diff --git a/CHANGELOG.md b/CHANGELOG.md index 2baec6ebfdc..5e7a7cc4107 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ and this project adheres to ### Fixed +- Inconsistent resolving of DHCP clients when the DHCP server is disabled + ([#2934]). - Comment handling in clients' custom upstreams ([#2947]). - Overwriting of DHCPv4 options when using the HTTP API ([#2927]). - Assumption that MAC addresses always have the length of 6 octets ([#2828]). @@ -75,6 +77,7 @@ and this project adheres to [#2838]: https://github.com/AdguardTeam/AdGuardHome/issues/2838 [#2889]: https://github.com/AdguardTeam/AdGuardHome/issues/2889 [#2927]: https://github.com/AdguardTeam/AdGuardHome/issues/2927 +[#2934]: https://github.com/AdguardTeam/AdGuardHome/issues/2934 [#2945]: https://github.com/AdguardTeam/AdGuardHome/issues/2945 [#2947]: https://github.com/AdguardTeam/AdGuardHome/issues/2947 diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 2d54af39218..26388cb013a 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -115,6 +115,7 @@ const ( LeaseChangedAdded = iota LeaseChangedAddedStatic LeaseChangedRemovedStatic + LeaseChangedRemovedAll LeaseChangedDBStore ) diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index d0a4a7a150c..8a0ec6303b0 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -801,6 +801,10 @@ func (s *v4Server) Start() error { log.Debug("dhcpv4: srv.Serve: %s", err) }() + // Signal to the clients containers in packages home and dnsforward that + // it should reload the DHCP clients. + s.conf.notify(LeaseChangedAdded) + return nil } @@ -815,7 +819,11 @@ func (s *v4Server) Stop() { if err != nil { log.Error("dhcpv4: srv.Close: %s", err) } - // now s.srv.Serve() will return + + // Signal to the clients containers in packages home and dnsforward that + // it should remove all DHCP clients. + s.conf.notify(LeaseChangedRemovedAll) + s.srv = nil } diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index d84a704a869..b655bec99be 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -23,7 +23,7 @@ func TestV4_AddRemove_static(t *testing.T) { SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, }) - require.Nil(t, err) + require.NoError(t, err) ls := s.GetLeases(LeasesStatic) assert.Empty(t, ls) @@ -33,23 +33,30 @@ func TestV4_AddRemove_static(t *testing.T) { IP: net.IP{192, 168, 10, 150}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - require.Nil(t, s.AddStaticLease(l)) - assert.NotNil(t, s.AddStaticLease(l)) + + err = s.AddStaticLease(l) + require.NoError(t, err) + + err = s.AddStaticLease(l) + assert.Error(t, err) ls = s.GetLeases(LeasesStatic) require.Len(t, ls, 1) + assert.True(t, l.IP.Equal(ls[0].IP)) assert.Equal(t, l.HWAddr, ls[0].HWAddr) assert.True(t, ls[0].IsStatic()) // Try to remove static lease. - assert.NotNil(t, s.RemoveStaticLease(Lease{ + err = s.RemoveStaticLease(Lease{ IP: net.IP{192, 168, 10, 110}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, - })) + }) + assert.Error(t, err) // Remove static lease. - require.Nil(t, s.RemoveStaticLease(l)) + err = s.RemoveStaticLease(l) + require.NoError(t, err) ls = s.GetLeases(LeasesStatic) assert.Empty(t, ls) } @@ -63,7 +70,7 @@ func TestV4_AddReplace(t *testing.T) { SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, }) - require.Nil(t, err) + require.NoError(t, err) s, ok := sIface.(*v4Server) require.True(t, ok) @@ -78,7 +85,7 @@ func TestV4_AddReplace(t *testing.T) { for i := range dynLeases { err = s.addLease(&dynLeases[i]) - require.Nil(t, err) + require.NoError(t, err) } stLeases := []Lease{{ @@ -90,7 +97,8 @@ func TestV4_AddReplace(t *testing.T) { }} for _, l := range stLeases { - require.Nil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.NoError(t, err) } ls := s.GetLeases(LeasesStatic) @@ -113,32 +121,35 @@ func TestV4StaticLease_Get(t *testing.T) { SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, }) - require.Nil(t, err) + require.NoError(t, err) s, ok := sIface.(*v4Server) require.True(t, ok) + s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} l := Lease{ IP: net.IP{192, 168, 10, 150}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - require.Nil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.NoError(t, err) var req, resp *dhcpv4.DHCPv4 mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} t.Run("discover", func(t *testing.T) { - var terr error + req, err = dhcpv4.NewDiscovery(mac) + require.NoError(t, err) - req, terr = dhcpv4.NewDiscovery(mac) - require.Nil(t, terr) + resp, err = dhcpv4.NewReplyFromRequest(req) + require.NoError(t, err) - resp, terr = dhcpv4.NewReplyFromRequest(req) - require.Nil(t, terr) assert.Equal(t, 1, s.process(req, resp)) }) - require.Nil(t, err) + + // Don't continue if we got any errors in the previous subtest. + require.NoError(t, err) t.Run("offer", func(t *testing.T) { assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) @@ -152,13 +163,15 @@ func TestV4StaticLease_Get(t *testing.T) { t.Run("request", func(t *testing.T) { req, err = dhcpv4.NewRequestFromOffer(resp) - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv4.NewReplyFromRequest(req) - require.Nil(t, err) + require.NoError(t, err) + assert.Equal(t, 1, s.process(req, resp)) }) - require.Nil(t, err) + + require.NoError(t, err) t.Run("ack", func(t *testing.T) { assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) @@ -172,11 +185,13 @@ func TestV4StaticLease_Get(t *testing.T) { dnsAddrs := resp.DNS() require.Len(t, dnsAddrs, 1) + assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0])) t.Run("check_lease", func(t *testing.T) { ls := s.GetLeases(LeasesStatic) require.Len(t, ls, 1) + assert.True(t, l.IP.Equal(ls[0].IP)) assert.Equal(t, mac, ls[0].HWAddr) }) @@ -196,10 +211,11 @@ func TestV4DynamicLease_Get(t *testing.T) { "82 ip 1.2.3.4", }, }) - require.Nil(t, err) + require.NoError(t, err) s, ok := sIface.(*v4Server) require.True(t, ok) + s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} var req, resp *dhcpv4.DHCPv4 @@ -207,15 +223,16 @@ func TestV4DynamicLease_Get(t *testing.T) { t.Run("discover", func(t *testing.T) { req, err = dhcpv4.NewDiscovery(mac) - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv4.NewReplyFromRequest(req) - require.Nil(t, err) + require.NoError(t, err) + assert.Equal(t, 1, s.process(req, resp)) }) // Don't continue if we got any errors in the previous subtest. - require.Nil(t, err) + require.NoError(t, err) t.Run("offer", func(t *testing.T) { assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) @@ -226,6 +243,7 @@ func TestV4DynamicLease_Get(t *testing.T) { router := resp.Router() require.Len(t, router, 1) + assert.Equal(t, s.conf.GatewayIP, router[0]) assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) @@ -236,16 +254,16 @@ func TestV4DynamicLease_Get(t *testing.T) { }) t.Run("request", func(t *testing.T) { - var terr error + req, err = dhcpv4.NewRequestFromOffer(resp) + require.NoError(t, err) - req, terr = dhcpv4.NewRequestFromOffer(resp) - require.Nil(t, terr) + resp, err = dhcpv4.NewReplyFromRequest(req) + require.NoError(t, err) - resp, terr = dhcpv4.NewReplyFromRequest(req) - require.Nil(t, terr) assert.Equal(t, 1, s.process(req, resp)) }) - require.Nil(t, err) + + require.NoError(t, err) t.Run("ack", func(t *testing.T) { assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) @@ -259,12 +277,14 @@ func TestV4DynamicLease_Get(t *testing.T) { dnsAddrs := resp.DNS() require.Len(t, dnsAddrs, 1) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0])) // check lease t.Run("check_lease", func(t *testing.T) { ls := s.GetLeases(LeasesDynamic) - assert.Len(t, ls, 1) + require.Len(t, ls, 1) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP)) assert.Equal(t, mac, ls[0].HWAddr) }) diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 47aad0641b2..31a402215fc 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -154,44 +154,60 @@ func isHostnameOK(hostname string) bool { return true } +func (s *Server) setTableHostToIP(t hostToIPTable) { + s.tableHostToIPLock.Lock() + defer s.tableHostToIPLock.Unlock() + + s.tableHostToIP = t +} + +func (s *Server) setTableIPToHost(t ipToHostTable) { + s.tableIPToHostLock.Lock() + defer s.tableIPToHostLock.Unlock() + + s.tableIPToHost = t +} + func (s *Server) onDHCPLeaseChanged(flags int) { + add := true switch flags { case dhcpd.LeaseChangedAdded, dhcpd.LeaseChangedAddedStatic, dhcpd.LeaseChangedRemovedStatic: - // + // Go on. + case dhcpd.LeaseChangedRemovedAll: + add = false default: return } - hostToIP := make(map[string]net.IP) - m := make(map[string]string) - - ll := s.dhcpServer.Leases(dhcpd.LeasesAll) + var hostToIP hostToIPTable + var ipToHost ipToHostTable + if add { + hostToIP = make(hostToIPTable) + ipToHost = make(ipToHostTable) - for _, l := range ll { - if len(l.Hostname) == 0 || !isHostnameOK(l.Hostname) { - continue - } + ll := s.dhcpServer.Leases(dhcpd.LeasesAll) - lowhost := strings.ToLower(l.Hostname) + for _, l := range ll { + if len(l.Hostname) == 0 || !isHostnameOK(l.Hostname) { + continue + } - m[l.IP.String()] = lowhost + lowhost := strings.ToLower(l.Hostname) - ip := make(net.IP, 4) - copy(ip, l.IP.To4()) - hostToIP[lowhost] = ip - } + ipToHost[l.IP.String()] = lowhost - log.Debug("dns: added %d A/PTR entries from DHCP", len(m)) + ip := make(net.IP, 4) + copy(ip, l.IP.To4()) + hostToIP[lowhost] = ip + } - s.tableHostToIPLock.Lock() - s.tableHostToIP = hostToIP - s.tableHostToIPLock.Unlock() + log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost)) + } - s.tablePTRLock.Lock() - s.tablePTR = m - s.tablePTRLock.Unlock() + s.setTableHostToIP(hostToIP) + s.setTableIPToHost(ipToHost) } // processDetermineLocal determines if the client's IP address is from @@ -336,14 +352,14 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { // ipToHost tries to get a hostname leased by DHCP. It's safe for concurrent // use. func (s *Server) ipToHost(ip net.IP) (host string, ok bool) { - s.tablePTRLock.Lock() - defer s.tablePTRLock.Unlock() + s.tableIPToHostLock.Lock() + defer s.tableIPToHostLock.Unlock() - if s.tablePTR == nil { + if s.tableIPToHost == nil { return "", false } - host, ok = s.tablePTR[ip.String()] + host, ok = s.tableIPToHost[ip.String()] return host, ok } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 1b060cbb4fc..0addc008623 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -91,7 +91,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := &Server{ autohostSuffix: defaultAutohostSuffix, - tableHostToIP: map[string]net.IP{ + tableHostToIP: hostToIPTable{ "example": knownIP, }, } @@ -202,7 +202,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := &Server{ autohostSuffix: tc.suffix, - tableHostToIP: map[string]net.IP{ + tableHostToIP: hostToIPTable{ "example": knownIP, }, } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 8951f45d11c..57ce30004fb 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -43,6 +43,15 @@ var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"} var webRegistered bool +// hostToIPTable is an alias for the type of Server.tableHostToIP. +type hostToIPTable = map[string]net.IP + +// ipToHostTable is an alias for the type of Server.tableIPToHost. +// +// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other +// places? +type ipToHostTable = map[string]string + // Server is the main way to start a DNS server. // // Example: @@ -69,11 +78,11 @@ type Server struct { subnetDetector *aghnet.SubnetDetector localResolvers *proxy.Proxy - tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP) + tableHostToIP hostToIPTable tableHostToIPLock sync.Mutex - tablePTR map[string]string // "IP -> hostname" table for reverse lookup - tablePTRLock sync.Mutex + tableIPToHost ipToHostTable + tableIPToHostLock sync.Mutex // DNS proxy instance for internal usage // We don't Start() it and so no listen port is required. diff --git a/internal/home/clients.go b/internal/home/clients.go index 882f5586c2e..dccf95c5dd8 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -121,7 +121,7 @@ func (clients *clientsContainer) Init( clients.addFromConfig(objects) if !clients.testing { - clients.addFromDHCP() + clients.updateFromDHCP(true) if clients.dhcpServer != nil { clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) } @@ -244,7 +244,9 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { case dhcpd.LeaseChangedAdded, dhcpd.LeaseChangedAddedStatic, dhcpd.LeaseChangedRemovedStatic: - clients.addFromDHCP() + clients.updateFromDHCP(true) + case dhcpd.LeaseChangedRemovedAll: + clients.updateFromDHCP(false) } } @@ -768,9 +770,9 @@ func (clients *clientsContainer) addFromSystemARP() { log.Debug("clients: added %d client aliases from 'arp -a' command output", n) } -// addFromDHCP adds the clients that have a non-empty hostname from the DHCP +// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP // server. -func (clients *clientsContainer) addFromDHCP() { +func (clients *clientsContainer) updateFromDHCP(add bool) { if clients.dhcpServer == nil { return } @@ -780,6 +782,10 @@ func (clients *clientsContainer) addFromDHCP() { clients.rmHostsBySrc(ClientSourceDHCP) + if !add { + return + } + leases := clients.dhcpServer.Leases(dhcpd.LeasesAll) n := 0 for _, l := range leases {