Skip to content

Commit

Permalink
Fix SetSendTimeout/SetReceiveTimeout
Browse files Browse the repository at this point in the history
They were implemented using SO_SNDTIMEO/SO_RCVTIMEO on the
socket descriptor - but that doesn't work now the socket is
non-blocking. Instead, set deadlines on the file read/write.

Signed-off-by: Rob Murray <rob.murray@docker.com>
  • Loading branch information
robmry authored and aboch committed Sep 4, 2024
1 parent 0cd1f79 commit e194da5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 40 deletions.
41 changes: 12 additions & 29 deletions handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"

"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)

func TestHandleCreateClose(t *testing.T) {
Expand Down Expand Up @@ -122,13 +120,22 @@ func TestHandleTimeout(t *testing.T) {
defer h.Close()

for _, sh := range h.sockets {
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0})
verifySockTimeVal(t, sh.Socket, time.Duration(0))
}

h.SetSocketTimeout(2*time.Second + 8*time.Millisecond)
const timeout = 2*time.Second + 8*time.Millisecond
h.SetSocketTimeout(timeout)

for _, sh := range h.sockets {
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000})
verifySockTimeVal(t, sh.Socket, timeout)
}
}

func verifySockTimeVal(t *testing.T, socket *nl.NetlinkSocket, expTimeout time.Duration) {
t.Helper()
send, receive := socket.GetTimeouts()
if send != expTimeout || receive != expTimeout {
t.Fatalf("Expected timeout: %v, got Send: %v, Receive: %v", expTimeout, send, receive)
}
}

Expand Down Expand Up @@ -157,30 +164,6 @@ func TestHandleReceiveBuffer(t *testing.T) {
}
}

func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) {
var (
tr unix.Timeval
v = uint32(0x10)
)
_, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
if errno != 0 {
t.Fatal(errno)
}

if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
}

_, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
if errno != 0 {
t.Fatal(errno)
}

if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
}
}

var (
iter = 10
numThread = uint32(4)
Expand Down
73 changes: 62 additions & 11 deletions nl/nl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ package nl
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"os"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"

"github.com/vishvananda/netns"
Expand Down Expand Up @@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
}

type NetlinkSocket struct {
fd int32
file *os.File
lsa unix.SockaddrNetlink
fd int32
file *os.File
lsa unix.SockaddrNetlink
sendTimeout int64 // Access using atomic.Load/StoreInt64
receiveTimeout int64 // Access using atomic.Load/StoreInt64
sync.Mutex
}

Expand Down Expand Up @@ -802,8 +806,44 @@ func (s *NetlinkSocket) GetFd() int {
return int(s.fd)
}

func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) {
return time.Duration(atomic.LoadInt64(&s.sendTimeout)),
time.Duration(atomic.LoadInt64(&s.receiveTimeout))
}

func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
rawConn, err := s.file.SyscallConn()
if err != nil {
return err
}
var (
deadline time.Time
innerErr error
)
sendTimeout := atomic.LoadInt64(&s.sendTimeout)
if sendTimeout != 0 {
deadline = time.Now().Add(time.Duration(sendTimeout))
}
if err := s.file.SetWriteDeadline(deadline); err != nil {
return err
}
serializedReq := request.Serialize()
err = rawConn.Write(func(fd uintptr) (done bool) {
innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa)
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
return innerErr
}
if err != nil {
// The timeout was previously implemented using SO_SNDTIMEO on a blocking
// socket. So, continue to return EAGAIN when the timeout is reached.
if errors.Is(err, os.ErrDeadlineExceeded) {
return unix.EAGAIN
}
return err
}
return nil
}

func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
Expand All @@ -812,20 +852,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
return nil, nil, err
}
var (
deadline time.Time
fromAddr *unix.SockaddrNetlink
rb [RECEIVE_BUFFER_SIZE]byte
nr int
from unix.Sockaddr
innerErr error
)
receiveTimeout := atomic.LoadInt64(&s.receiveTimeout)
if receiveTimeout != 0 {
deadline = time.Now().Add(time.Duration(receiveTimeout))
}
if err := s.file.SetReadDeadline(deadline); err != nil {
return nil, nil, err
}
err = rawConn.Read(func(fd uintptr) (done bool) {
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
return nil, nil, innerErr
}
if err != nil {
// The timeout was previously implemented using SO_RCVTIMEO on a blocking
// socket. So, continue to return EAGAIN when the timeout is reached.
if errors.Is(err, os.ErrDeadlineExceeded) {
return nil, nil, unix.EAGAIN
}
return nil, nil, err
}
fromAddr, ok := from.(*unix.SockaddrNetlink)
Expand All @@ -847,16 +900,14 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli

// SetSendTimeout allows to set a send timeout on the socket
func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
// remains stuck on a send on a closed fd
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
atomic.StoreInt64(&s.sendTimeout, timeout.Nano())
return nil
}

// SetReceiveTimeout allows to set a receive timeout on the socket
func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
// remains stuck on a recvmsg on a closed fd
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
atomic.StoreInt64(&s.receiveTimeout, timeout.Nano())
return nil
}

// SetReceiveBufferSize allows to set a receive buffer size on the socket
Expand Down
63 changes: 63 additions & 0 deletions nl/nl_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,69 @@ func TestIfSocketCloses(t *testing.T) {
}
}

func TestReceiveTimeout(t *testing.T) {
nlSock, err := getNetlinkSocket(unix.NETLINK_ROUTE)
if err != nil {
t.Fatalf("Error creating the socket: %v", err)
}
// Even if the test fails because the timeout doesn't work, closing the
// socket at the end of the test should result in an EAGAIN (as long as
// TestIfSocketCloses completed, otherwise this test will leak the
// goroutines running the Receive).
defer nlSock.Close()
const failAfter = time.Second

tests := []struct {
name string
timeout time.Duration
}{
{
name: "1us timeout", // The smallest value accepted by Handle.SetSocketTimeout
timeout: time.Microsecond,
},
{
name: "100ms timeout",
timeout: 100 * time.Millisecond,
},
{
name: "500ms timeout",
timeout: 500 * time.Millisecond,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
timeout := unix.NsecToTimeval(int64(tc.timeout))
nlSock.SetReceiveTimeout(&timeout)

doneC := make(chan time.Duration)
errC := make(chan error)
go func() {
start := time.Now()
_, _, err := nlSock.Receive()
dur := time.Since(start)
if err != unix.EAGAIN {
errC <- err
return
}
doneC <- dur
}()

failTimerC := time.After(failAfter)
select {
case dur := <-doneC:
if dur < tc.timeout || dur > (tc.timeout+(100*time.Millisecond)) {
t.Fatalf("Expected timeout %v got %v", tc.timeout, dur)
}
case err := <-errC:
t.Fatalf("Expected EAGAIN, but got: %v", err)
case <-failTimerC:
t.Fatalf("No timeout received")
}
})
}
}

func (msg *CnMsgOp) write(b []byte) {
native := NativeEndian()
native.PutUint32(b[0:4], msg.ID.Idx)
Expand Down

0 comments on commit e194da5

Please sign in to comment.