Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add socket diagnosis for UDP #927

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions inet_diag.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ type InetDiagTCPInfoResp struct {
TCPInfo *TCPInfo
TCPBBRInfo *TCPBBRInfo
}

type InetDiagUDPInfoResp struct {
InetDiagMsg *Socket
Memory *MemInfo
}
142 changes: 117 additions & 25 deletions socket_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,18 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {

// SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info.
func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {
// Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{
Family: family,
Protocol: unix.IPPROTO_TCP,
Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)),
States: uint32(0xfff), // all states
})

// Do the query and parse the result
var result []*InetDiagTCPInfoResp
err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error {
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
Expand All @@ -201,8 +211,18 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {

// SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket.
func SocketDiagTCP(family uint8) ([]*Socket, error) {
// Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{
Family: family,
Protocol: unix.IPPROTO_TCP,
Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)),
States: uint32(0xfff), // all states
})

// Do the query and parse the result
var result []*Socket
err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error {
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
Expand All @@ -216,21 +236,82 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) {
return result, nil
}

// socketDiagTCPExecutor requests INET_DIAG_INFO for TCP protocol for specified family type.
func socketDiagTCPExecutor(family uint8, receiver func(syscall.NetlinkMessage) error) error {
s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info.
func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) {
// Construct the request
var extensions uint8
extensions = 1 << (INET_DIAG_VEGASINFO - 1)
extensions |= 1 << (INET_DIAG_INFO - 1)
extensions |= 1 << (INET_DIAG_MEMINFO - 1)

req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{
Family: family,
Protocol: unix.IPPROTO_UDP,
Ext: extensions,
States: uint32(0xfff), // all states
})

// Do the query and parse the result
var result []*InetDiagUDPInfoResp
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
}
attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:])
if err != nil {
return err
}

res, err := attrsToInetDiagUDPInfoResp(attrs, sockInfo)
if err != nil {
return err
}

result = append(result, res)
return nil
})
if err != nil {
return err
return nil, err
}
defer s.Close()
return result, nil
}

// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket.
func SocketDiagUDP(family uint8) ([]*Socket, error) {
// Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{
Family: family,
Protocol: unix.IPPROTO_TCP,
Protocol: unix.IPPROTO_UDP,
Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)),
States: uint32(0xfff), // All TCP states
States: uint32(0xfff), // all states
})

// Do the query and parse the result
var result []*Socket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
}
result = append(result, sockInfo)
return nil
})
if err != nil {
return nil, err
}
return result, nil
}

// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request.
func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error {
s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
if err != nil {
return err
}
defer s.Close()
s.Send(req)

loop:
Expand All @@ -240,7 +321,7 @@ loop:
return err
}
if from.Pid != nl.PidKernel {
return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
return fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
}
if len(msgs) == 0 {
return errors.New("no message nor error from netlink")
Expand All @@ -263,29 +344,40 @@ loop:
}

func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) {
var tcpInfo *TCPInfo
var tcpBBRInfo *TCPBBRInfo
info := &InetDiagTCPInfoResp{
InetDiagMsg: sockInfo,
}
for _, a := range attrs {
if a.Attr.Type == INET_DIAG_INFO {
tcpInfo = &TCPInfo{}
if err := tcpInfo.deserialize(a.Value); err != nil {
switch a.Attr.Type {
case INET_DIAG_INFO:
info.TCPInfo = &TCPInfo{}
if err := info.TCPInfo.deserialize(a.Value); err != nil {
return nil, err
}
case INET_DIAG_BBRINFO:
info.TCPBBRInfo = &TCPBBRInfo{}
if err := info.TCPBBRInfo.deserialize(a.Value); err != nil {
return nil, err
}
continue
}
}

return info, nil
}

if a.Attr.Type == INET_DIAG_BBRINFO {
tcpBBRInfo = &TCPBBRInfo{}
if err := tcpBBRInfo.deserialize(a.Value); err != nil {
func attrsToInetDiagUDPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagUDPInfoResp, error) {
info := &InetDiagUDPInfoResp{
InetDiagMsg: sockInfo,
}
for _, a := range attrs {
switch a.Attr.Type {
case INET_DIAG_MEMINFO:
info.Memory = &MemInfo{}
if err := info.Memory.deserialize(a.Value); err != nil {
return nil, err
}
continue
}
}

return &InetDiagTCPInfoResp{
InetDiagMsg: sockInfo,
TCPInfo: tcpInfo,
TCPBBRInfo: tcpBBRInfo,
}, nil
return info, nil
}
16 changes: 16 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build linux
// +build linux

package netlink
Expand Down Expand Up @@ -75,3 +76,18 @@ func TestSocketDiagTCPInfo(t *testing.T) {
}
}
}

func TestSocketDiagUDPnfo(t *testing.T) {
for _, want := range []uint8{syscall.AF_INET, syscall.AF_INET6} {
result, err := SocketDiagUDPInfo(want)
if err != nil {
t.Fatal(err)
}

for _, r := range result {
if got := r.InetDiagMsg.Family; got != want {
t.Fatalf("protocol family = %v, want %v", got, want)
}
}
}
}
8 changes: 8 additions & 0 deletions tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ type TCPBBRInfo struct {
BBRPacingGain uint32
BBRCwndGain uint32
}

// According to https://man7.org/linux/man-pages/man7/sock_diag.7.html
type MemInfo struct {
RMem uint32
WMem uint32
FMem uint32
TMem uint32
}
15 changes: 15 additions & 0 deletions tcp_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

const (
tcpBBRInfoLen = 20
memInfoLen = 16
)

func checkDeserErr(err error) error {
Expand Down Expand Up @@ -351,3 +352,17 @@ func (t *TCPBBRInfo) deserialize(b []byte) error {

return nil
}

func (m *MemInfo) deserialize(b []byte) error {
if len(b) != memInfoLen {
return errors.New("Invalid length")
}

rb := bytes.NewBuffer(b)
m.RMem = native.Uint32(rb.Next(4))
m.WMem = native.Uint32(rb.Next(4))
m.FMem = native.Uint32(rb.Next(4))
m.TMem = native.Uint32(rb.Next(4))

return nil
}
Loading