{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}

#include "HsNetDef.h"

module Network.Socket.Shutdown (
    ShutdownCmd(..)
  , shutdown
  , gracefulClose
  ) where

import qualified Control.Exception as E
import Foreign.Marshal.Alloc (mallocBytes, free)

import Control.Concurrent (threadDelay)
#if !defined(mingw32_HOST_OS)
import Control.Concurrent (putMVar, takeMVar, newEmptyMVar)
import qualified GHC.Event as Ev
import System.Posix.Types (Fd(..))
#endif

import Network.Socket.Buffer
import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Types

data ShutdownCmd = ShutdownReceive
                 | ShutdownSend
                 | ShutdownBoth
                 deriving Typeable

sdownCmdToInt :: ShutdownCmd -> CInt
sdownCmdToInt :: ShutdownCmd -> CInt
sdownCmdToInt ShutdownCmd
ShutdownReceive = CInt
0
sdownCmdToInt ShutdownCmd
ShutdownSend    = CInt
1
sdownCmdToInt ShutdownCmd
ShutdownBoth    = CInt
2

-- | Shut down one or both halves of the connection, depending on the
-- second argument to the function.  If the second argument is
-- 'ShutdownReceive', further receives are disallowed.  If it is
-- 'ShutdownSend', further sends are disallowed.  If it is
-- 'ShutdownBoth', further sends and receives are disallowed.
shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown Socket
s ShutdownCmd
stype = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ()) -> IO ()) -> (CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  String -> IO CInt -> IO ()
forall a. (Eq a, Num a) => String -> IO a -> IO ()
throwSocketErrorIfMinus1Retry_ String
"Network.Socket.shutdown" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
    CInt -> CInt -> IO CInt
c_shutdown CInt
fd (CInt -> IO CInt) -> CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ ShutdownCmd -> CInt
sdownCmdToInt ShutdownCmd
stype

foreign import CALLCONV unsafe "shutdown"
  c_shutdown :: CInt -> CInt -> IO CInt

#if !defined(mingw32_HOST_OS)
data Wait = MoreData | TimeoutTripped
#endif

-- | Closing a socket gracefully.
--   This sends TCP FIN and check if TCP FIN is received from the peer.
--   The second argument is time out to receive TCP FIN in millisecond.
--   In both normal cases and error cases, socket is deallocated finally.
--
--   Since: 3.1.1.0
gracefulClose :: Socket -> Int -> IO ()
gracefulClose :: Socket -> Int -> IO ()
gracefulClose Socket
s Int
tmout = IO ()
sendRecvFIN IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`E.finally` Socket -> IO ()
close Socket
s
  where
    sendRecvFIN :: IO ()
sendRecvFIN = do
        -- Sending TCP FIN.
        Socket -> ShutdownCmd -> IO ()
shutdown Socket
s ShutdownCmd
ShutdownSend
        -- Waiting TCP FIN.
#if defined(mingw32_HOST_OS)
        recvEOFloop
#else
        Maybe EventManager
mevmgr <- IO (Maybe EventManager)
Ev.getSystemEventManager
        case Maybe EventManager
mevmgr of
          Maybe EventManager
Nothing    -> IO ()
recvEOFloop     -- non-threaded RTS
          Just EventManager
evmgr -> EventManager -> IO ()
recvEOFev EventManager
evmgr
#endif
    -- milliseconds. Taken from BSD fast clock value.
    clock :: Int
clock = Int
200
    recvEOFloop :: IO ()
recvEOFloop = IO (Ptr Word8)
-> (Ptr Word8 -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Int -> IO (Ptr Word8)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufSize) Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Ptr Word8 -> IO ()
loop Int
0
      where
        loop :: Int -> Ptr Word8 -> IO ()
loop Int
delay Ptr Word8
buf = do
            -- We don't check the (positive) length.
            -- In normal case, it's 0. That is, only FIN is received.
            -- In error cases, data is available. But there is no
            -- application which can read it. So, let's stop receiving
            -- to prevent attacks.
            Int
r <- Socket -> Ptr Word8 -> Int -> IO Int
recvBufNoWait Socket
s Ptr Word8
buf Int
bufSize
            let delay' :: Int
delay' = Int
delay Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
clock
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1 Bool -> Bool -> Bool
&& Int
delay' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
tmout) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Int -> IO ()
threadDelay (Int
clock Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000)
                Int -> Ptr Word8 -> IO ()
loop Int
delay' Ptr Word8
buf
#if !defined(mingw32_HOST_OS)
    recvEOFev :: EventManager -> IO ()
recvEOFev EventManager
evmgr = do
        TimerManager
tmmgr <- IO TimerManager
Ev.getSystemTimerManager
        MVar Wait
mvar <- IO (MVar Wait)
forall a. IO (MVar a)
newEmptyMVar
        IO (TimeoutKey, FdKey)
-> ((TimeoutKey, FdKey) -> IO ())
-> ((TimeoutKey, FdKey) -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (EventManager -> TimerManager -> MVar Wait -> IO (TimeoutKey, FdKey)
register EventManager
evmgr TimerManager
tmmgr MVar Wait
mvar) (EventManager -> TimerManager -> (TimeoutKey, FdKey) -> IO ()
unregister EventManager
evmgr TimerManager
tmmgr) (((TimeoutKey, FdKey) -> IO ()) -> IO ())
-> ((TimeoutKey, FdKey) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(TimeoutKey, FdKey)
_ -> do
            Wait
wait <- MVar Wait -> IO Wait
forall a. MVar a -> IO a
takeMVar MVar Wait
mvar
            case Wait
wait of
              Wait
TimeoutTripped -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
              -- We don't check the (positive) length.
              -- In normal case, it's 0. That is, only FIN is received.
              -- In error cases, data is available. But there is no
              -- application which can read it. So, let's stop receiving
              -- to prevent attacks.
              Wait
MoreData       -> IO (Ptr Word8)
-> (Ptr Word8 -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Int -> IO (Ptr Word8)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufSize)
                                          Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free
                                          (\Ptr Word8
buf -> IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
recvBufNoWait Socket
s Ptr Word8
buf Int
bufSize)
    register :: EventManager -> TimerManager -> MVar Wait -> IO (TimeoutKey, FdKey)
register EventManager
evmgr TimerManager
tmmgr MVar Wait
mvar = do
        -- millisecond to microsecond
        TimeoutKey
key1 <- TimerManager -> Int -> IO () -> IO TimeoutKey
Ev.registerTimeout TimerManager
tmmgr (Int
tmout Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) (IO () -> IO TimeoutKey) -> IO () -> IO TimeoutKey
forall a b. (a -> b) -> a -> b
$
            MVar Wait -> Wait -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Wait
mvar Wait
TimeoutTripped
        FdKey
key2 <- Socket -> (CInt -> IO FdKey) -> IO FdKey
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO FdKey) -> IO FdKey) -> (CInt -> IO FdKey) -> IO FdKey
forall a b. (a -> b) -> a -> b
$ \CInt
fd' -> do
            let callback :: p -> p -> IO ()
callback p
_ p
_ = MVar Wait -> Wait -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Wait
mvar Wait
MoreData
                fd :: Fd
fd = CInt -> Fd
Fd CInt
fd'
#if __GLASGOW_HASKELL__ < 709
            Ev.registerFd evmgr callback fd Ev.evtRead
#else
            EventManager -> IOCallback -> Fd -> Event -> Lifetime -> IO FdKey
Ev.registerFd EventManager
evmgr IOCallback
forall p p. p -> p -> IO ()
callback Fd
fd Event
Ev.evtRead Lifetime
Ev.OneShot
#endif
        (TimeoutKey, FdKey) -> IO (TimeoutKey, FdKey)
forall (m :: * -> *) a. Monad m => a -> m a
return (TimeoutKey
key1, FdKey
key2)
    unregister :: EventManager -> TimerManager -> (TimeoutKey, FdKey) -> IO ()
unregister EventManager
evmgr TimerManager
tmmgr (TimeoutKey
key1,FdKey
key2) = do
        TimerManager -> TimeoutKey -> IO ()
Ev.unregisterTimeout TimerManager
tmmgr TimeoutKey
key1
        EventManager -> FdKey -> IO ()
Ev.unregisterFd EventManager
evmgr FdKey
key2
#endif
    -- Don't use 4092 here. The GHC runtime takes the global lock
    -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
    bufSize :: Int
bufSize = Int
1024