From b4abccfebe3fa166917d352f67980f9372072b67 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 11 Oct 2023 15:36:12 +0300 Subject: [PATCH] Pull request 288: close bootstraps Merge in GO/dnsproxy from close-bootstraps to master Squashed commit of the following: commit 35b1163fb541d3247ae22e31a3d5e668be0acac1 Author: Eugene Burkov Date: Wed Oct 11 15:17:24 2023 +0300 upstream: imp doc commit 03fa1e5f1cacbe04bd0141e09ed3b52422c41af1 Author: Eugene Burkov Date: Wed Oct 11 14:01:31 2023 +0300 upstream: close bootstraps --- upstream/doh.go | 12 ++++--- upstream/dot.go | 12 ++++--- upstream/plain.go | 8 +++-- upstream/quic.go | 6 +++- upstream/upstream.go | 82 ++++++++++++++++++++++++++++++-------------- 5 files changed, 82 insertions(+), 38 deletions(-) diff --git a/upstream/doh.go b/upstream/doh.go index 960302f48..1138263c9 100644 --- a/upstream/doh.go +++ b/upstream/doh.go @@ -49,6 +49,9 @@ type dnsOverHTTPS struct { // one. getDialer DialerInitializer + // closeBoot is the function to close the bootstrap upstreams. + closeBoot closeFunc + // addr is the DNS-over-HTTPS server URL. addr *url.URL @@ -89,7 +92,7 @@ func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) { httpVersions = DefaultHTTPVersions } - getDialer, err := newDialerInitializer(addr, opts) + getDialer, closeBoot, err := newDialerInitializer(addr, opts) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err @@ -97,6 +100,7 @@ func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) { ups := &dnsOverHTTPS{ getDialer: getDialer, + closeBoot: closeBoot, addr: addr, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, @@ -191,11 +195,11 @@ func (p *dnsOverHTTPS) Close() (err error) { runtime.SetFinalizer(p, nil) - if p.client == nil { - return nil + if p.client != nil { + err = p.closeClient(p.client) } - return p.closeClient(p.client) + return errors.Join(err, errors.Annotate(p.closeBoot(), "closing bootstrap: %w")) } // closeClient cleans up resources used by client if necessary. Note, that at diff --git a/upstream/dot.go b/upstream/dot.go index e34f5f960..c8508e175 100644 --- a/upstream/dot.go +++ b/upstream/dot.go @@ -31,6 +31,9 @@ type dnsOverTLS struct { // new one. getDialer DialerInitializer + // closeBoot is the function to close the bootstrap upstreams. + closeBoot closeFunc + // tlsConf is the configuration of TLS. tlsConf *tls.Config @@ -54,7 +57,7 @@ var _ Upstream = (*dnsOverTLS)(nil) func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) { addPort(addr, defaultPortDoT) - getDialer, err := newDialerInitializer(addr, opts) + getDialer, closeBoot, err := newDialerInitializer(addr, opts) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err @@ -63,6 +66,7 @@ func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) { tlsUps := &dnsOverTLS{ addr: addr, getDialer: getDialer, + closeBoot: closeBoot, // #nosec G402 -- TLS certificate verification could be disabled by // configuration. tlsConf: &tls.Config{ @@ -147,11 +151,9 @@ func (p *dnsOverTLS) Close() (err error) { } } - if len(closeErrs) > 0 { - return errors.List("closing tls conns", closeErrs...) - } + closeErrs = append(closeErrs, errors.Annotate(p.closeBoot(), "closing bootstrap: %w")) - return nil + return errors.Join(closeErrs...) } // conn returns the first available connection from the pool if there is any, or diff --git a/upstream/plain.go b/upstream/plain.go index 99d4b150f..806cb00ba 100644 --- a/upstream/plain.go +++ b/upstream/plain.go @@ -37,6 +37,9 @@ type plainDNS struct { // one. getDialer DialerInitializer + // closeBoot is the function to close the bootstrap upstreams. + closeBoot closeFunc + // net is the network of the connections. net network @@ -59,7 +62,7 @@ func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) { addPort(addr, defaultPortPlain) - getDialer, err := newDialerInitializer(addr, opts) + getDialer, closeBoot, err := newDialerInitializer(addr, opts) if err != nil { return nil, err } @@ -67,6 +70,7 @@ func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) { return &plainDNS{ addr: addr, getDialer: getDialer, + closeBoot: closeBoot, net: addr.Scheme, timeout: opts.Timeout, }, nil @@ -175,7 +179,7 @@ func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { // Close implements the [Upstream] interface for *plainDNS. func (p *plainDNS) Close() (err error) { - return nil + return errors.Annotate(p.closeBoot(), "closing bootstrap: %w") } // errQuestion is returned when a message has malformed question section. diff --git a/upstream/quic.go b/upstream/quic.go index 387cb39e2..df62c1a43 100644 --- a/upstream/quic.go +++ b/upstream/quic.go @@ -57,6 +57,9 @@ type dnsOverQUIC struct { // one. getDialer DialerInitializer + // closeBoot is the function to close the bootstrap upstreams. + closeBoot closeFunc + // addr is the DNS-over-QUIC server URL. addr *url.URL @@ -97,13 +100,14 @@ var _ Upstream = (*dnsOverQUIC)(nil) func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) { addPort(addr, defaultPortDoQ) - getDialer, err := newDialerInitializer(addr, opts) + getDialer, closeBoot, err := newDialerInitializer(addr, opts) if err != nil { return nil, err } u = &dnsOverQUIC{ getDialer: getDialer, + closeBoot: closeBoot, addr: addr, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, diff --git a/upstream/upstream.go b/upstream/upstream.go index e1fa199c8..1da04b815 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -17,6 +17,7 @@ import ( "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" @@ -24,6 +25,7 @@ import ( "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/logging" + "golang.org/x/exp/slices" ) // Upstream is an interface for a DNS resolver. @@ -76,10 +78,11 @@ type Options struct { // CipherSuites is a custom list of TLSv1.2 ciphers. CipherSuites []uint16 - // Bootstrap is a list of DNS servers to be used to resolve - // DNS-over-HTTPS/DNS-over-TLS hostnames. Plain DNS, DNSCrypt, or - // DNS-over-HTTPS/DNS-over-TLS with IP addresses (not hostnames) could be - // used. + // Bootstrap is a list of DNS servers to be used to resolve DoH/DoT/DoQ + // hostnames. Plain DNS, DNSCrypt, or DoH/DoT/DoQ with IP addresses (not + // hostnames) could be used. Those servers will be turned to upstream + // servers and will be closed as soon as the resolved upstream itself is + // closed. Bootstrap []string // List of IP addresses of the upstream DNS server. If not empty, bootstrap @@ -306,12 +309,26 @@ func logFinish(upstreamAddress string, n network, err error) { // resolving will be performed only once. type DialerInitializer func() (handler bootstrap.DialHandler, err error) +// closeFunc is the signature of a function that closes an upstream. +type closeFunc func() (err error) + +// nopClose is the [closeFunc] that does nothing. +func nopClose() (err error) { return nil } + // newDialerInitializer creates an initializer of the dialer that will dial the // addresses resolved from u using opts. -func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer, err error) { +// +// TODO(e.burkov): Returning closeFunc is a temporary solution. It's needed +// to close the bootstrap upstreams, which may require closing. It should be +// gone when the [Options.Bootstrap] will be turned into [Resolver] and it's +// closing will be handled by the caller. +func newDialerInitializer( + u *url.URL, + opts *Options, +) (di DialerInitializer, closeBoot closeFunc, err error) { host, port, err := netutil.SplitHostPort(u.Host) if err != nil { - return nil, fmt.Errorf("invalid address: %s: %w", u.Host, err) + return nil, nopClose, fmt.Errorf("invalid address: %s: %w", u.Host, err) } if addrsLen := len(opts.ServerIPAddrs); addrsLen > 0 { @@ -324,58 +341,58 @@ func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer, err handler := bootstrap.NewDialContext(opts.Timeout, addrs...) - return func() (bootstrap.DialHandler, error) { return handler, nil }, nil + return func() (h bootstrap.DialHandler, err error) { return handler, nil }, nopClose, nil } else if _, err = netip.ParseAddr(host); err == nil { // Don't resolve the address of the server since it's already an IP. handler := bootstrap.NewDialContext(opts.Timeout, u.Host) - return func() (bootstrap.DialHandler, error) { return handler, nil }, nil + return func() (h bootstrap.DialHandler, err error) { return handler, nil }, nopClose, nil } - resolvers, err := newResolvers(opts) + resolvers, closeBoot, err := newResolvers(opts) if err != nil { - // Don't wrap the error since it's informative enough as is. - return nil, err + return nil, nopClose, errors.Join(err, closeBoot()) } - var dialHandler atomic.Value + var dialHandler atomic.Pointer[bootstrap.DialHandler] di = func() (h bootstrap.DialHandler, resErr error) { // Check if the dial handler has already been created. - h, ok := dialHandler.Load().(bootstrap.DialHandler) - if ok { - return h, nil + if hPtr := dialHandler.Load(); hPtr != nil { + return *hPtr, nil } // TODO(e.burkov): It may appear that several exchanges will try to // resolve the upstream hostname at the same time. Currently, the last // successful value will be stored in dialHandler, but ideally we should - // resolve only once. + // resolve only once at a time. h, resolveErr := bootstrap.ResolveDialContext(u, opts.Timeout, resolvers, opts.PreferIPv6) if resolveErr != nil { return nil, fmt.Errorf("creating dial handler: %w", resolveErr) } - if !dialHandler.CompareAndSwap(nil, h) { - return dialHandler.Load().(bootstrap.DialHandler), nil + if !dialHandler.CompareAndSwap(nil, &h) { + // The dial handler has just been created by another exchange. + return *dialHandler.Load(), nil } return h, nil } - return di, nil + return di, closeBoot, nil } // newResolvers prepares resolvers for bootstrapping. If opts.Bootstrap is // empty, the only new [net.Resolver] will be returned. Otherwise, the it will // be added for each occurrence of an empty string in [Options.Bootstrap]. -func newResolvers(opts *Options) (resolvers []Resolver, err error) { +func newResolvers(opts *Options) (resolvers []Resolver, closeBoot closeFunc, err error) { bootstraps := opts.Bootstrap - if len(bootstraps) == 0 { - return []Resolver{&net.Resolver{}}, nil + l := len(bootstraps) + if l == 0 { + return []Resolver{&net.Resolver{}}, nopClose, nil } - resolvers = make([]Resolver, 0, len(bootstraps)) - for _, boot := range bootstraps { + resolvers, closeBoots := make([]Resolver, 0, l), make([]closeFunc, 0, l) + for i, boot := range bootstraps { if boot == "" { resolvers = append(resolvers, &net.Resolver{}) @@ -384,11 +401,24 @@ func newResolvers(opts *Options) (resolvers []Resolver, err error) { r, rErr := NewUpstreamResolver(boot, opts) if rErr != nil { - return nil, fmt.Errorf("preparing bootstrap resolver: %w", rErr) + resolvers = nil + err = fmt.Errorf("preparing bootstrap resolver at index %d: %w", i, rErr) + + break } resolvers = append(resolvers, r) + closeBoots = append(closeBoots, r.(upstreamResolver).Close) } - return resolvers, nil + closeBoots = slices.Clip(closeBoots) + + return resolvers, func() (closeErr error) { + var errs []error + for _, cb := range closeBoots { + errs = append(errs, cb()) + } + + return errors.Join(errs...) + }, err }