diff --git a/go.mod b/go.mod index 86faea899..ea82afce0 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy go 1.18 require ( - github.com/AdguardTeam/golibs v0.10.9 + github.com/AdguardTeam/golibs v0.11.2 github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 @@ -13,7 +13,8 @@ require ( github.com/miekg/dns v1.1.50 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/stretchr/testify v1.8.0 - golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b + golang.org/x/net v0.1.0 + golang.org/x/sys v0.1.1-0.20221102194838-fc697a31fa06 gopkg.in/yaml.v3 v3.0.1 ) @@ -32,10 +33,9 @@ require ( github.com/onsi/ginkgo/v2 v2.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect - golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect + golang.org/x/exp v0.0.0-20221019170559-20944726eadf // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/sys v0.1.1-0.20221102194838-fc697a31fa06 // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/text v0.4.0 // indirect golang.org/x/tools v0.1.12 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) diff --git a/go.sum b/go.sum index 891cbead6..2fcf19106 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/golibs v0.10.9 h1:F9oP2da0dQ9RQDM1lGR7LxUTfUWu8hEFOs4icwAkKM0= -github.com/AdguardTeam/golibs v0.10.9/go.mod h1:W+5rznZa1cSNSFt+gPS7f4Wytnr9fOrd5ZYqwadPw14= +github.com/AdguardTeam/golibs v0.11.2 h1:JbQB1Dg2JWStXgHh1QqBbOLWnP4t9oDjppoBH6TVXSE= +github.com/AdguardTeam/golibs v0.11.2/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= @@ -64,8 +64,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM= golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= -golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20221019170559-20944726eadf h1:nFVjjKDgNY37+ZSYCJmtYf7tOlfQswHqplG2eosjOMg= +golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -73,8 +73,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY= -golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= @@ -93,8 +93,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index 5077fd2f4..b6252f167 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -24,41 +24,43 @@ const dialTimeout = 10 * time.Second // dnsOverTLS is a struct that implements the Upstream interface for the // DNS-over-TLS protocol. type dnsOverTLS struct { - boot *bootstrapper - conns *sync.Pool - connsInUse *sync.WaitGroup + // boot resolves the hostname upstream addresses. + boot *bootstrapper + // conns stores the connections ready for reuse. + conns *sync.Pool + // connsWG tracks all the connections usages. + connsWG *sync.WaitGroup } // type check var _ Upstream = (*dnsOverTLS)(nil) // newDoT returns the DNS-over-TLS Upstream. -func newDoT(uu *url.URL, opts *Options) (u Upstream, err error) { - addPort(uu, defaultPortDoT) +func newDoT(u *url.URL, opts *Options) (ups Upstream, err error) { + addPort(u, defaultPortDoT) - var b *bootstrapper - b, err = urlToBoot(uu, opts) + boot, err := urlToBoot(u, opts) if err != nil { return nil, fmt.Errorf("creating tls bootstrapper: %w", err) } - u = &dnsOverTLS{ - boot: b, - conns: &sync.Pool{}, - connsInUse: &sync.WaitGroup{}, + ups = &dnsOverTLS{ + boot: boot, + conns: &sync.Pool{}, + connsWG: &sync.WaitGroup{}, } - runtime.SetFinalizer(u, (*dnsOverTLS).Close) + runtime.SetFinalizer(ups, (*dnsOverTLS).Close) - return u, nil + return ups, nil } // Address implements the [Upstream] interface for *dnsOverTLS. func (p *dnsOverTLS) Address() string { return p.boot.URL.String() } -// Get gets a connection from the pool (if there's one available) or creates -// a new TLS connection. -func (p *dnsOverTLS) getConn() (conn net.Conn, err error) { +// conn returns a connection from the pool if there's one available or creates a +// new TLS connection otherwise. +func (p *dnsOverTLS) conn() (conn net.Conn, err error) { c := p.conns.Get() conn, ok := c.(net.Conn) if conn == nil { @@ -85,13 +87,13 @@ func (p *dnsOverTLS) getConn() (conn net.Conn, err error) { // Exchange implements the [Upstream] interface for *dnsOverTLS. func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { - conn, err := p.getConn() + conn, err := p.conn() if err != nil { return nil, fmt.Errorf("getting conn to %s: %w", p.Address(), err) } - p.connsInUse.Add(1) - defer p.connsInUse.Done() + p.connsWG.Add(1) + defer p.connsWG.Done() reply, err = p.exchangeWithConn(conn, m) if err != nil { @@ -122,7 +124,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { // Close implements the [Upstream] interface for *dnsOverTLS. func (p *dnsOverTLS) Close() (err error) { runtime.SetFinalizer(p, nil) - p.connsInUse.Wait() + p.connsWG.Wait() var closeErrs []error for c := p.conns.Get(); c != nil; c = p.conns.Get() { @@ -134,8 +136,7 @@ func (p *dnsOverTLS) Close() (err error) { } closeErr := conn.Close() - if closeErr != nil && p.isVitalErr(closeErr) { - // TODO(e.burkov): !! inspect. + if closeErr != nil && isCriticalTCP(closeErr) { closeErrs = append(closeErrs, closeErr) } } @@ -201,7 +202,7 @@ func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls. err = conn.SetDeadline(time.Now().Add(dialTimeout)) if err != nil { // Must not happen in normal circumstances. - panic(fmt.Errorf("cannot set deadline: %w", err)) + panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err)) } err = conn.Handshake() @@ -212,8 +213,11 @@ func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls. return conn, nil } -func (p *dnsOverTLS) isVitalErr(err error) (ok bool) { - if netErr := new(net.Error); errors.As(err, netErr) && (*netErr).Timeout() { +// isCriticalTCP returns true if err isn't an expected error in terms of closing +// the TCP connection. +func isCriticalTCP(err error) (ok bool) { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { return false } @@ -222,7 +226,7 @@ func (p *dnsOverTLS) isVitalErr(err error) (ok bool) { errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed), errors.Is(err, os.ErrDeadlineExceeded), - p.isConnBroke(err): + isConnBroken(err): return false default: return true diff --git a/upstream/upstream_dot_test.go b/upstream/upstream_dot_test.go index 70ec89f60..ca93bb0ff 100644 --- a/upstream/upstream_dot_test.go +++ b/upstream/upstream_dot_test.go @@ -107,8 +107,7 @@ func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, u.Close) - require.IsType(t, &dnsOverTLS{}, u) - p, _ := u.(*dnsOverTLS) + p := testutil.RequireTypeAssert[*dnsOverTLS](t, u) var usedConn net.Conn @@ -120,8 +119,7 @@ func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) { // Now let's close the pooled connection. conn := p.conns.Get() - require.IsType(t, &tls.Conn{}, conn) - usedConn, _ = conn.(net.Conn) + usedConn = testutil.RequireTypeAssert[net.Conn](t, conn) require.NoError(t, usedConn.Close()) @@ -136,7 +134,8 @@ func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) { // Now assert that the number of connections in the pool is not changed. conn = p.conns.Get() - require.IsType(t, &tls.Conn{}, conn) + _ = testutil.RequireTypeAssert[net.Conn](t, conn) + require.Nil(t, p.conns.Get()) assert.NotSame(t, usedConn, conn) @@ -171,12 +170,11 @@ func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) { require.NoError(t, err) requireResponse(t, req, response) - p := u.(*dnsOverTLS) + p := testutil.RequireTypeAssert[*dnsOverTLS](t, u) // Now let's get connection from the pool and use it again. conn := p.conns.Get() - require.IsType(t, &tls.Conn{}, conn) - usedConn, _ := conn.(net.Conn) + usedConn := testutil.RequireTypeAssert[net.Conn](t, conn) response, err = p.exchangeWithConn(usedConn, req) require.NoError(t, err) diff --git a/upstream/upstream_dot_unix.go b/upstream/upstream_dot_unix.go index 10992ad3b..8bf0db3ae 100644 --- a/upstream/upstream_dot_unix.go +++ b/upstream/upstream_dot_unix.go @@ -1,3 +1,5 @@ +//go:build darwin || freebsd || linux || openbsd + package upstream import ( @@ -5,6 +7,7 @@ import ( "golang.org/x/sys/unix" ) -func (p *dnsOverTLS) isConnBroke(err error) (ok bool) { +// isConnBroken returns true if err means that a connection is broken. +func isConnBroken(err error) (ok bool) { return errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ETIMEDOUT) } diff --git a/upstream/upstream_dot_windows.go b/upstream/upstream_dot_windows.go new file mode 100644 index 000000000..a2349adcf --- /dev/null +++ b/upstream/upstream_dot_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +package upstream + +import ( + "github.com/AdguardTeam/golibs/errors" + "golang.org/x/sys/windows" +) + +// isConnBroken always returns false. +func isConnBroken(err error) (ok bool) { + return errors.Is(err, windows.WSAECONNABORTED) +} diff --git a/vendor/github.com/AdguardTeam/golibs/errors/errors.go b/vendor/github.com/AdguardTeam/golibs/errors/errors.go index f94c22afc..8487ae800 100644 --- a/vendor/github.com/AdguardTeam/golibs/errors/errors.go +++ b/vendor/github.com/AdguardTeam/golibs/errors/errors.go @@ -29,14 +29,14 @@ type Wrapper interface { // // It calls errors.As from the Go standard library. See go doc errors.As for // the full documentation. -func As(err error, target interface{}) (ok bool) { +func As(err error, target any) (ok bool) { return stderrors.As(err, target) } // Aser is a copy of the hidden aser interface from the Go standard library. It // is added here for tests, linting, etc. type Aser interface { - As(target interface{}) (ok bool) + As(target any) (ok bool) } // Is reports whether any error in err's chain matches target. @@ -82,9 +82,9 @@ func Unwrap(err error) (unwrapped error) { // dynamically. Users of this API must check it's return value as well as the // result errors.As. // -// if derr := errors.Deferred(nil); errors.As(err, &derr) && derr.Deferred() { -// // … -// } +// if derr := errors.Deferred(nil); errors.As(err, &derr) && derr.Deferred() { +// // … +// } // // See https://dave.cheney.net/2014/12/24/inspecting-errors. type Deferred interface { @@ -137,22 +137,22 @@ func (err *Pair) Unwrap() (unwrapped error) { // WithDeferred is a helper function for deferred errors. For example, to // preserve errors from the Close method, replace this: // -// defer f.Close() +// defer f.Close() // // With this: // -// defer func() { err = errors.WithDeferred(err, f.Close()) } +// defer func() { err = errors.WithDeferred(err, f.Close()) } // // If returned is nil and deferred is non-nil, the returned error implements the // Deferred interface. If both returned and deferred are non-nil, result has // the underlying type of *Pair. // -// Warning +// # Warning // // This function requires that there be only ONE error named "err" in the // function and that it is always the one that is returned. Example (Bad) // provides an example of the incorrect usage of WithDeferred. -func WithDeferred(returned error, deferred error) (result error) { +func WithDeferred(returned, deferred error) (result error) { if deferred == nil { return returned } @@ -219,56 +219,56 @@ func (err *listError) Unwrap() (unwrapped error) { // Annotate annotates the error with the message, unless the error is nil. The // last verb in format must be a verb compatible with errors, for example "%w". // -// In Defers +// # In Defers // // The primary use case for this function is to simplify code like this: // -// func (f *foo) doStuff(s string) (err error) { -// defer func() { -// if err != nil { -// err = fmt.Errorf("bad foo %q: %w", s, err) -// } -// }() +// func (f *foo) doStuff(s string) (err error) { +// defer func() { +// if err != nil { +// err = fmt.Errorf("bad foo %q: %w", s, err) +// } +// }() // -// // … -// } +// // … +// } // // Instead, write: // -// func (f *foo) doStuff(s string) (err error) { -// defer func() { err = errors.Annotate(err, "bad foo %q: %w", s) }() +// func (f *foo) doStuff(s string) (err error) { +// defer func() { err = errors.Annotate(err, "bad foo %q: %w", s) }() // -// // … -// } +// // … +// } // -// At The End Of Functions +// # At The End Of Functions // // Another possible use case is to simplify final checks like this: // -// func (f *foo) doStuff(s string) (err error) { -// // … +// func (f *foo) doStuff(s string) (err error) { +// // … // -// if err != nil { -// return fmt.Errorf("doing stuff with %s: %w", s, err) -// } +// if err != nil { +// return fmt.Errorf("doing stuff with %s: %w", s, err) +// } // -// return nil -// } +// return nil +// } // // Instead, you could write: // -// func (f *foo) doStuff(s string) (err error) { -// // … +// func (f *foo) doStuff(s string) (err error) { +// // … // -// return errors.Annotate(err, "doing stuff with %s: %w", s) -// } +// return errors.Annotate(err, "doing stuff with %s: %w", s) +// } // -// Warning +// # Warning // // This function requires that there be only ONE error named "err" in the // function and that it is always the one that is returned. Example (Bad) // provides an example of the incorrect usage of WithDeferred. -func Annotate(err error, format string, args ...interface{}) (annotated error) { +func Annotate(err error, format string, args ...any) (annotated error) { if err == nil { return nil } diff --git a/vendor/github.com/AdguardTeam/golibs/log/log.go b/vendor/github.com/AdguardTeam/golibs/log/log.go index 2071b484b..3549c11bc 100644 --- a/vendor/github.com/AdguardTeam/golibs/log/log.go +++ b/vendor/github.com/AdguardTeam/golibs/log/log.go @@ -53,7 +53,7 @@ func StartTimer() Timer { } // LogElapsed writes to log message and elapsed time -func (t *Timer) LogElapsed(message string, args ...interface{}) { +func (t *Timer) LogElapsed(message string, args ...any) { var buf strings.Builder buf.WriteString(message) buf.WriteString(fmt.Sprintf("; Elapsed time: %dms", int(time.Since(t.start)/time.Millisecond))) @@ -86,10 +86,7 @@ func GetLevel() (l Level) { // These constants are the same as in the standard package "log". // -// See the output of: -// -// go doc log.Ldate -// +// See the documentation for [log.Ldate], etc. const ( Ldate = 1 << iota Ltime @@ -113,24 +110,24 @@ func SetFlags(flags int) { } // Fatal writes to error log and exits application -func Fatal(args ...interface{}) { +func Fatal(args ...any) { writeLog("fatal", "", "%s", fmt.Sprint(args...)) os.Exit(1) } // Fatalf writes to error log and exits application -func Fatalf(format string, args ...interface{}) { +func Fatalf(format string, args ...any) { writeLog("fatal", "", format, args...) os.Exit(1) } // Error writes to error log -func Error(format string, args ...interface{}) { +func Error(format string, args ...any) { writeLog("error", "", format, args...) } // Panic is equivalent to Print() followed by a call to panic(). -func Panic(args ...interface{}) { +func Panic(args ...any) { s := fmt.Sprint(args...) writeLog("panic", "", "%s", s) @@ -138,7 +135,7 @@ func Panic(args ...interface{}) { } // Panicf is equivalent to Printf() followed by a call to panic(). -func Panicf(format string, args ...interface{}) { +func Panicf(format string, args ...any) { s := fmt.Sprintf(format, args...) writeLog("panic", "", "%s", s) @@ -146,36 +143,36 @@ func Panicf(format string, args ...interface{}) { } // Print writes to info log -func Print(args ...interface{}) { +func Print(args ...any) { Info("%s", fmt.Sprint(args...)) } // Printf writes to info log -func Printf(format string, args ...interface{}) { +func Printf(format string, args ...any) { Info(format, args...) } // Println writes to info log -func Println(args ...interface{}) { +func Println(args ...any) { Info("%s", fmt.Sprint(args...)) } // Info writes to info log -func Info(format string, args ...interface{}) { +func Info(format string, args ...any) { if atomic.LoadUint32(&level) >= uint32(INFO) { writeLog("info", "", format, args...) } } // Debug writes to debug log -func Debug(format string, args ...interface{}) { +func Debug(format string, args ...any) { if atomic.LoadUint32(&level) >= uint32(DEBUG) { writeLog("debug", "", format, args...) } } // Tracef writes to debug log and adds the calling function's name -func Tracef(format string, args ...interface{}) { +func Tracef(format string, args ...any) { if atomic.LoadUint32(&level) >= uint32(DEBUG) { writeLog("debug", getCallerName(), format, args...) } @@ -194,7 +191,7 @@ func goroutineID() uint64 { // Construct a log message and write it // TIME PID#GOID [LEVEL] FUNCNAME(): TEXT -func writeLog(levelStr string, funcName string, format string, args ...interface{}) { +func writeLog(levelStr, funcName, format string, args ...any) { var buf strings.Builder if atomic.LoadUint32(&level) >= uint32(DEBUG) { @@ -237,7 +234,7 @@ func (w *stdLogWriter) Write(p []byte) (n int, err error) { // the message before calling Write. We do the same thing, so trim it. p = bytes.TrimSuffix(p, []byte{'\n'}) - var logFunc func(format string, args ...interface{}) + var logFunc func(format string, args ...any) switch w.level { case ERROR: logFunc = Error @@ -297,11 +294,11 @@ func OnPanicAndExit(prefix string, exitCode int) { // // Instead of: // -// defer f.Close() +// defer f.Close() // // You can now write: // -// defer log.OnCloserError(f, log.DEBUG) +// defer log.OnCloserError(f, log.DEBUG) // // Note that if closer is nil, it is simply ignored. func OnCloserError(closer io.Closer, l Level) { diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/addr.go b/vendor/github.com/AdguardTeam/golibs/netutil/addr.go index 85958f290..ccbc18ac4 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/addr.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/addr.go @@ -10,18 +10,17 @@ import ( "strconv" "strings" + "golang.org/x/exp/slices" "golang.org/x/net/idna" ) // Various Network Address Utilities // CloneMAC returns a clone of a MAC address. +// +// Deprecated: use [slices.Clone]. func CloneMAC(mac net.HardwareAddr) (clone net.HardwareAddr) { - if mac != nil && len(mac) == 0 { - return net.HardwareAddr{} - } - - return append(clone, mac...) + return slices.Clone(mac) } // CloneURL returns a deep clone of u. The User pointer of clone is the same, @@ -58,8 +57,8 @@ func JoinHostPort(host string, port int) (hostport string) { return net.JoinHostPort(strings.Trim(host, "[]"), strconv.Itoa(port)) } -// SplitHostPort is a convenient wrapper for net.SplitHostPort with port of type -// int. +// SplitHostPort is a convenient wrapper for [net.SplitHostPort] with port of +// type int. func SplitHostPort(hostport string) (host string, port int, err error) { var portStr string host, portStr, err = net.SplitHostPort(hostport) @@ -76,8 +75,8 @@ func SplitHostPort(hostport string) (host string, port int, err error) { return host, int(portUint), nil } -// SplitHost is a wrapper for net.SplitHostPort for cases when the hostport may -// or may not contain a port. +// SplitHost is a wrapper for [net.SplitHostPort] for cases when the hostport +// may or may not contain a port. func SplitHost(hostport string) (host string, err error) { host, _, err = net.SplitHostPort(hostport) if err != nil { @@ -124,7 +123,7 @@ func Subdomains(domain string) (sub []string) { // ValidateMAC returns an error if mac is not a valid EUI-48, EUI-64, or // 20-octet InfiniBand link-layer address. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func ValidateMAC(mac net.HardwareAddr) (err error) { defer makeAddrError(&err, mac.String(), AddrKindMAC) @@ -143,19 +142,23 @@ func ValidateMAC(mac net.HardwareAddr) (err error) { } // MaxDomainLabelLen is the maximum allowed length of a domain name label -// according to RFC 1035. +// according to [RFC 1035]. +// +// [RFC 1035]: https://datatracker.ietf.org/doc/html/rfc1035 const MaxDomainLabelLen = 63 // MaxDomainNameLen is the maximum allowed length of a full domain name -// according to RFC 1035. +// according to [RFC 1035]. // // See also: https://stackoverflow.com/a/32294443/1892060. +// +// [RFC 1035]: https://datatracker.ietf.org/doc/html/rfc1035 const MaxDomainNameLen = 253 // ValidateDomainNameLabel returns an error if label is not a valid label of // a domain name. An empty label is considered invalid. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func ValidateDomainNameLabel(label string) (err error) { defer makeAddrError(&err, label, AddrKindLabel) @@ -201,11 +204,14 @@ func ValidateDomainNameLabel(label string) (err error) { } // ValidateDomainName validates the domain name in accordance to RFC 952, -// RFC 1035, and with RFC 1123's inclusion of digits at the start of the host. -// It doesn't validate against two or more hyphens to allow punycode and +// [RFC 1035], and with [RFC 1123]'s inclusion of digits at the start of the +// host. It doesn't validate against two or more hyphens to allow punycode and // internationalized domains. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. +// +// [RFC 1035]: https://datatracker.ietf.org/doc/html/rfc1035 +// [RFC 1123]: https://datatracker.ietf.org/doc/html/rfc1123 func ValidateDomainName(name string) (err error) { defer makeAddrError(&err, name, AddrKindName) @@ -236,13 +242,15 @@ func ValidateDomainName(name string) (err error) { } // MaxServiceLabelLen is the maximum allowed length of a service name label -// according to RFC 6335. +// according to [RFC 6335]. +// +// [RFC 6335]: https://datatracker.ietf.org/doc/html/rfc6335 const MaxServiceLabelLen = 16 // ValidateServiceNameLabel returns an error if label is not a valid label of // a service domain name. An empty label is considered invalid. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func ValidateServiceNameLabel(label string) (err error) { defer makeAddrError(&err, label, AddrKindSRVLabel) @@ -266,7 +274,7 @@ func ValidateServiceNameLabel(label string) (err error) { // TODO(e.burkov): Validate adjacent hyphens since service labels can't be // internationalized. See RFC 6336 Section 5.1. - if err := ValidateDomainNameLabel(label[1:]); err != nil { + if err = ValidateDomainNameLabel(label[1:]); err != nil { err = errors.Unwrap(err) if rerr, ok := err.(*RuneError); ok { rerr.Kind = AddrKindSRVLabel @@ -279,11 +287,14 @@ func ValidateServiceNameLabel(label string) (err error) { } // ValidateSRVDomainName validates of domain name assuming it belongs to the -// superset of service domain names in accordance to RFC 2782 and RFC 6763. It -// doesn't validate against two or more hyphens to allow punycode and +// superset of service domain names in accordance to [RFC 2782] and [RFC 6763]. +// It doesn't validate against two or more hyphens to allow punycode and // internationalized domains. // // Any error returned will have the underlying type of *AddrError. +// +// [RFC 2782]: https://datatracker.ietf.org/doc/html/rfc2782 +// [RFC 6763]: https://datatracker.ietf.org/doc/html/rfc6763 func ValidateSRVDomainName(name string) (err error) { defer makeAddrError(&err, name, AddrKindSRVName) diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/addrconv.go b/vendor/github.com/AdguardTeam/golibs/netutil/addrconv.go new file mode 100644 index 000000000..89afda8be --- /dev/null +++ b/vendor/github.com/AdguardTeam/golibs/netutil/addrconv.go @@ -0,0 +1,124 @@ +package netutil + +import ( + "fmt" + "net" + "net/netip" + + "github.com/AdguardTeam/golibs/errors" +) + +// IPv4Localhost returns the IPv4 localhost address "127.0.0.1". +func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } + +// IPv6Localhost returns the IPv6 localhost address "::1". +func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) } + +// ZeroPrefix returns an IP subnet with prefix 0 and all bytes of the IP address +// set to 0. fam must be either [AddrFamilyIPv4] or [AddrFamilyIPv6]. +func ZeroPrefix(fam AddrFamily) (n netip.Prefix) { + switch fam { + case AddrFamilyIPv4: + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) + case AddrFamilyIPv6: + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + default: + panic(badAddrFam("ZeroPrefix", fam)) + } +} + +// badAddrFam is a helper that returns an informative error for panics caused by +// bad address-family values. +func badAddrFam(fn string, fam AddrFamily) (err error) { + return fmt.Errorf("netutil.%s: bad address family %d", fn, fam) +} + +// IPToAddr converts a [net.IP] into a [netip.Addr] of the given family and +// returns a meaningful error. ip should not be nil. fam must be either +// [AddrFamilyIPv4] or [AddrFamilyIPv6]. +// +// See also [IPToAddrNoMapped]. +func IPToAddr(ip net.IP, fam AddrFamily) (addr netip.Addr, err error) { + if ip == nil { + return netip.Addr{}, errors.Error("nil ip") + } + + switch fam { + case AddrFamilyIPv4: + // Make sure that we use the IPv4 form of the address to make sure that + // netip.Addr doesn't turn out to be an IPv6 one when it really should + // be an IPv4 one. + ip4 := ip.To4() + if ip4 == nil { + return netip.Addr{}, fmt.Errorf("bad ipv4 net.IP %v", ip) + } + + ip = ip4 + case AddrFamilyIPv6: + // Again, make sure that we use the correct form according to the + // address family. + ip = ip.To16() + default: + panic(badAddrFam("IPToAddr", fam)) + } + + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, fmt.Errorf("bad net.IP value %v", ip) + } + + return addr, nil +} + +// IPToAddrNoMapped is like [IPToAddr] but it detects the address family +// automatically by assuming that every IPv6-mapped IPv4 address is actually an +// IPv4 address. Do not use IPToAddrNoMapped where this assumption isn't safe. +func IPToAddrNoMapped(ip net.IP) (addr netip.Addr, err error) { + if ip4 := ip.To4(); ip4 != nil { + return IPToAddr(ip4, AddrFamilyIPv4) + } + + return IPToAddr(ip, AddrFamilyIPv6) +} + +// IPNetToPrefix is a helper that converts a [*net.IPNet] into a [netip.Prefix]. +// subnet should not be nil. fam must be either [AddrFamilyIPv4] or +// [AddrFamilyIPv6]. +// +// See also [IPNetToPrefixNoMapped]. +func IPNetToPrefix(subnet *net.IPNet, fam AddrFamily) (p netip.Prefix, err error) { + if subnet == nil { + return netip.Prefix{}, errors.Error("nil subnet") + } + + addr, err := IPToAddr(subnet.IP, fam) + if err != nil { + return netip.Prefix{}, fmt.Errorf("bad ip for subnet %v: %w", subnet, err) + } + + ones, _ := subnet.Mask.Size() + p = netip.PrefixFrom(addr, ones) + if !p.IsValid() { + return netip.Prefix{}, fmt.Errorf("bad subnet %v", subnet) + } + + return p, nil +} + +// IPNetToPrefixNoMapped is like [IPNetToPrefix] but it detects the address +// family automatically by assuming that every IPv6-mapped IPv4 address is +// actually an IPv4 address. Do not use IPNetToPrefixNoMapped where this +// assumption isn't safe. +func IPNetToPrefixNoMapped(subnet *net.IPNet) (p netip.Prefix, err error) { + if subnet == nil { + return netip.Prefix{}, errors.Error("nil subnet") + } + + if ip4 := subnet.IP.To4(); ip4 != nil { + subnet.IP = ip4 + + return IPNetToPrefix(subnet, AddrFamilyIPv4) + } + + return IPNetToPrefix(subnet, AddrFamilyIPv6) +} diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/addrfam.go b/vendor/github.com/AdguardTeam/golibs/netutil/addrfam.go new file mode 100644 index 000000000..e41e86d06 --- /dev/null +++ b/vendor/github.com/AdguardTeam/golibs/netutil/addrfam.go @@ -0,0 +1,56 @@ +package netutil + +import ( + "fmt" +) + +// AddrFamily is the type for IANA address family numbers. +type AddrFamily uint16 + +// An incomplete list of IANA address family numbers. +// +// See https://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml. +const ( + AddrFamilyNone AddrFamily = 0 + AddrFamilyIPv4 AddrFamily = 1 + AddrFamilyIPv6 AddrFamily = 2 +) + +// type check +var _ fmt.Stringer = AddrFamilyNone + +// String implements the [fmt.Stringer] interface for AddrFamily. +func (f AddrFamily) String() (s string) { + switch f { + case AddrFamilyNone: + return "none" + case AddrFamilyIPv4: + return "ipv4" + case AddrFamilyIPv6: + return "ipv6" + default: + return fmt.Sprintf("!bad_addr_fam_%d", f) + } +} + +// Constants to avoid a dependency on github.com/miekg/dns. +// +// See https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4. +const ( + dnsTypeA uint16 = 1 + dnsTypeAAAA uint16 = 28 +) + +// AddrFamilyFromRRType returns an AddrFamily appropriate for the DNS resource +// record type rr. That is, [AddrFamilyIPv4] for DNS type A (1), +// [AddrFamilyIPv6] for DNS type AAAA (28), and [AddrFamilyNone] otherwise. +func AddrFamilyFromRRType(rr uint16) (fam AddrFamily) { + switch rr { + case dnsTypeA: + return AddrFamilyIPv4 + case dnsTypeAAAA: + return AddrFamilyIPv6 + default: + return AddrFamilyNone + } +} diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/error.go b/vendor/github.com/AdguardTeam/golibs/netutil/error.go index 878a92fa6..91671d75c 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/error.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/error.go @@ -52,12 +52,17 @@ const ( type AddrError struct { // Err is the underlying error, if any. Err error + // Kind is the kind of address or address part. Kind AddrKind + // Addr is the text of the invalid address. Addr string } +// type check +var _ error = (*AddrError)(nil) + // Error implements the error interface for *AddrError. func (err *AddrError) Error() (msg string) { if err.Err != nil { @@ -67,17 +72,19 @@ func (err *AddrError) Error() (msg string) { return fmt.Sprintf("bad %s %q", err.Kind, err.Addr) } -// Unwrap implements the errors.Wrapper interface for *AddrError. It returns +// type check +var _ errors.Wrapper = (*AddrError)(nil) + +// Unwrap implements the [errors.Wrapper] interface for *AddrError. It returns // err.Err. func (err *AddrError) Unwrap() (unwrapped error) { return err.Err } -// makeAddrError is a deferrable helper for functions that return *AddrError. +// makeAddrError is a deferrable helper for functions that return [*AddrError]. // errPtr must be non-nil. Usage example: // -// defer makeAddrError(&err, addr, AddrKindARPA) -// +// defer makeAddrError(&err, addr, AddrKindARPA) func makeAddrError(errPtr *error, addr string, k AddrKind) { err := *errPtr if err == nil { @@ -96,12 +103,15 @@ func makeAddrError(errPtr *error, addr string, k AddrKind) { type LengthError struct { // Kind is the kind of address or address part. Kind AddrKind + // Allowed are the allowed lengths for this kind of address. If allowed // is empty, Max should be non-zero. Allowed []int + // Max is the maximum length for this part or address kind. If Max is // zero, Allowed should be non-empty. Max int + // Length is the length of the provided address. Length int } @@ -125,6 +135,7 @@ func (err *LengthError) Error() (msg string) { type RuneError struct { // Kind is the kind of address or address part. Kind AddrKind + // Rune is the invalid rune. Rune rune } diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/hostport.go b/vendor/github.com/AdguardTeam/golibs/netutil/hostport.go index fd23d3e70..73afdb32d 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/hostport.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/hostport.go @@ -1,5 +1,10 @@ package netutil +import ( + "encoding" + "fmt" +) + // HostPort And Utilities // HostPort is a convenient type for addresses that contain a hostname and @@ -10,7 +15,7 @@ type HostPort struct { } // ParseHostPort parses a HostPort from addr. Any error returned will have the -// underlying type of *AddrError. +// underlying type of [*AddrError]. func ParseHostPort(addr string) (hp *HostPort, err error) { defer makeAddrError(&err, addr, AddrKindHostPort) @@ -53,18 +58,27 @@ func (hp *HostPort) Clone() (clone *HostPort) { } } -// MarshalText implements the encoding.TextMarshaler interface for HostPort. +// type check +var _ encoding.TextMarshaler = HostPort{} + +// MarshalText implements the [encoding.TextMarshaler] interface for HostPort. func (hp HostPort) MarshalText() (b []byte, err error) { return []byte(hp.String()), nil } -// String implements the fmt.Stringer interface for *HostPort. +// type check +var _ fmt.Stringer = HostPort{} + +// String implements the [fmt.Stringer] interface for *HostPort. func (hp HostPort) String() (s string) { return JoinHostPort(hp.Host, hp.Port) } -// UnmarshalText implements the encoding.TextUnmarshaler interface for -// *HostPort. Any error returned will have the underlying type of *AddrError. +// type check +var _ encoding.TextUnmarshaler = (*HostPort)(nil) + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface for +// *HostPort. Any error returned will have the underlying type of [*AddrError]. func (hp *HostPort) UnmarshalText(b []byte) (err error) { var newHP *HostPort newHP, err = ParseHostPort(string(b)) diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/ip.go b/vendor/github.com/AdguardTeam/golibs/netutil/ip.go index ee46ae776..53e953f55 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/ip.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/ip.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "strings" + + "golang.org/x/exp/slices" ) // IP Address Constants And Utilities @@ -16,12 +18,10 @@ const ( // CloneIP returns a clone of an IP address that doesn't share the same // underlying array with it. +// +// Deprecated: use slices.Clone. func CloneIP(ip net.IP) (clone net.IP) { - if ip != nil && len(ip) == 0 { - return net.IP{} - } - - return append(clone, ip...) + return slices.Clone(ip) } // CloneIPs returns a deep clone of ips. @@ -32,14 +32,14 @@ func CloneIPs(ips []net.IP) (clone []net.IP) { clone = make([]net.IP, len(ips)) for i, ip := range ips { - clone[i] = CloneIP(ip) + clone[i] = slices.Clone(ip) } return clone } // IPAndPortFromAddr returns the IP address and the port from addr. If addr is -// neither a *net.TCPAddr nor a *net.UDPAddr, it returns nil and 0. +// neither a [*net.TCPAddr] nor a [*net.UDPAddr], it returns nil and 0. func IPAndPortFromAddr(addr net.Addr) (ip net.IP, port int) { switch addr := addr.(type) { case *net.TCPAddr: @@ -79,7 +79,7 @@ func IPv6Zero() (ip net.IP) { // ParseIP is a wrapper around net.ParseIP that returns a useful error. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func ParseIP(s string) (ip net.IP, err error) { ip = net.ParseIP(s) if ip == nil { @@ -95,7 +95,7 @@ func ParseIP(s string) (ip net.IP, err error) { // ParseIPv4 is a wrapper around net.ParseIP that makes sure that the parsed IP // is an IPv4 address and returns a useful error. // -// Any error returned will have the underlying type of either *AddrError. +// Any error returned will have the underlying type of either [*AddrError]. func ParseIPv4(s string) (ip net.IP, err error) { ip, err = ParseIP(s) if err != nil { @@ -121,9 +121,9 @@ func CloneIPNet(n *net.IPNet) (clone *net.IPNet) { } return &net.IPNet{ - IP: CloneIP(n.IP), + IP: slices.Clone(n.IP), // TODO(e.burkov): Consider adding CloneIPMask. - Mask: net.IPMask(CloneIP(net.IP(n.Mask))), + Mask: net.IPMask(slices.Clone(net.IP(n.Mask))), } } @@ -134,7 +134,7 @@ func CloneIPNet(n *net.IPNet) (clone *net.IPNet) { // If s contains a CIDR with an IP address that is an IPv4-mapped IPv6 address, // the behavior is undefined. // -// Any error returned will have the underlying type of either *AddrError. +// Any error returned will have the underlying type of either [*AddrError]. func ParseSubnet(s string) (n *net.IPNet, err error) { var ip net.IP @@ -220,7 +220,7 @@ func ParseSubnets(ss ...string) (ns []*net.IPNet, err error) { // ValidateIP returns an error if ip is not a valid IPv4 or IPv6 address. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func ValidateIP(ip net.IP) (err error) { // TODO(a.garipov): Get rid of unnecessary allocations in case of valid IP. defer makeAddrError(&err, ip.String(), AddrKindIP) diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/ipmap.go b/vendor/github.com/AdguardTeam/golibs/netutil/ipmap.go index b62e96688..2bcdcd52e 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/ipmap.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/ipmap.go @@ -25,8 +25,10 @@ func ipToArr(ip net.IP) (a ipArr) { } // IPMap is a map of IP addresses. +// +// Deprecated: Use map[netip.Addr]T instead. type IPMap struct { - m map[ipArr]interface{} + m map[ipArr]any } // NewIPMap returns a new empty IP map using hint as a size hint for the @@ -35,7 +37,7 @@ type IPMap struct { // It is not safe for concurrent use, just like the usual Go maps aren't. func NewIPMap(hint int) (m *IPMap) { return &IPMap{ - m: make(map[ipArr]interface{}, hint), + m: make(map[ipArr]any, hint), } } @@ -63,7 +65,7 @@ func (m *IPMap) Del(ip net.IP) { // Get returns the value from the map. Calling Get on a nil *IPMap returns nil // and false, just like indexing on an empty map does. -func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) { +func (m *IPMap) Get(ip net.IP) (v any, ok bool) { if m != nil { v, ok = m.m[ipToArr(ip)] @@ -87,7 +89,7 @@ func (m *IPMap) Len() (n int) { // present in the map in an undefined order. If cont is false, range stops the // iteration. Calling Range on a nil *IPMap has no effect, just like ranging // over a nil map. -func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) { +func (m *IPMap) Range(f func(ip net.IP, v any) (cont bool)) { if m == nil { return } @@ -106,7 +108,7 @@ func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) { // Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map // does. -func (m *IPMap) Set(ip net.IP, v interface{}) { +func (m *IPMap) Set(ip net.IP, v any) { if m == nil { panic(errors.Error("assignment to entry in nil *netutil.IPMap")) } @@ -121,7 +123,7 @@ func (m *IPMap) ShallowClone() (sclone *IPMap) { } sclone = NewIPMap(m.Len()) - m.Range(func(ip net.IP, v interface{}) (cont bool) { + m.Range(func(ip net.IP, v any) (cont bool) { sclone.Set(ip, v) return true diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/ipport.go b/vendor/github.com/AdguardTeam/golibs/netutil/ipport.go index e12413474..eacfba5e6 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/ipport.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/ipport.go @@ -1,11 +1,17 @@ package netutil -import "net" +import ( + "net" + + "golang.org/x/exp/slices" +) // IPPort And Utilities // IPPort is a convenient type for network addresses that contain an IP address // and a port, like "1.2.3.4:56789" or "[1234::cdef]:12345". +// +// Deprecated: use netip.AddrPort. type IPPort struct { IP net.IP Port int @@ -20,7 +26,7 @@ func IPPortFromAddr(a net.Addr) (ipp *IPPort) { } return &IPPort{ - IP: CloneIP(ip), + IP: slices.Clone(ip), Port: port, } } @@ -70,7 +76,7 @@ func (ipp *IPPort) Clone() (clone *IPPort) { } return &IPPort{ - IP: CloneIP(ipp.IP), + IP: slices.Clone(ipp.IP), Port: ipp.Port, } } @@ -93,7 +99,7 @@ func (ipp IPPort) String() (s string) { // TCP returns a *net.TCPAddr with a clone of ipp's IP address and its port. func (ipp *IPPort) TCP() (a *net.TCPAddr) { return &net.TCPAddr{ - IP: CloneIP(ipp.IP), + IP: slices.Clone(ipp.IP), Port: ipp.Port, } } @@ -101,7 +107,7 @@ func (ipp *IPPort) TCP() (a *net.TCPAddr) { // UDP returns a *net.UDPAddr with a clone of ipp's IP address and its port. func (ipp *IPPort) UDP() (a *net.UDPAddr) { return &net.UDPAddr{ - IP: CloneIP(ipp.IP), + IP: slices.Clone(ipp.IP), Port: ipp.Port, } } diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/reversed.go b/vendor/github.com/AdguardTeam/golibs/netutil/reversed.go index 22332a08f..6b791a5a7 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/reversed.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/reversed.go @@ -39,12 +39,11 @@ const ( // // An example of IPv4 with a maximum length: // -// 49.91.20.104.in-addr.arpa +// 49.91.20.104.in-addr.arpa // // An example of IPv6 with a maximum length: // -// 1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa -// +// 1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa const ( arpaV4MaxIPLen = len("000.000.000.000") arpaV6MaxIPLen = len("0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0") @@ -156,7 +155,7 @@ func IPFromReversedAddr(arpa string) (ip net.IP, err error) { // DNS (PTR) record lookups. This is a modified version of function ReverseAddr // from package github.com/miekg/dns package that accepts an IP. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func IPToReversedAddr(ip net.IP) (arpa string, err error) { const dot = "." @@ -332,7 +331,7 @@ func subnetFromReversedV6(arpa string) (subnet *net.IPNet, err error) { // SubnetFromReversedAddr tries to convert a reversed ARPA address to an IP // network. arpa can be domain name or an FQDN. // -// Any error returned will have the underlying type of *AddrError. +// Any error returned will have the underlying type of [*AddrError]. func SubnetFromReversedAddr(arpa string) (subnet *net.IPNet, err error) { arpa = strings.TrimSuffix(arpa, ".") err = ValidateDomainName(arpa) diff --git a/vendor/github.com/AdguardTeam/golibs/netutil/subnetset.go b/vendor/github.com/AdguardTeam/golibs/netutil/subnetset.go index ff6d2b669..b54a4c86d 100644 --- a/vendor/github.com/AdguardTeam/golibs/netutil/subnetset.go +++ b/vendor/github.com/AdguardTeam/golibs/netutil/subnetset.go @@ -47,25 +47,27 @@ func (f SubnetSetFunc) Contains(ip net.IP) (ok bool) { return f(ip) } // Optimized Implementations Of Some Commonly Used Sets Of Networks -// IsLocallyServed checks if ip belongs to any network defined by RFC 6303: +// IsLocallyServed checks if ip belongs to any network defined by [RFC 6303]: // -// 10.0.0.0/8 -// 172.16.0.0/12 -// 192.168.0.0/16 -// 127.0.0.0/8 -// 169.254.0.0/16 -// 192.0.2.0/24 -// 198.51.100.0/24 -// 203.0.113.0/24 -// 255.255.255.255/32 +// 10.0.0.0/8 +// 127.0.0.0/8 +// 169.254.0.0/16 +// 172.16.0.0/12 +// 192.0.2.0/24 +// 192.168.0.0/16 +// 198.51.100.0/24 +// 203.0.113.0/24 +// 255.255.255.255/32 // -// ::/128 -// ::1/128 -// fe80::/10 -// 2001:db8::/32 -// fd00::/8 +// ::/128 +// ::1/128 +// 2001:db8::/32 +// fd00::/8 +// fe80::/10 // -// It may also be used as a SubnetSetFunc. +// It may also be used as a [SubnetSetFunc]. +// +// [RFC 6303]: https://datatracker.ietf.org/doc/html/rfc6303 func IsLocallyServed(ip net.IP) (ok bool) { if ip == nil { return false @@ -81,9 +83,11 @@ func IsLocallyServed(ip net.IP) (ok bool) { } // isLocallyServedV6 returns true if ip belongs to at least one of networks -// listed in RFC 6303. The ip is expected to be a valid IPv6. +// listed in [RFC 6303]. The ip is expected to be a valid IPv6. +// +// See also [IsLocallyServed]. // -// See go doc IsLocallyServed. +// [RFC 6303]: https://datatracker.ietf.org/doc/html/rfc6303 func isLocallyServedV6(ip net.IP) (ok bool) { switch ip[0] { case 0x00: @@ -99,9 +103,11 @@ func isLocallyServedV6(ip net.IP) (ok bool) { } // isLocallyServedV4 returns true if ip belongs to at least one of networks -// listed in RFC 6303. The ip is expected to be a valid IPv4. +// listed in [RFC 6303]. The ip is expected to be a valid IPv4. +// +// See also [IsLocallyServed]. // -// See go doc IsLocallyServed. +// [RFC 6303]: https://datatracker.ietf.org/doc/html/rfc6303 func isLocallyServedV4(ip net.IP) (ok bool) { switch ip[0] { case 10, 127: @@ -124,47 +130,47 @@ func isLocallyServedV4(ip net.IP) (ok bool) { // IsSpecialPurpose checks if ip belongs to any network defined by IANA // Special-Purpose Address Registry: // -// 0.0.0.0/8 "This host on this network". -// 10.0.0.0/8 Private-Use. -// 100.64.0.0/10 Shared Address Space. -// 127.0.0.0/8 Loopback. -// 169.254.0.0/16 Link Local. -// 172.16.0.0/12 Private-Use. -// 192.0.0.0/24 IETF Protocol Assignments. -// 192.0.0.0/29 DS-Lite. -// 192.0.2.0/24 Documentation (TEST-NET-1) -// 192.88.99.0/24 6to4 Relay Anycast. -// 192.168.0.0/16 Private-Use. -// 198.18.0.0/15 Benchmarking. -// 198.51.100.0/24 Documentation (TEST-NET-2). -// 203.0.113.0/24 Documentation (TEST-NET-3). -// 240.0.0.0/4 Reserved. -// 255.255.255.255/32 Limited Broadcast. +// 0.0.0.0/8 "This host on this network". +// 10.0.0.0/8 Private-Use. +// 100.64.0.0/10 Shared Address Space. +// 127.0.0.0/8 Loopback. +// 169.254.0.0/16 Link Local. +// 172.16.0.0/12 Private-Use. +// 192.0.0.0/24 IETF Protocol Assignments. +// 192.0.0.0/29 DS-Lite. +// 192.0.2.0/24 Documentation (TEST-NET-1) +// 192.88.99.0/24 6to4 Relay Anycast. +// 192.168.0.0/16 Private-Use. +// 198.18.0.0/15 Benchmarking. +// 198.51.100.0/24 Documentation (TEST-NET-2). +// 203.0.113.0/24 Documentation (TEST-NET-3). +// 240.0.0.0/4 Reserved. +// 255.255.255.255/32 Limited Broadcast. // -// ::/128 Unspecified Address. -// ::1/128 Loopback Address. -// 64:ff9b::/96 IPv4-IPv6 Translation Address. -// 64:ff9b:1::/48 IPv4-IPv6 Translation Address. -// 100::/64 Discard-Only Address Block. -// 2001::/23 IETF Protocol Assignments. -// 2001::/32 TEREDO. -// 2001:1::1/128 Port Control Protocol Anycast. -// 2001:1::2/128 Traversal Using Relays around NAT Anycast. -// 2001:2::/48 Benchmarking. -// 2001:3::/32 AMT. -// 2001:4:112::/48 AS112-v6. -// 2001:10::/28 ORCHID. -// 2001:20::/28 ORCHIDv2. -// 2001:db8::/32 Documentation. -// 2002::/16 6to4. -// 2620:4f:8000::/48 Direct Delegation AS112 Service. -// fc00::/7 Unique-Local. -// fe80::/10 Linked-Scoped Unicast. +// ::/128 Unspecified Address. +// ::1/128 Loopback Address. +// 64:ff9b::/96 IPv4-IPv6 Translation Address. +// 64:ff9b:1::/48 IPv4-IPv6 Translation Address. +// 100::/64 Discard-Only Address Block. +// 2001::/23 IETF Protocol Assignments. +// 2001::/32 TEREDO. +// 2001:1::1/128 Port Control Protocol Anycast. +// 2001:1::2/128 Traversal Using Relays around NAT Anycast. +// 2001:2::/48 Benchmarking. +// 2001:3::/32 AMT. +// 2001:4:112::/48 AS112-v6. +// 2001:10::/28 ORCHID. +// 2001:20::/28 ORCHIDv2. +// 2001:db8::/32 Documentation. +// 2002::/16 6to4. +// 2620:4f:8000::/48 Direct Delegation AS112 Service. +// fc00::/7 Unique-Local. +// fe80::/10 Linked-Scoped Unicast. // // See https://www.iana.org/assignments/iana-ipv4-special-registry and // https://www.iana.org/assignments/iana-ipv6-special-registry. // -// It may also be used as a SubnetSetFunc. +// It may also be used as a [SubnetSetFunc]. func IsSpecialPurpose(ip net.IP) (ok bool) { if ip == nil { return false @@ -183,7 +189,7 @@ func IsSpecialPurpose(ip net.IP) (ok bool) { // from special-purpose address registries. The ip is expected to be a valid // IPv6. // -// See go doc IsSpecialPurpose. +// See also [IsSpecialPurpose]. func isSpecialPurposeV6(ip net.IP) (ok bool) { switch ip[0] { case 0x00: @@ -208,7 +214,7 @@ func isSpecialPurposeV6(ip net.IP) (ok bool) { // from special-purpose address registries. The ip is expected to be a valid // IPv4. // -// See go doc IsSpecialPurpose. +// See also [IsSpecialPurpose]. func isSpecialPurposeV4(ip net.IP) (ok bool) { switch ip[0] { case 0: diff --git a/vendor/github.com/AdguardTeam/golibs/testutil/log.go b/vendor/github.com/AdguardTeam/golibs/testutil/log.go new file mode 100644 index 000000000..d589f63f1 --- /dev/null +++ b/vendor/github.com/AdguardTeam/golibs/testutil/log.go @@ -0,0 +1,17 @@ +package testutil + +import ( + "io" + "log" + "os" + "testing" +) + +// DiscardLogOutput runs tests with discarded logger output. +// +// TODO(a.garipov): Refactor project that use this to not use a global logger. +func DiscardLogOutput(m *testing.M) { + log.SetOutput(io.Discard) + + os.Exit(m.Run()) +} diff --git a/vendor/github.com/AdguardTeam/golibs/testutil/panict.go b/vendor/github.com/AdguardTeam/golibs/testutil/panict.go index 42410582d..c81f566fb 100644 --- a/vendor/github.com/AdguardTeam/golibs/testutil/panict.go +++ b/vendor/github.com/AdguardTeam/golibs/testutil/panict.go @@ -19,7 +19,7 @@ var _ require.TestingT = PanicT{} // Errorf implements the require.TestingT interface for PanicT. It panics with // an error with the given format. -func (PanicT) Errorf(format string, args ...interface{}) { +func (PanicT) Errorf(format string, args ...any) { panic(fmt.Errorf(format, args...)) } diff --git a/vendor/github.com/AdguardTeam/golibs/testutil/sync.go b/vendor/github.com/AdguardTeam/golibs/testutil/sync.go new file mode 100644 index 000000000..5e2262a6f --- /dev/null +++ b/vendor/github.com/AdguardTeam/golibs/testutil/sync.go @@ -0,0 +1,49 @@ +package testutil + +import ( + "time" + + "github.com/stretchr/testify/require" +) + +// RequireSend waits until v is sent to ch or until the timeout is exceeded. If +// the timeout is exceeded, the test is failed. +func RequireSend[T any](t require.TestingT, ch chan<- T, v T, timeout time.Duration) { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case ch <- v: + // Go on. + case <-timer.C: + t.Errorf("did not send after %s", timeout) + t.FailNow() + } +} + +// RequireReceive waits until res is received from ch or until the timeout is +// exceeded. If the timeout is exceeded, the test is failed. +func RequireReceive[T any](t require.TestingT, ch <-chan T, timeout time.Duration) (res T, ok bool) { + if h, isHelper := t.(interface{ Helper() }); isHelper { + h.Helper() + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case res, ok = <-ch: + return res, ok + case <-timer.C: + t.Errorf("did not receive after %s", timeout) + t.FailNow() + + } + + // Generally unreachable. + return res, ok +} diff --git a/vendor/github.com/AdguardTeam/golibs/testutil/testutil.go b/vendor/github.com/AdguardTeam/golibs/testutil/testutil.go index 845a51ff0..422399417 100644 --- a/vendor/github.com/AdguardTeam/golibs/testutil/testutil.go +++ b/vendor/github.com/AdguardTeam/golibs/testutil/testutil.go @@ -65,25 +65,25 @@ func newStringCodecChecker(s string) (c *stringCodecChecker) { // newGenericCodecChecker constructs a pointer to value of a type similar to the // following: // -// type checker struct { -// PtrMap map[string]*T `json:"ptr_map"` -// Map map[string]T `json:"map"` +// type checker struct { +// PtrMap map[string]*T `json:"ptr_map"` +// Map map[string]T `json:"map"` // -// PtrValue *T `json:"ptr_value"` -// Value T `json:"value"` +// PtrValue *T `json:"ptr_value"` +// Value T `json:"value"` // -// PtrArray [1]*T `json:"ptr_array"` -// Array [1]T `json:"array"` +// PtrArray [1]*T `json:"ptr_array"` +// Array [1]T `json:"array"` // -// PtrSlice []*T `json:"ptr_slice"` -// Slice []T `json:"slice"` -// } +// PtrSlice []*T `json:"ptr_slice"` +// Slice []T `json:"slice"` +// } // // where T is the type v points to. The slice and pointer fields are properly // initialized. // // TODO(a.garipov): Redo this with type parameters in Go 1.18. -func newGenericCodecChecker(v interface{}) (checkerVal reflect.Value) { +func newGenericCodecChecker(v any) (checkerVal reflect.Value) { strTyp := reflect.TypeOf("") ptrTyp := reflect.TypeOf(v) @@ -142,7 +142,7 @@ func newGenericCodecChecker(v interface{}) (checkerVal reflect.Value) { } // assignGenericCodecChecker assigns all fields to v or the value v points to. -func assignGenericCodecChecker(checkerVal reflect.Value, v interface{}) { +func assignGenericCodecChecker(checkerVal reflect.Value, v any) { keyVal := reflect.ValueOf("1") valPtr := reflect.ValueOf(v) val := valPtr.Elem() @@ -217,3 +217,13 @@ func CleanupAndRequireSuccess(t testing.TB, f func() (err error)) { require.NoError(t, err) }) } + +// RequireTypeAssert is a helper that first requires the desired type and then, +// if the type is correct, converts and returns the value. +func RequireTypeAssert[T any](t testing.TB, v any) (res T) { + t.Helper() + + require.IsType(t, res, v) + + return v.(T) +} diff --git a/vendor/golang.org/x/exp/slices/slices.go b/vendor/golang.org/x/exp/slices/slices.go new file mode 100644 index 000000000..a9fe63f52 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/slices.go @@ -0,0 +1,231 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package slices defines various functions useful with slices of any type. +// Unless otherwise specified, these functions all apply to the elements +// of a slice at index 0 <= i < len(s). +// +// Note that the less function in IsSortedFunc, SortFunc, SortStableFunc requires a +// strict weak ordering (https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings), +// or the sorting may fail to sort correctly. A common case is when sorting slices of +// floating-point numbers containing NaN values. +package slices + +import "golang.org/x/exp/constraints" + +// Equal reports whether two slices are equal: the same length and all +// elements equal. If the lengths are different, Equal returns false. +// Otherwise, the elements are compared in increasing index order, and the +// comparison stops at the first unequal pair. +// Floating point NaNs are not considered equal. +func Equal[E comparable](s1, s2 []E) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + +// EqualFunc reports whether two slices are equal using a comparison +// function on each pair of elements. If the lengths are different, +// EqualFunc returns false. Otherwise, the elements are compared in +// increasing index order, and the comparison stops at the first index +// for which eq returns false. +func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { + if len(s1) != len(s2) { + return false + } + for i, v1 := range s1 { + v2 := s2[i] + if !eq(v1, v2) { + return false + } + } + return true +} + +// Compare compares the elements of s1 and s2. +// The elements are compared sequentially, starting at index 0, +// until one element is not equal to the other. +// The result of comparing the first non-matching elements is returned. +// If both slices are equal until one of them ends, the shorter slice is +// considered less than the longer one. +// The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. +// Comparisons involving floating point NaNs are ignored. +func Compare[E constraints.Ordered](s1, s2 []E) int { + s2len := len(s2) + for i, v1 := range s1 { + if i >= s2len { + return +1 + } + v2 := s2[i] + switch { + case v1 < v2: + return -1 + case v1 > v2: + return +1 + } + } + if len(s1) < s2len { + return -1 + } + return 0 +} + +// CompareFunc is like Compare but uses a comparison function +// on each pair of elements. The elements are compared in increasing +// index order, and the comparisons stop after the first time cmp +// returns non-zero. +// The result is the first non-zero result of cmp; if cmp always +// returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), +// and +1 if len(s1) > len(s2). +func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { + s2len := len(s2) + for i, v1 := range s1 { + if i >= s2len { + return +1 + } + v2 := s2[i] + if c := cmp(v1, v2); c != 0 { + return c + } + } + if len(s1) < s2len { + return -1 + } + return 0 +} + +// Index returns the index of the first occurrence of v in s, +// or -1 if not present. +func Index[E comparable](s []E, v E) int { + for i, vs := range s { + if v == vs { + return i + } + } + return -1 +} + +// IndexFunc returns the first index i satisfying f(s[i]), +// or -1 if none do. +func IndexFunc[E any](s []E, f func(E) bool) int { + for i, v := range s { + if f(v) { + return i + } + } + return -1 +} + +// Contains reports whether v is present in s. +func Contains[E comparable](s []E, v E) bool { + return Index(s, v) >= 0 +} + +// Insert inserts the values v... into s at index i, +// returning the modified slice. +// In the returned slice r, r[i] == v[0]. +// Insert panics if i is out of range. +// This function is O(len(s) + len(v)). +func Insert[S ~[]E, E any](s S, i int, v ...E) S { + tot := len(s) + len(v) + if tot <= cap(s) { + s2 := s[:tot] + copy(s2[i+len(v):], s[i:]) + copy(s2[i:], v) + return s2 + } + s2 := make(S, tot) + copy(s2, s[:i]) + copy(s2[i:], v) + copy(s2[i+len(v):], s[i:]) + return s2 +} + +// Delete removes the elements s[i:j] from s, returning the modified slice. +// Delete panics if s[i:j] is not a valid slice of s. +// Delete modifies the contents of the slice s; it does not create a new slice. +// Delete is O(len(s)-j), so if many items must be deleted, it is better to +// make a single call deleting them all together than to delete one at a time. +// Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those +// elements contain pointers you might consider zeroing those elements so that +// objects they reference can be garbage collected. +func Delete[S ~[]E, E any](s S, i, j int) S { + _ = s[i:j] // bounds check + + return append(s[:i], s[j:]...) +} + +// Clone returns a copy of the slice. +// The elements are copied using assignment, so this is a shallow clone. +func Clone[S ~[]E, E any](s S) S { + // Preserve nil in case it matters. + if s == nil { + return nil + } + return append(S([]E{}), s...) +} + +// Compact replaces consecutive runs of equal elements with a single copy. +// This is like the uniq command found on Unix. +// Compact modifies the contents of the slice s; it does not create a new slice. +func Compact[S ~[]E, E comparable](s S) S { + if len(s) == 0 { + return s + } + i := 1 + last := s[0] + for _, v := range s[1:] { + if v != last { + s[i] = v + i++ + last = v + } + } + return s[:i] +} + +// CompactFunc is like Compact but uses a comparison function. +func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { + if len(s) == 0 { + return s + } + i := 1 + last := s[0] + for _, v := range s[1:] { + if !eq(v, last) { + s[i] = v + i++ + last = v + } + } + return s[:i] +} + +// Grow increases the slice's capacity, if necessary, to guarantee space for +// another n elements. After Grow(n), at least n elements can be appended +// to the slice without another allocation. If n is negative or too large to +// allocate the memory, Grow panics. +func Grow[S ~[]E, E any](s S, n int) S { + if n < 0 { + panic("cannot be negative") + } + if n -= cap(s) - len(s); n > 0 { + // TODO(https://go.dev/issue/53888): Make using []E instead of S + // to workaround a compiler bug where the runtime.growslice optimization + // does not take effect. Revert when the compiler is fixed. + s = append([]E(s)[:cap(s)], make([]E, n)...)[:len(s)] + } + return s +} + +// Clip removes unused capacity from the slice, returning s[:len(s):len(s)]. +func Clip[S ~[]E, E any](s S) S { + return s[:len(s):len(s)] +} diff --git a/vendor/golang.org/x/exp/slices/sort.go b/vendor/golang.org/x/exp/slices/sort.go new file mode 100644 index 000000000..c22e74bd1 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/sort.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import ( + "math/bits" + + "golang.org/x/exp/constraints" +) + +// Sort sorts a slice of any ordered type in ascending order. +// Sort may fail to sort correctly when sorting slices of floating-point +// numbers containing Not-a-number (NaN) values. +// Use slices.SortFunc(x, func(a, b float64) bool {return a < b || (math.IsNaN(a) && !math.IsNaN(b))}) +// instead if the input may contain NaNs. +func Sort[E constraints.Ordered](x []E) { + n := len(x) + pdqsortOrdered(x, 0, n, bits.Len(uint(n))) +} + +// SortFunc sorts the slice x in ascending order as determined by the less function. +// This sort is not guaranteed to be stable. +// +// SortFunc requires that less is a strict weak ordering. +// See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. +func SortFunc[E any](x []E, less func(a, b E) bool) { + n := len(x) + pdqsortLessFunc(x, 0, n, bits.Len(uint(n)), less) +} + +// SortStable sorts the slice x while keeping the original order of equal +// elements, using less to compare elements. +func SortStableFunc[E any](x []E, less func(a, b E) bool) { + stableLessFunc(x, len(x), less) +} + +// IsSorted reports whether x is sorted in ascending order. +func IsSorted[E constraints.Ordered](x []E) bool { + for i := len(x) - 1; i > 0; i-- { + if x[i] < x[i-1] { + return false + } + } + return true +} + +// IsSortedFunc reports whether x is sorted in ascending order, with less as the +// comparison function. +func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool { + for i := len(x) - 1; i > 0; i-- { + if less(x[i], x[i-1]) { + return false + } + } + return true +} + +// BinarySearch searches for target in a sorted slice and returns the position +// where target is found, or the position where target would appear in the +// sort order; it also returns a bool saying whether the target is really found +// in the slice. The slice must be sorted in increasing order. +func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { + // search returns the leftmost position where f returns true, or len(x) if f + // returns false for all x. This is the insertion position for target in x, + // and could point to an element that's either == target or not. + pos := search(len(x), func(i int) bool { return x[i] >= target }) + if pos >= len(x) || x[pos] != target { + return pos, false + } else { + return pos, true + } +} + +// BinarySearchFunc works like BinarySearch, but uses a custom comparison +// function. The slice must be sorted in increasing order, where "increasing" is +// defined by cmp. cmp(a, b) is expected to return an integer comparing the two +// parameters: 0 if a == b, a negative number if a < b and a positive number if +// a > b. +func BinarySearchFunc[E any](x []E, target E, cmp func(E, E) int) (int, bool) { + pos := search(len(x), func(i int) bool { return cmp(x[i], target) >= 0 }) + if pos >= len(x) || cmp(x[pos], target) != 0 { + return pos, false + } else { + return pos, true + } +} + +func search(n int, f func(int) bool) int { + // Define f(-1) == false and f(n) == true. + // Invariant: f(i-1) == false, f(j) == true. + i, j := 0, n + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + if !f(h) { + i = h + 1 // preserves f(i-1) == false + } else { + j = h // preserves f(j) == true + } + } + // i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i. + return i +} + +type sortedHint int // hint for pdqsort when choosing the pivot + +const ( + unknownHint sortedHint = iota + increasingHint + decreasingHint +) + +// xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf +type xorshift uint64 + +func (r *xorshift) Next() uint64 { + *r ^= *r << 13 + *r ^= *r >> 17 + *r ^= *r << 5 + return uint64(*r) +} + +func nextPowerOfTwo(length int) uint { + return 1 << bits.Len(uint(length)) +} diff --git a/vendor/golang.org/x/exp/slices/zsortfunc.go b/vendor/golang.org/x/exp/slices/zsortfunc.go new file mode 100644 index 000000000..2a632476c --- /dev/null +++ b/vendor/golang.org/x/exp/slices/zsortfunc.go @@ -0,0 +1,479 @@ +// Code generated by gen_sort_variants.go; DO NOT EDIT. + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +// insertionSortLessFunc sorts data[a:b] using insertion sort. +func insertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { + for i := a + 1; i < b; i++ { + for j := i; j > a && less(data[j], data[j-1]); j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// siftDownLessFunc implements the heap property on data[lo:hi]. +// first is an offset into the array where the root of the heap lies. +func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && less(data[first+child], data[first+child+1]) { + child++ + } + if !less(data[first+root], data[first+child]) { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} + +func heapSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDownLessFunc(data, i, hi, first, less) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+i], data[first] + siftDownLessFunc(data, lo, i, first, less) + } +} + +// pdqsortLessFunc sorts data[a:b]. +// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. +// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf +// C++ implementation: https://github.com/orlp/pdqsort +// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ +// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. +func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { + const maxInsertion = 12 + + var ( + wasBalanced = true // whether the last partitioning was reasonably balanced + wasPartitioned = true // whether the slice was already partitioned + ) + + for { + length := b - a + + if length <= maxInsertion { + insertionSortLessFunc(data, a, b, less) + return + } + + // Fall back to heapsort if too many bad choices were made. + if limit == 0 { + heapSortLessFunc(data, a, b, less) + return + } + + // If the last partitioning was imbalanced, we need to breaking patterns. + if !wasBalanced { + breakPatternsLessFunc(data, a, b, less) + limit-- + } + + pivot, hint := choosePivotLessFunc(data, a, b, less) + if hint == decreasingHint { + reverseRangeLessFunc(data, a, b, less) + // The chosen pivot was pivot-a elements after the start of the array. + // After reversing it is pivot-a elements before the end of the array. + // The idea came from Rust's implementation. + pivot = (b - 1) - (pivot - a) + hint = increasingHint + } + + // The slice is likely already sorted. + if wasBalanced && wasPartitioned && hint == increasingHint { + if partialInsertionSortLessFunc(data, a, b, less) { + return + } + } + + // Probably the slice contains many duplicate elements, partition the slice into + // elements equal to and elements greater than the pivot. + if a > 0 && !less(data[a-1], data[pivot]) { + mid := partitionEqualLessFunc(data, a, b, pivot, less) + a = mid + continue + } + + mid, alreadyPartitioned := partitionLessFunc(data, a, b, pivot, less) + wasPartitioned = alreadyPartitioned + + leftLen, rightLen := mid-a, b-mid + balanceThreshold := length / 8 + if leftLen < rightLen { + wasBalanced = leftLen >= balanceThreshold + pdqsortLessFunc(data, a, mid, limit, less) + a = mid + 1 + } else { + wasBalanced = rightLen >= balanceThreshold + pdqsortLessFunc(data, mid+1, b, limit, less) + b = mid + } + } +} + +// partitionLessFunc does one quicksort partition. +// Let p = data[pivot] +// Moves elements in data[a:b] around, so that data[i]
=p for i =p for i