From 40f40f6d2423c093add49820894dd6d22346a9f4 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sat, 15 Jun 2024 00:33:03 +0800 Subject: [PATCH] fix: dns dial to wrong target --- adapter/outbound/direct.go | 3 +- dns/client.go | 47 ++-------- dns/dialer.go | 5 +- dns/doh.go | 35 +++----- dns/doq.go | 16 ++-- dns/util.go | 15 ++-- tunnel/dns_dialer.go | 180 +++++++++++++++++++------------------ 7 files changed, 134 insertions(+), 167 deletions(-) diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 09b9696b2f..c904d3b62a 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -3,7 +3,6 @@ package outbound import ( "context" "errors" - "net/netip" "os" "strconv" @@ -58,7 +57,7 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, } metadata.DstIP = ip } - pc, err := dialer.NewDialer(d.Base.DialOptions(opts...)...).ListenPacket(ctx, "udp", "", netip.AddrPortFrom(metadata.DstIP, metadata.DstPort)) + pc, err := dialer.NewDialer(d.Base.DialOptions(opts...)...).ListenPacket(ctx, "udp", "", metadata.AddrPort()) if err != nil { return nil, err } diff --git a/dns/client.go b/dns/client.go index a6f0a7d492..096b96a7f5 100644 --- a/dns/client.go +++ b/dns/client.go @@ -5,28 +5,20 @@ import ( "crypto/tls" "fmt" "net" - "net/netip" "strings" "github.com/metacubex/mihomo/component/ca" - "github.com/metacubex/mihomo/component/dialer" - "github.com/metacubex/mihomo/component/resolver" - C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" - "github.com/metacubex/randv2" D "github.com/miekg/dns" ) type client struct { *D.Client - r *Resolver - port string - host string - iface string - proxyAdapter C.ProxyAdapter - proxyName string - addr string + port string + host string + dialer *dnsDialer + addr string } var _ dnsClient = (*client)(nil) @@ -49,38 +41,13 @@ func (c *client) Address() string { } func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) { - var ( - ip netip.Addr - err error - ) - if c.r == nil { - // a default ip dns - if ip, err = netip.ParseAddr(c.host); err != nil { - return nil, fmt.Errorf("dns %s not a valid ip", c.host) - } - } else { - ips, err := resolver.LookupIPWithResolver(ctx, c.host, c.r) - if err != nil { - return nil, fmt.Errorf("use default dns resolve failed: %w", err) - } else if len(ips) == 0 { - return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, c.host) - } - ip = ips[randv2.IntN(len(ips))] - } - network := "udp" if strings.HasPrefix(c.Client.Net, "tcp") { network = "tcp" } - var options []dialer.Option - if c.iface != "" { - options = append(options, dialer.WithInterface(c.iface)) - } - - dialHandler := getDialHandler(c.r, c.proxyAdapter, c.proxyName, options...) - addr := net.JoinHostPort(ip.String(), c.port) - conn, err := dialHandler(ctx, network, addr) + addr := net.JoinHostPort(c.host, c.port) + conn, err := c.dialer.DialContext(ctx, network, addr) if err != nil { return nil, err } @@ -115,7 +82,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) tcpClient.Net = "tcp" network = "tcp" log.Debugln("[DNS] Truncated reply from %s:%s for %s over UDP, retrying over TCP", c.host, c.port, m.Question[0].String()) - dConn.Conn, err = dialHandler(ctx, network, addr) + dConn.Conn, err = c.dialer.DialContext(ctx, network, addr) if err != nil { ch <- result{msg, err} return diff --git a/dns/dialer.go b/dns/dialer.go index 309b6c1886..f4d9e128f8 100644 --- a/dns/dialer.go +++ b/dns/dialer.go @@ -6,7 +6,6 @@ import "github.com/metacubex/mihomo/tunnel" const RespectRules = tunnel.DnsRespectRules -type dialHandler = tunnel.DnsDialHandler +type dnsDialer = tunnel.DNSDialer -var getDialHandler = tunnel.GetDnsDialHandler -var listenPacket = tunnel.DnsListenPacket +var newDNSDialer = tunnel.NewDNSDialer diff --git a/dns/doh.go b/dns/doh.go index 09d311b585..54b8279657 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -62,10 +62,8 @@ type dnsOverHTTPS struct { quicConfig *quic.Config quicConfigGuard sync.Mutex url *url.URL - r *Resolver httpVersions []C.HTTPVersion - proxyAdapter C.ProxyAdapter - proxyName string + dialer *dnsDialer addr string } @@ -85,11 +83,9 @@ func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[strin } doh := &dnsOverHTTPS{ - url: u, - addr: u.String(), - r: r, - proxyAdapter: proxyAdapter, - proxyName: proxyName, + url: u, + addr: u.String(), + dialer: newDNSDialer(r, proxyAdapter, proxyName), quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, TokenStore: newQUICTokenStore(), @@ -388,13 +384,12 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp nextProtos = append(nextProtos, string(v)) } tlsConfig.NextProtos = nextProtos - dialContext := getDialHandler(doh.r, doh.proxyAdapter, doh.proxyName) if slices.Contains(doh.httpVersions, C.HTTPVersion3) { // First, we attempt to create an HTTP3 transport. If the probe QUIC // connection is established successfully, we'll be using HTTP3 for this // upstream. - transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext) + transportH3, err := doh.createTransportH3(ctx, tlsConfig) if err == nil { log.Debugln("[%s] using HTTP/3 for this upstream: QUIC was faster", doh.url.String()) return transportH3, nil @@ -410,7 +405,7 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp transport := &http.Transport{ TLSClientConfig: tlsConfig, DisableCompression: true, - DialContext: dialContext, + DialContext: doh.dialer.DialContext, IdleConnTimeout: transportDefaultIdleConnTimeout, MaxConnsPerHost: dohMaxConnsPerHost, MaxIdleConns: dohMaxIdleConns, @@ -490,13 +485,12 @@ func (h *http3Transport) Close() (err error) { func (doh *dnsOverHTTPS) createTransportH3( ctx context.Context, tlsConfig *tls.Config, - dialContext dialHandler, ) (roundTripper http.RoundTripper, err error) { if !doh.supportsH3() { return nil, errors.New("HTTP3 support is not enabled") } - addr, err := doh.probeH3(ctx, tlsConfig, dialContext) + addr, err := doh.probeH3(ctx, tlsConfig) if err != nil { return nil, err } @@ -534,7 +528,7 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. IP: net.ParseIP(ip), Port: portInt, } - conn, err := listenPacket(ctx, doh.proxyAdapter, doh.proxyName, "udp", addr, doh.r) + conn, err := doh.dialer.ListenPacket(ctx, "udp", addr) if err != nil { return nil, err } @@ -557,12 +551,11 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. func (doh *dnsOverHTTPS) probeH3( ctx context.Context, tlsConfig *tls.Config, - dialContext dialHandler, ) (addr string, err error) { // We're using bootstrapped address instead of what's passed to the function // it does not create an actual connection, but it helps us determine // what IP is actually reachable (when there are v4/v6 addresses). - rawConn, err := dialContext(ctx, "udp", doh.url.Host) + rawConn, err := doh.dialer.DialContext(ctx, "udp", doh.url.Host) if err != nil { return "", fmt.Errorf("failed to dial: %w", err) } @@ -592,7 +585,7 @@ func (doh *dnsOverHTTPS) probeH3( chQuic := make(chan error, 1) chTLS := make(chan error, 1) go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic) - go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS) + go doh.probeTLS(ctx, probeTLSCfg, chTLS) select { case quicErr := <-chQuic: @@ -635,10 +628,10 @@ func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig * // probeTLS attempts to establish a TLS connection to the specified address. We // run probeQUIC and probeTLS in parallel and see which one is faster. -func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { +func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, tlsConfig *tls.Config, ch chan error) { startTime := time.Now() - conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig) + conn, err := doh.tlsDial(ctx, "tcp", tlsConfig) if err != nil { ch <- fmt.Errorf("opening TLS connection: %w", err) return @@ -694,10 +687,10 @@ func isHTTP3(client *http.Client) (ok bool) { // tlsDial is basically the same as tls.DialWithDialer, but we will call our own // dialContext function to get connection. -func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { +func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, network string, config *tls.Config) (*tls.Conn, error) { // We're using bootstrapped address instead of what's passed // to the function. - rawConn, err := dialContext(ctx, network, doh.url.Host) + rawConn, err := doh.dialer.DialContext(ctx, network, doh.url.Host) if err != nil { return nil, err } diff --git a/dns/doq.go b/dns/doq.go index 70b67c2a61..ad936f9575 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -60,10 +60,8 @@ type dnsOverQUIC struct { bytesPool *sync.Pool bytesPoolGuard sync.Mutex - addr string - proxyAdapter C.ProxyAdapter - proxyName string - r *Resolver + addr string + dialer *dnsDialer } // type check @@ -72,10 +70,8 @@ var _ dnsClient = (*dnsOverQUIC)(nil) // newDoQ returns the DNS-over-QUIC Upstream. func newDoQ(resolver *Resolver, addr string, proxyAdapter C.ProxyAdapter, proxyName string) (dnsClient, error) { doq := &dnsOverQUIC{ - addr: addr, - proxyAdapter: proxyAdapter, - proxyName: proxyName, - r: resolver, + addr: addr, + dialer: newDNSDialer(resolver, proxyAdapter, proxyName), quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, TokenStore: newQUICTokenStore(), @@ -300,7 +296,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio // we're using bootstrapped address instead of what's passed to the function // it does not create an actual connection, but it helps us determine // what IP is actually reachable (when there're v4/v6 addresses). - rawConn, err := getDialHandler(doq.r, doq.proxyAdapter, doq.proxyName)(ctx, "udp", doq.addr) + rawConn, err := doq.dialer.DialContext(ctx, "udp", doq.addr) if err != nil { return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) } @@ -315,7 +311,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio p, err := strconv.Atoi(port) udpAddr := net.UDPAddr{IP: net.ParseIP(ip), Port: p} - udp, err := listenPacket(ctx, doq.proxyAdapter, doq.proxyName, "udp", addr, doq.r) + udp, err := doq.dialer.ListenPacket(ctx, "udp", addr) if err != nil { return nil, err } diff --git a/dns/util.go b/dns/util.go index 9c5a0b585a..e4ec5917cf 100644 --- a/dns/util.go +++ b/dns/util.go @@ -12,6 +12,7 @@ import ( "github.com/metacubex/mihomo/common/nnip" "github.com/metacubex/mihomo/common/picker" + "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/log" @@ -115,6 +116,11 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { continue } + var options []dialer.Option + if s.Interface != "" { + options = append(options, dialer.WithInterface(s.Interface)) + } + host, port, _ := net.SplitHostPort(s.Addr) ret = append(ret, &client{ Client: &D.Client{ @@ -125,12 +131,9 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { UDPSize: 4096, Timeout: 5 * time.Second, }, - port: port, - host: host, - iface: s.Interface, - r: resolver, - proxyAdapter: s.ProxyAdapter, - proxyName: s.ProxyName, + port: port, + host: host, + dialer: newDNSDialer(resolver, s.ProxyAdapter, s.ProxyName, options...), }) } return ret diff --git a/tunnel/dns_dialer.go b/tunnel/dns_dialer.go index f4c0be7139..1839869b4a 100644 --- a/tunnel/dns_dialer.go +++ b/tunnel/dns_dialer.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "net" - "net/netip" "strings" N "github.com/metacubex/mihomo/common/net" @@ -18,110 +17,121 @@ import ( const DnsRespectRules = "RULES" -type DnsDialHandler func(ctx context.Context, network, addr string) (net.Conn, error) +type DNSDialer struct { + r resolver.Resolver + proxyAdapter C.ProxyAdapter + proxyName string + opts []dialer.Option +} -func GetDnsDialHandler(r resolver.Resolver, proxyAdapter C.ProxyAdapter, proxyName string, opts ...dialer.Option) DnsDialHandler { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - if len(proxyName) == 0 && proxyAdapter == nil { - opts = append(opts, dialer.WithResolver(r)) - return dialer.DialContext(ctx, network, addr, opts...) - } else { - metadata := &C.Metadata{ - NetWork: C.TCP, - Type: C.INNER, - } - err := metadata.SetRemoteAddress(addr) // tcp can resolve host by remote +func NewDNSDialer(r resolver.Resolver, proxyAdapter C.ProxyAdapter, proxyName string, opts ...dialer.Option) *DNSDialer { + return &DNSDialer{r: r, proxyAdapter: proxyAdapter, proxyName: proxyName, opts: opts} +} + +func (d *DNSDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + r := d.r + proxyName := d.proxyName + proxyAdapter := d.proxyAdapter + opts := d.opts + var rule C.Rule + metadata := &C.Metadata{ + NetWork: C.TCP, + Type: C.INNER, + } + err := metadata.SetRemoteAddress(addr) // tcp can resolve host by remote + if err != nil { + return nil, err + } + if !strings.Contains(network, "tcp") { + metadata.NetWork = C.UDP + if !metadata.Resolved() { + // udp must resolve host first + dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) if err != nil { return nil, err } - if !strings.Contains(network, "tcp") { - metadata.NetWork = C.UDP - if !metadata.Resolved() { - // udp must resolve host first - dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) - if err != nil { - return nil, err - } - metadata.DstIP = dstIP - } - } + metadata.DstIP = dstIP + } + } - var rule C.Rule - if proxyAdapter == nil { - if proxyName == DnsRespectRules { - if !metadata.Resolved() { - // resolve here before resolveMetadata to avoid its inner resolver.ResolveIP - dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) - if err != nil { - return nil, err - } - metadata.DstIP = dstIP - } - proxyAdapter, rule, err = resolveMetadata(metadata) - if err != nil { - return nil, err - } - } else { - var ok bool - proxyAdapter, ok = Proxies()[proxyName] - if !ok { - opts = append(opts, dialer.WithInterface(proxyName)) - } + if proxyAdapter == nil && len(proxyName) != 0 { + if proxyName == DnsRespectRules { + if !metadata.Resolved() { + // resolve here before resolveMetadata to avoid its inner resolver.ResolveIP + dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) + if err != nil { + return nil, err } + metadata.DstIP = dstIP } + proxyAdapter, rule, err = resolveMetadata(metadata) + if err != nil { + return nil, err + } + } else { + var ok bool + proxyAdapter, ok = Proxies()[proxyName] + if !ok { + opts = append(opts, dialer.WithInterface(proxyName)) + } + } + } - if strings.Contains(network, "tcp") { - if proxyAdapter == nil { - opts = append(opts, dialer.WithResolver(r)) - return dialer.DialContext(ctx, network, addr, opts...) - } - - if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback - if !metadata.Resolved() { - dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) - if err != nil { - return nil, err - } - metadata.DstIP = dstIP - } - metadata.Host = "" // clear host to avoid double resolve in proxy - } + if metadata.NetWork == C.TCP { + if proxyAdapter == nil { + opts = append(opts, dialer.WithResolver(r)) + return dialer.DialContext(ctx, network, addr, opts...) + } - conn, err := proxyAdapter.DialContext(ctx, metadata, opts...) + if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback + if !metadata.Resolved() { + dstIP, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) if err != nil { - logMetadataErr(metadata, rule, proxyAdapter, err) return nil, err } - logMetadata(metadata, rule, conn) - - conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, metadata, rule, 0, 0, false) + metadata.DstIP = dstIP + } + metadata.Host = "" // clear host to avoid double resolve in proxy + } - return conn, nil - } else { - if proxyAdapter == nil { - return dialer.DialContext(ctx, network, addr, opts...) - } + conn, err := proxyAdapter.DialContext(ctx, metadata, opts...) + if err != nil { + logMetadataErr(metadata, rule, proxyAdapter, err) + return nil, err + } + logMetadata(metadata, rule, conn) - if !proxyAdapter.SupportUDP() { - return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter) - } + conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, metadata, rule, 0, 0, false) - packetConn, err := proxyAdapter.ListenPacketContext(ctx, metadata, opts...) - if err != nil { - logMetadataErr(metadata, rule, proxyAdapter, err) - return nil, err - } - logMetadata(metadata, rule, packetConn) + return conn, nil + } else { + if proxyAdapter == nil { + return dialer.DialContext(ctx, network, metadata.AddrPort().String(), opts...) + } - packetConn = statistic.NewUDPTracker(packetConn, statistic.DefaultManager, metadata, rule, 0, 0, false) + if !proxyAdapter.SupportUDP() { + return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter) + } - return N.NewBindPacketConn(packetConn, metadata.UDPAddr()), nil - } + packetConn, err := proxyAdapter.ListenPacketContext(ctx, metadata, opts...) + if err != nil { + logMetadataErr(metadata, rule, proxyAdapter, err) + return nil, err } + logMetadata(metadata, rule, packetConn) + + packetConn = statistic.NewUDPTracker(packetConn, statistic.DefaultManager, metadata, rule, 0, 0, false) + + return N.NewBindPacketConn(packetConn, metadata.UDPAddr()), nil } + } -func DnsListenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName string, network string, addr string, r resolver.Resolver, opts ...dialer.Option) (net.PacketConn, error) { +func (d *DNSDialer) ListenPacket(ctx context.Context, network, addr string) (net.PacketConn, error) { + r := d.r + proxyAdapter := d.proxyAdapter + proxyName := d.proxyName + opts := d.opts metadata := &C.Metadata{ NetWork: C.UDP, Type: C.INNER, @@ -156,7 +166,7 @@ func DnsListenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName } if proxyAdapter == nil { - return dialer.NewDialer(opts...).ListenPacket(ctx, network, "", netip.AddrPortFrom(metadata.DstIP, metadata.DstPort)) + return dialer.NewDialer(opts...).ListenPacket(ctx, network, "", metadata.AddrPort()) } if !proxyAdapter.SupportUDP() {