diff --git a/go.mod b/go.mod index c6c4ff25c0b..13e3e74a3da 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/fsnotify/fsnotify v1.4.9 github.com/go-ping/ping v0.0.0-20210506233800-ff8be3320020 github.com/google/go-cmp v0.5.5 + github.com/google/gopacket v1.1.19 github.com/google/renameio v1.0.1 github.com/insomniacslk/dhcp v0.0.0-20210310193751-cfd4d47082c2 github.com/kardianos/service v1.2.0 diff --git a/go.sum b/go.sum index 017d64771e0..b3dcd9f3983 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,8 @@ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/renameio v1.0.1 h1:Lh/jXZmvZxb0BBeSY5VKEfidcbcbenKjZFzM/q0fSeU= @@ -268,6 +270,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -377,6 +380,7 @@ golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/dhcpd/conn_unix.go b/internal/dhcpd/conn_unix.go index 7fa6ee32f5a..8e6fc0f32aa 100644 --- a/internal/dhcpd/conn_unix.go +++ b/internal/dhcpd/conn_unix.go @@ -4,7 +4,6 @@ package dhcpd import ( - "encoding/binary" "fmt" "net" "os" @@ -12,15 +11,17 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/server4" "github.com/mdlayher/ethernet" "github.com/mdlayher/raw" ) -// dhcpUcastAddr is the combination of MAC and IP addresses for responding to +// dhcpUnicastAddr is the combination of MAC and IP addresses for responding to // the unconfigured host. -type dhcpUcastAddr struct { +type dhcpUnicastAddr struct { // raw.Addr is embedded here to make *dhcpUcastAddr a net.Addr without // actually implementing all methods. It also contains the client's // hardware address. @@ -79,15 +80,16 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) }, nil } -// wrapErrs is a helper to wrap the errors from two independent connections. -func (c *dhcpConn) wrapErrs(action string, uerr, rerr error) (err error) { +// wrapErrs is a helper to wrap the errors from two independent underlying +// connections. +func (c *dhcpConn) wrapErrs(action string, udpConnErr, rawConnErr error) (err error) { switch { - case uerr != nil && rerr != nil: - return errors.List(fmt.Sprintf("%s both connections", action), uerr, rerr) - case uerr != nil: - return fmt.Errorf("%s udp connection: %w", action, uerr) - case rerr != nil: - return fmt.Errorf("%s raw connection: %w", action, rerr) + case udpConnErr != nil && rawConnErr != nil: + return errors.List(fmt.Sprintf("%s both connections", action), udpConnErr, rawConnErr) + case udpConnErr != nil: + return fmt.Errorf("%s udp connection: %w", action, udpConnErr) + case rawConnErr != nil: + return fmt.Errorf("%s raw connection: %w", action, rawConnErr) default: return nil } @@ -97,7 +99,7 @@ func (c *dhcpConn) wrapErrs(action string, uerr, rerr error) (err error) { // connection to write to based on the type of addr. func (c *dhcpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { switch addr := addr.(type) { - case *dhcpUcastAddr: + case *dhcpUnicastAddr: // Unicast the message to the client's MAC address. Use the raw // connection. // @@ -136,9 +138,9 @@ func (c *dhcpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } // unicast wraps respData with required frames and writes it to the peer. -func (c *dhcpConn) unicast(respData []byte, peer *dhcpUcastAddr) (n int, err error) { +func (c *dhcpConn) unicast(respData []byte, peer *dhcpUnicastAddr) (n int, err error) { var data []byte - data, err = c.buildEtherPacket(respData, peer) + data, err = c.buildEtherPkt(respData, peer) if err != nil { return 0, err } @@ -185,199 +187,47 @@ func (c *dhcpConn) SetWriteDeadline(t time.Time) error { ) } -// ipv4HdrVals describes the values of IPv4 packet's header as defined in -// RFC-791 (https://datatracker.ietf.org/doc/html/rfc791#section-3.1). -type ipv4HdrVals struct { - // src is the Source Address field of the header. - src net.IP - - // dst is the Destination Address field of the header. - dst net.IP - - // fragOffset is the Fragment Offset field of the header. - fragOffset uint16 - - // totalLen is the Total Length field of the header. - totalLen uint16 - - // id is the Identification field of the header. - id uint16 - - // ihl is the Internet Header Length field of the header. - ihl uint8 - - // tos is the Type of Service field of the header. - tos uint8 - - // flags is the Flags field of the header. - flags uint8 - - // ttl is the Time to Live field of the header. - ttl uint8 - - // proto is the Protocol field of the header. - proto uint8 -} - -// ipv4Hdr is used to encode the ipv4HdrVals into the underlying byte slice. -type ipv4Hdr []byte - const ( - // ipv4Sz is the lowest size of an IPv4 packet. - ipv4Sz = 20 - - // ipv4FlagDontFrag is the value for flags to avoid rfragmentation of - // the packet. - // - // See https://datatracker.ietf.org/doc/html/rfc791#section-3.1. - ipv4FlagDontFrag = 1 << 1 - - // ipv4DefaultTTL is the Time to Live values as recommended by RFC-1700 - // (https://datatracker.ietf.org/doc/html/rfc1700). + // ipv4DefaultTTL is the default Time to Live value as recommended by + // RFC-1700 (https://datatracker.ietf.org/doc/html/rfc1700) in seconds. ipv4DefaultTTL = 64 ) -// Indexes of IPv4 header's fields in bytes slice representation. -const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksumOff = 10 - srcAddr = 12 - dstAddr = 16 -) - -// encode writes all the data from i to b. -func (b ipv4Hdr) encode(i ipv4HdrVals) { - b[versIHL] = (4 << 4) | ((i.ihl / 4) & 0xf) - b[tos] = i.tos - binary.BigEndian.PutUint16(b[totalLen:], i.totalLen) - binary.BigEndian.PutUint16(b[id:], i.id) - binary.BigEndian.PutUint16(b[flagsFO:], (uint16(i.flags)<<13)|(i.fragOffset>>3)) - b[ttl] = i.ttl - b[protocol] = i.proto - copy(b[srcAddr:srcAddr+net.IPv4len], i.src) - copy(b[dstAddr:dstAddr+net.IPv4len], i.dst) - - hl := (b[versIHL] & 0x0f) * 4 - binary.BigEndian.PutUint16(b[checksumOff:], ^checksum(b[:hl], 0)) -} - -// udpHdrVals describes the values of UDP header as defined by RFC-768 -// (https://datatracker.ietf.org/doc/html/rfc768). -type udpHdrVals struct { - // srcPort is the Source Port field of the header. - srcPort uint16 - - // dstPort is the Destination Port field of the header. - dstPort uint16 - - // length is the Length field of the header. - length uint16 -} - -// ipv4Hdr is used to encode the udpHdrVals into the underlying byte slice. -type udpHdr []byte - -const ( - // udpSz is the minimum size of a valid udp packet. - udpSz = 8 - - // udpProto is udp's transport protocol number. - udpProto = 17 -) - -// Indexes of UDP header's fields in bytes slice representation. -const ( - udpSrcPort = 0 - udpDstPort = 2 - udpLength = 4 - udpChecksum = 6 -) - -// encode writes all the data from u to b. -func (b udpHdr) encode(u udpHdrVals, src, dst net.IP, proto uint8, payload []byte) { - binary.BigEndian.PutUint16(b[udpSrcPort:], u.srcPort) - binary.BigEndian.PutUint16(b[udpDstPort:], u.dstPort) - binary.BigEndian.PutUint16(b[udpLength:], u.length) - - xsum := checksum([]byte(src), 0) - for _, buf := range [...][]byte{ - []byte(dst), - {0, proto}, - payload, - b[udpLength : udpLength+2], - b[:udpSz], - } { - xsum = checksum(buf, xsum) +func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []byte, err error) { + dhcpPkt := gopacket.NewPacket(payload, layers.LayerTypeDHCPv4, gopacket.DecodeOptions{ + NoCopy: true, + }).Layer(layers.LayerTypeDHCPv4) + dhcpLayer, ok := dhcpPkt.(gopacket.SerializableLayer) + if !ok { + return nil, fmt.Errorf("layer %s is not serializable", dhcpLayer.LayerType()) } - binary.BigEndian.PutUint16(b[udpChecksum:], ^xsum) -} - -// checksum calculates the internet checksum as defined in RFC-1071 -// (https://datatracker.ietf.org/doc/html/rfc1071). -// -// The initial checksum must have been calculated on a byte slice of even -// length. -func checksum(buf []byte, initial uint16) (c uint16) { - v := uint32(initial) - - l := len(buf) - if l%2 != 0 { - l-- - v += uint32(buf[l]) << 8 + udpLayer := &layers.UDP{ + SrcPort: dhcpv4.ServerPort, + DstPort: dhcpv4.ClientPort, } - - for i := 0; i < l; i += 2 { - v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + ipv4Layer := &layers.IPv4{ + Version: uint8(layers.IPProtocolIPv4), + Flags: layers.IPv4DontFragment, + TTL: ipv4DefaultTTL, + Protocol: layers.IPProtocolUDP, + SrcIP: c.srcIP, + DstIP: peer.yiaddr, + } + ethLayer := &layers.Ethernet{ + SrcMAC: c.srcMAC, + DstMAC: peer.HardwareAddr, + EthernetType: layers.EthernetTypeIPv4, } - v = uint32(uint16(v)) + uint32(uint16(v>>16)) - - return uint16(v + v>>16) -} - -// buildEtherPacket wraps the payload with IPv4, UDP and Ethernet frames. -func (c *dhcpConn) buildEtherPacket(payload []byte, peer *dhcpUcastAddr) (pkt []byte, err error) { - pkt = make([]byte, ipv4Sz+udpSz+len(payload)) - - // TODO(e.burkov): Think about generalizing all the layers marshalling - // using interfaces. - - ipLayer := ipv4Hdr(pkt[:ipv4Sz]) - ipLayer.encode(ipv4HdrVals{ - ihl: ipv4Sz, - totalLen: uint16(ipv4Sz + udpSz + len(payload)), - flags: ipv4FlagDontFrag, - ttl: ipv4DefaultTTL, - proto: udpProto, - src: c.srcIP, - dst: peer.yiaddr, - }) - - udpLayer := udpHdr(pkt[ipv4Sz : ipv4Sz+udpSz]) - udpLayer.encode(udpHdrVals{ - srcPort: dhcpv4.ServerPort, - dstPort: dhcpv4.ClientPort, - length: uint16(udpSz + len(payload)), - }, c.srcIP, peer.yiaddr, ipLayer[protocol], payload) - - copy(pkt[ipv4Sz+udpSz:], payload) - - pkt, err = (ðernet.Frame{ - Destination: peer.HardwareAddr, - Source: c.srcMAC, - EtherType: ethernet.EtherTypeIPv4, - Payload: pkt, - }).MarshalBinary() + buf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, ethLayer, ipv4Layer, udpLayer, dhcpLayer) if err != nil { - return nil, fmt.Errorf("marshaling ethernet frame: %w", err) + return nil, fmt.Errorf("serializing layers: %w", err) } - return pkt, nil + return buf.Bytes(), nil } diff --git a/internal/dhcpd/conn_unix_test.go b/internal/dhcpd/conn_unix_test.go index 886636ccb6b..477990c8048 100644 --- a/internal/dhcpd/conn_unix_test.go +++ b/internal/dhcpd/conn_unix_test.go @@ -46,11 +46,3 @@ func TestDHCPConn_WriteTo_common(t *testing.T) { assert.Zero(t, n) }) } - -// TODO(e.burkov): Cover network layers' encoders with tests: -// -// ipv4Hdr.encode -// udpHdr.encode -// checksum -// dhcpConn.buildEtherPacket -// diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 7edc4bcdc8c..6c2d4b5ec24 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -984,7 +984,7 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH case !req.IsBroadcast() && req.ClientHWAddr != nil: // Unicast DHCPOFFER and DHCPACK messages to the client's // hardware address and yiaddr. - peer = &dhcpUcastAddr{ + peer = &dhcpUnicastAddr{ Addr: raw.Addr{HardwareAddr: req.ClientHWAddr}, yiaddr: resp.YourIPAddr, } diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index b6081c59f2e..d20e715c410 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -440,7 +440,7 @@ func TestV4Server_Send(t *testing.T) { name: "chaddr", req: &dhcpv4.DHCPv4{ClientHWAddr: knownMAC}, resp: &dhcpv4.DHCPv4{YourIPAddr: knownIP}, - want: &dhcpUcastAddr{ + want: &dhcpUnicastAddr{ Addr: raw.Addr{HardwareAddr: knownMAC}, yiaddr: knownIP, },