Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sockpair #554

Merged
merged 2 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,7 @@ unpackBits ((k,v):xs) r
-- SockAddr

instance Show SockAddr where
#if defined(DOMAIN_SOCKET_SUPPORT)
showsPrec _ (SockAddrUnix str) = showString str
#else
showsPrec _ SockAddrUnix{} = error "showsPrec: not supported"
#endif
showsPrec _ (SockAddrInet port ha)
= showHostAddress ha
. showString ":"
Expand Down
18 changes: 0 additions & 18 deletions Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ import GHC.IO (IO (..))

import qualified Text.Read as P

#if defined(DOMAIN_SOCKET_SUPPORT)
import Foreign.Marshal.Array
#endif

import Network.Socket.Imports

Expand Down Expand Up @@ -1075,11 +1073,7 @@ isSupportedSockAddr :: SockAddr -> Bool
isSupportedSockAddr addr = case addr of
SockAddrInet{} -> True
SockAddrInet6{} -> True
#if defined(DOMAIN_SOCKET_SUPPORT)
SockAddrUnix{} -> True
#else
SockAddrUnix{} -> False
#endif

instance SocketAddress SockAddr where
sizeOfSocketAddress = sizeOfSockAddr
Expand All @@ -1098,7 +1092,6 @@ type CSaFamily = (#type sa_family_t)
-- 'SockAddr'. This function differs from 'Foreign.Storable.sizeOf'
-- in that the value of the argument /is/ used.
sizeOfSockAddr :: SockAddr -> Int
#if defined(DOMAIN_SOCKET_SUPPORT)
# ifdef linux_HOST_OS
-- http://man7.org/linux/man-pages/man7/unix.7.html says:
-- "an abstract socket address is distinguished (from a
Expand All @@ -1118,9 +1111,6 @@ sizeOfSockAddr (SockAddrUnix path) =
# else
sizeOfSockAddr SockAddrUnix{} = #const sizeof(struct sockaddr_un)
# endif
#else
sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported"
#endif
sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in)
sizeOfSockAddr SockAddrInet6{} = #const sizeof(struct sockaddr_in6)

Expand All @@ -1135,10 +1125,8 @@ withSockAddr addr f = do
-- structure, and attempting to do so could overflow the allocated storage
-- space. This constant holds the maximum allowable path length.
--
#if defined(DOMAIN_SOCKET_SUPPORT)
unixPathMax :: Int
unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path)
#endif

-- We can't write an instance of 'Storable' for 'SockAddr' because
-- @sockaddr@ is a sum type of variable size but
Expand All @@ -1149,7 +1137,6 @@ unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path)

-- | Write the given 'SockAddr' to the given memory location.
pokeSockAddr :: Ptr a -> SockAddr -> IO ()
#if defined(DOMAIN_SOCKET_SUPPORT)
pokeSockAddr p sa@(SockAddrUnix path) = do
when (length path > unixPathMax) $ error
$ "pokeSockAddr: path is too long in SockAddrUnix " <> show path
Expand All @@ -1162,9 +1149,6 @@ pokeSockAddr p sa@(SockAddrUnix path) = do
let pathC = map castCharToCChar path
-- the buffer is already filled with nulls.
pokeArray ((#ptr struct sockaddr_un, sun_path) p) pathC
#else
pokeSockAddr _ SockAddrUnix{} = error "pokeSockAddr: not supported"
#endif
pokeSockAddr p (SockAddrInet port addr) = do
zeroMemory p (#const sizeof(struct sockaddr_in))
#if defined(HAVE_STRUCT_SOCKADDR_SA_LEN)
Expand All @@ -1189,11 +1173,9 @@ peekSockAddr :: Ptr SockAddr -> IO SockAddr
peekSockAddr p = do
family <- (#peek struct sockaddr, sa_family) p
case family :: CSaFamily of
#if defined(DOMAIN_SOCKET_SUPPORT)
(#const AF_UNIX) -> do
str <- peekCAString ((#ptr struct sockaddr_un, sun_path) p)
return (SockAddrUnix str)
#endif
(#const AF_INET) -> do
addr <- (#peek struct sockaddr_in, sin_addr) p
port <- (#peek struct sockaddr_in, sin_port) p
Expand Down
51 changes: 27 additions & 24 deletions Network/Socket/Unix.hsc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}

#include "HsNet.h"
##include "HsNetDef.h"
Expand All @@ -13,30 +14,32 @@ module Network.Socket.Unix (
, getPeerEid
) where

import System.Posix.Types (Fd(..))

import Foreign.Marshal.Alloc (allocaBytes)
import Network.Socket.Buffer
import Network.Socket.Fcntl
import Network.Socket.Imports
import Network.Socket.Types
import System.Posix.Types (Fd(..))

#if defined(mingw32_HOST_OS)
import Network.Socket.Syscall
import Network.Socket.Win32.Cmsg
import System.Directory
import System.IO
import System.IO.Temp
#else
import Foreign.Marshal.Array (peekArray)
import Network.Socket.Internal
import Network.Socket.Posix.Cmsg
#endif
import Network.Socket.Types

#if defined(HAVE_GETPEEREID)
import System.IO.Error (catchIOError)
#endif
#ifdef HAVE_GETPEEREID
import Foreign.Marshal.Alloc (alloca)
#endif
#ifdef DOMAIN_SOCKET_SUPPORT
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Marshal.Array (peekArray)

import Network.Socket.Fcntl
import Network.Socket.Internal
#endif
#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
import Network.Socket.Options
#endif
Expand Down Expand Up @@ -126,11 +129,7 @@ getPeerEid _ = return (0, 0)
--
-- Since 2.7.0.0.
isUnixDomainSocketAvailable :: Bool
#if defined(DOMAIN_SOCKET_SUPPORT)
isUnixDomainSocketAvailable = True
#else
isUnixDomainSocketAvailable = False
#endif

data NullSockAddr = NullSockAddr

Expand All @@ -143,33 +142,25 @@ instance SocketAddress NullSockAddr where
-- Use this function in the case where 'isUnixDomainSocketAvailable' is
-- 'True'.
sendFd :: Socket -> CInt -> IO ()
#if defined(DOMAIN_SOCKET_SUPPORT)
sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
let cmsg = encodeCmsg $ Fd outfd
sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty
where
dummyBufSize = 1
#else
sendFd _ _ = error "Network.Socket.sendFd"
#endif

-- | Receive a file descriptor over a UNIX-domain socket. Note that the resulting
-- file descriptor may have to be put into non-blocking mode in order to be
-- used safely. See 'setNonBlockIfNeeded'.
-- Use this function in the case where 'isUnixDomainSocketAvailable' is
-- 'True'.
recvFd :: Socket -> IO CInt
#if defined(DOMAIN_SOCKET_SUPPORT)
recvFd s = allocaBytes dummyBufSize $ \buf -> do
(NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty
case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of
Nothing -> return (-1)
Just (Fd fd) -> return fd
where
dummyBufSize = 16
#else
recvFd _ = error "Network.Socket.recvFd"
#endif

-- | Build a pair of connected socket objects.
-- For portability, use this function in the case
Expand All @@ -179,7 +170,21 @@ socketPair :: Family -- Family Name (usually AF_UNIX)
-> SocketType -- Socket Type (usually Stream)
-> ProtocolNumber -- Protocol Number
-> IO (Socket, Socket) -- unnamed and connected.
#if defined(DOMAIN_SOCKET_SUPPORT)
#if defined(mingw32_HOST_OS)
socketPair _ _ _ = withSystemTempFile "temp-for-pair" $ \file hdl -> do
hClose hdl
removeFile file
listenSock <- socket AF_UNIX Stream defaultProtocol
bind listenSock $ SockAddrUnix file
listen listenSock 10
clientSock <- socket AF_UNIX Stream defaultProtocol
connect clientSock $ SockAddrUnix file
(serverSock, _ :: SockAddr) <- accept listenSock
close listenSock
withFdSocket clientSock setNonBlockIfNeeded
withFdSocket serverSock setNonBlockIfNeeded
return (clientSock, serverSock)
#else
socketPair family stype protocol =
allocaBytes (2 * sizeOf (1 :: CInt)) $ \ fdArr -> do
let c_stype = packSocketType stype
Expand All @@ -194,6 +199,4 @@ socketPair family stype protocol =

foreign import ccall unsafe "socketpair"
c_socketpair :: CInt -> CInt -> CInt -> Ptr CInt -> IO CInt
#else
socketPair _ _ _ = error "Network.Socket.socketPair"
#endif
2 changes: 0 additions & 2 deletions include/HsNetDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#undef PACKAGE_TARNAME
#undef PACKAGE_VERSION

#define DOMAIN_SOCKET_SUPPORT 1

#if defined(HAVE_STRUCT_UCRED) && HAVE_DECL_SO_PEERCRED
# define HAVE_STRUCT_UCRED_SO_PEERCRED 1
#else
Expand Down
3 changes: 3 additions & 0 deletions network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ library
cpp-options: -D_WIN32_WINNT=0x0600
cc-options: -D_WIN32_WINNT=0x0600

build-depends:
temporary

test-suite spec
type: exitcode-stdio-1.0
main-is: Spec.hs
Expand Down