diff --git a/go.mod b/go.mod index f22141c2c0e..a4f866583a9 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.14 require ( - github.com/AdguardTeam/dnsproxy v0.33.7 + github.com/AdguardTeam/dnsproxy v0.33.9 github.com/AdguardTeam/golibs v0.4.4 github.com/AdguardTeam/urlfilter v0.14.2 github.com/NYTimes/gziphandler v1.1.1 @@ -17,6 +17,7 @@ require ( github.com/insomniacslk/dhcp v0.0.0-20201112113307-4de412bc85d8 github.com/kardianos/service v1.2.0 github.com/karrick/godirwalk v1.16.1 // indirect + github.com/lucas-clemente/quic-go v0.19.3 github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 github.com/miekg/dns v1.1.35 diff --git a/go.sum b/go.sum index 1c488f95010..b298768bfed 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= -github.com/AdguardTeam/dnsproxy v0.33.7 h1:DXsLTJoBSUejB2ZqVHyMG0/kXD8PzuVPbLCsGKBdaDc= -github.com/AdguardTeam/dnsproxy v0.33.7/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs= +github.com/AdguardTeam/dnsproxy v0.33.9 h1:HUwywkhUV/M73E7qWcBAF+SdsNq742s82Lvox4pr/tM= +github.com/AdguardTeam/dnsproxy v0.33.9/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 4c94fc9929d..f8e7bff06e6 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" ) @@ -273,11 +274,6 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID stri return clientID, nil } -// tlsConn is a narrow interface for *tls.Conn to simplify testing. -type tlsConn interface { - ConnectionState() (cs tls.ConnectionState) -} - // processClientIDHTTPS extracts the client's ID from the path of the // client's DNS-over-HTTPS request. func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { @@ -326,6 +322,16 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } +// tlsConn is a narrow interface for *tls.Conn to simplify testing. +type tlsConn interface { + ConnectionState() (cs tls.ConnectionState) +} + +// quicSession is a narrow interface for quic.Session to simplify testing. +type quicSession interface { + ConnectionState() (cs quic.ConnectionState) +} + // processClientID extracts the client's ID from the server name of the client's // DOT or DOQ request or the path of the client's DOH. func processClientID(ctx *dnsContext) (rc resultCode) { @@ -342,15 +348,28 @@ func processClientID(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - conn := pctx.Conn - tc, ok := conn.(tlsConn) - if !ok { - ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) + cliSrvName := "" + if proto == proxy.ProtoTLS { + conn := pctx.Conn + tc, ok := conn.(tlsConn) + if !ok { + ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) - return resultCodeError + return resultCodeError + } + + cliSrvName = tc.ConnectionState().ServerName + } else if proto == proxy.ProtoQUIC { + qs, ok := pctx.QUICSession.(quicSession) + if !ok { + ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) + + return resultCodeError + } + + cliSrvName = qs.ConnectionState().ServerName } - cliSrvName := tc.ConnectionState().ServerName clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName) if err != nil { ctx.err = fmt.Errorf("client id check: %w", err) diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 0ae18d54f20..7904acf815f 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/lucas-clemente/quic-go" "github.com/stretchr/testify/assert" ) @@ -27,6 +28,22 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) { return cs } +// testQUICSession is a quicSession for tests. +type testQUICSession struct { + // Session is embedded here simply to make testQUICSession + // a quic.Session without acctually implementing the methods. + quic.Session + + serverName string +} + +// ConnectionState implements the quicSession interface for testQUICSession. +func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) { + cs.ServerName = c.serverName + + return cs +} + func TestProcessClientID(t *testing.T) { testCases := []struct { name string @@ -84,17 +101,32 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`, wantRes: resultCodeError, + }, { + name: "quic_client_id", + proto: proxy.ProtoQUIC, + hostSrvName: "example.com", + cliSrvName: "cli.example.com", + wantClientID: "cli", + wantErrMsg: "", + wantRes: resultCodeSuccess, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - conn := testTLSConn{ - serverName: tc.cliSrvName, + var conn net.Conn + if tc.proto == proxy.ProtoTLS { + conn = testTLSConn{serverName: tc.cliSrvName} + } + + var qs quic.Session + if tc.proto == proxy.ProtoQUIC { + qs = testQUICSession{serverName: tc.cliSrvName} } - conn.ConnectionState() + pctx := &proxy.DNSContext{ - Proto: tc.proto, - Conn: conn, + Proto: tc.proto, + Conn: conn, + QUICSession: qs, } tlsConf := TLSConfig{ServerName: tc.hostSrvName}