module Foreign.Marshal.Pool (
   
   Pool,
   newPool,
   freePool,
   withPool,
   
   pooledMalloc,
   pooledMallocBytes,
   pooledRealloc,
   pooledReallocBytes,
   pooledMallocArray,
   pooledMallocArray0,
   pooledReallocArray,
   pooledReallocArray0,
   
   pooledNew,
   pooledNewArray,
   pooledNewArray0
) where
import GHC.Base              ( Int, Monad(..), (.), liftM, not )
import GHC.Err               ( undefined )
import GHC.Exception         ( throw )
import GHC.IO                ( IO, mask, catchAny )
import GHC.IORef             ( IORef, newIORef, readIORef, writeIORef )
import GHC.List              ( elem, length )
import GHC.Num               ( Num(..) )
import Data.OldList          ( delete )
import Foreign.Marshal.Alloc ( mallocBytes, reallocBytes, free )
import Foreign.Marshal.Array ( pokeArray, pokeArray0 )
import Foreign.Marshal.Error ( throwIf )
import Foreign.Ptr           ( Ptr, castPtr )
import Foreign.Storable      ( Storable(sizeOf, poke) )
newtype Pool = Pool (IORef [Ptr ()])
newPool :: IO Pool
newPool = liftM Pool (newIORef [])
freePool :: Pool -> IO ()
freePool (Pool pool) = readIORef pool >>= freeAll
   where freeAll []     = return ()
         freeAll (p:ps) = free p >> freeAll ps
withPool :: (Pool -> IO b) -> IO b
withPool act =   
   mask (\restore -> do
      pool <- newPool
      val <- catchAny
                (restore (act pool))
                (\e -> do freePool pool; throw e)
      freePool pool
      return val)
pooledMalloc :: forall a . Storable a => Pool -> IO (Ptr a)
pooledMalloc pool = pooledMallocBytes pool (sizeOf (undefined :: a))
pooledMallocBytes :: Pool -> Int -> IO (Ptr a)
pooledMallocBytes (Pool pool) size = do
   ptr <- mallocBytes size
   ptrs <- readIORef pool
   writeIORef pool (ptr:ptrs)
   return (castPtr ptr)
pooledRealloc :: forall a . Storable a => Pool -> Ptr a -> IO (Ptr a)
pooledRealloc pool ptr = pooledReallocBytes pool ptr (sizeOf (undefined :: a))
pooledReallocBytes :: Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes (Pool pool) ptr size = do
   let cPtr = castPtr ptr
   _ <- throwIf (not . (cPtr `elem`)) (\_ -> "pointer not in pool") (readIORef pool)
   newPtr <- reallocBytes cPtr size
   ptrs <- readIORef pool
   writeIORef pool (newPtr : delete cPtr ptrs)
   return (castPtr newPtr)
pooledMallocArray :: forall a . Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray pool size =
    pooledMallocBytes pool (size * sizeOf (undefined :: a))
pooledMallocArray0 :: Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray0 pool size =
   pooledMallocArray pool (size + 1)
pooledReallocArray :: forall a . Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray pool ptr size =
    pooledReallocBytes pool ptr (size * sizeOf (undefined :: a))
pooledReallocArray0 :: Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray0 pool ptr size =
   pooledReallocArray pool ptr (size + 1)
pooledNew :: Storable a => Pool -> a -> IO (Ptr a)
pooledNew pool val = do
   ptr <- pooledMalloc pool
   poke ptr val
   return ptr
pooledNewArray :: Storable a => Pool -> [a] -> IO (Ptr a)
pooledNewArray pool vals = do
   ptr <- pooledMallocArray pool (length vals)
   pokeArray ptr vals
   return ptr
pooledNewArray0 :: Storable a => Pool -> a -> [a] -> IO (Ptr a)
pooledNewArray0 pool marker vals = do
   ptr <- pooledMallocArray0 pool (length vals)
   pokeArray0 marker ptr vals
   return ptr