From ac27db95c12858b6ef182a0bd4acebab67a23993 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 18 Jul 2023 15:47:17 +0300 Subject: [PATCH] all: imp code --- internal/dnsforward/dialcontext.go | 69 ++++++++++++++---------------- internal/dnsforward/dnsforward.go | 2 +- internal/home/clients.go | 11 ++++- internal/home/dns.go | 5 ++- internal/home/httpclient.go | 6 +-- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/internal/dnsforward/dialcontext.go b/internal/dnsforward/dialcontext.go index ae33e7e8551..db32dd3db64 100644 --- a/internal/dnsforward/dialcontext.go +++ b/internal/dnsforward/dialcontext.go @@ -6,55 +6,52 @@ import ( "net" "time" - "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) -// DialContext returns a DialContextFunc that uses s to resolve hostnames. -func (s *Server) DialContext() (f whois.DialContextFunc) { - return func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - log.Debug("dnsforward: dialing %q for network %q", addr, network) +// DialContext is a [whois.DialContextFunc] that uses s to resolve hostnames. +func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { + log.Debug("dnsforward: dialing %q for network %q", addr, network) - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - dialer := &net.Dialer{ - // TODO(a.garipov): Consider making configurable. - Timeout: time.Minute * 5, - } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } - if net.ParseIP(host) != nil { - return dialer.DialContext(ctx, network, addr) - } + dialer := &net.Dialer{ + // TODO(a.garipov): Consider making configurable. + Timeout: time.Minute * 5, + } - addrs, err := s.Resolve(host) - if err != nil { - return nil, fmt.Errorf("resolving %q: %w", host, err) - } + if net.ParseIP(host) != nil { + return dialer.DialContext(ctx, network, addr) + } - log.Debug("dnsforward: resolving %q: %v", host, addrs) + addrs, err := s.Resolve(host) + if err != nil { + return nil, fmt.Errorf("resolving %q: %w", host, err) + } - if len(addrs) == 0 { - return nil, fmt.Errorf("no addresses for host %q", host) - } + log.Debug("dnsforward: resolving %q: %v", host, addrs) - var dialErrs []error - for _, a := range addrs { - addr = net.JoinHostPort(a.String(), port) - conn, err = dialer.DialContext(ctx, network, addr) - if err != nil { - dialErrs = append(dialErrs, err) + if len(addrs) == 0 { + return nil, fmt.Errorf("no addresses for host %q", host) + } - continue - } + var dialErrs []error + for _, a := range addrs { + addr = net.JoinHostPort(a.String(), port) + conn, err = dialer.DialContext(ctx, network, addr) + if err != nil { + dialErrs = append(dialErrs, err) - return conn, err + continue } - // TODO(a.garipov): Use errors.Join in Go 1.20. - return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...) + return conn, err } + + // TODO(a.garipov): Use errors.Join in Go 1.20. + return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...) } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index d5f064eaa69..bb389969e95 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -556,7 +556,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { s.addrProc = client.EmptyAddrProc{} } else { c := s.conf.AddrProcConf - c.DialContext = s.DialContext() + c.DialContext = s.DialContext c.PrivateSubnets = s.privateNets c.UsePrivateRDNS = s.conf.UsePrivateRDNS s.addrProc = client.NewDefaultAddrProc(s.conf.AddrProcConf) diff --git a/internal/home/clients.go b/internal/home/clients.go index 9fee7a34682..049710bc8d0 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -10,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -787,9 +788,17 @@ func (clients *clientsContainer) addHost( return clients.addHostLocked(ip, host, src) } -// UpdateAddress implements the [client.ClientStorage] interface for +// type check +var _ client.AddressUpdater = (*clientsContainer)(nil) + +// UpdateAddress implements the [client.AddressUpdater] interface for // *clientsContainer func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { + // Common fast path optimization. + if host == "" && info == nil { + return + } + clients.lock.Lock() defer clients.lock.Unlock() diff --git a/internal/home/dns.go b/internal/home/dns.go index 74dedbee0c8..7dfa8b85fa5 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -240,13 +240,14 @@ func newServerConfig( DNS64Prefixes: config.DNS.DNS64Prefixes, } + const initialClientsNum = 100 + // Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they // are set by [dnsforward.Server.Prepare]. - const topClientsNumber = 100 newConf.AddrProcConf = &client.DefaultAddrProcConfig{ Exchanger: Context.dnsServer, AddressUpdater: &Context.clients, - InitialAddresses: Context.stats.TopClientsIP(topClientsNumber), + InitialAddresses: Context.stats.TopClientsIP(initialClientsNum), UseRDNS: config.Clients.Sources.RDNS, UseWHOIS: config.Clients.Sources.WHOIS, } diff --git a/internal/home/httpclient.go b/internal/home/httpclient.go index 30230e72a8d..ae41d6acd66 100644 --- a/internal/home/httpclient.go +++ b/internal/home/httpclient.go @@ -15,10 +15,10 @@ import ( // // TODO(a.garipov, e.burkov): This is rather messy. Refactor. func httpClient() (c *http.Client) { + // Do not use Context.dnsServer.DialContext directly in the struct literal + // below, since Context.dnsServer may be nil when this function is called. dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - f := Context.dnsServer.DialContext() - - return f(ctx, network, addr) + return Context.dnsServer.DialContext(ctx, network, addr) } return &http.Client{