Skip to content

Commit

Permalink
fix: dns dial to wrong target
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Jun 15, 2024
1 parent ad5bc51 commit 40f40f6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 167 deletions.
3 changes: 1 addition & 2 deletions adapter/outbound/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package outbound
import (
"context"
"errors"
"net/netip"
"os"
"strconv"

Expand Down Expand Up @@ -58,7 +57,7 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
}
metadata.DstIP = ip
}
pc, err := dialer.NewDialer(d.Base.DialOptions(opts...)...).ListenPacket(ctx, "udp", "", netip.AddrPortFrom(metadata.DstIP, metadata.DstPort))
pc, err := dialer.NewDialer(d.Base.DialOptions(opts...)...).ListenPacket(ctx, "udp", "", metadata.AddrPort())
if err != nil {
return nil, err
}
Expand Down
47 changes: 7 additions & 40 deletions dns/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,20 @@ import (
"crypto/tls"
"fmt"
"net"
"net/netip"
"strings"

"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"

"github.com/metacubex/randv2"
D "github.com/miekg/dns"
)

type client struct {
*D.Client
r *Resolver
port string
host string
iface string
proxyAdapter C.ProxyAdapter
proxyName string
addr string
port string
host string
dialer *dnsDialer
addr string
}

var _ dnsClient = (*client)(nil)
Expand All @@ -49,38 +41,13 @@ func (c *client) Address() string {
}

func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) {
var (
ip netip.Addr
err error
)
if c.r == nil {
// a default ip dns
if ip, err = netip.ParseAddr(c.host); err != nil {
return nil, fmt.Errorf("dns %s not a valid ip", c.host)
}
} else {
ips, err := resolver.LookupIPWithResolver(ctx, c.host, c.r)
if err != nil {
return nil, fmt.Errorf("use default dns resolve failed: %w", err)
} else if len(ips) == 0 {
return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, c.host)
}
ip = ips[randv2.IntN(len(ips))]
}

network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp"
}

var options []dialer.Option
if c.iface != "" {
options = append(options, dialer.WithInterface(c.iface))
}

dialHandler := getDialHandler(c.r, c.proxyAdapter, c.proxyName, options...)
addr := net.JoinHostPort(ip.String(), c.port)
conn, err := dialHandler(ctx, network, addr)
addr := net.JoinHostPort(c.host, c.port)
conn, err := c.dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -115,7 +82,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
tcpClient.Net = "tcp"
network = "tcp"
log.Debugln("[DNS] Truncated reply from %s:%s for %s over UDP, retrying over TCP", c.host, c.port, m.Question[0].String())
dConn.Conn, err = dialHandler(ctx, network, addr)
dConn.Conn, err = c.dialer.DialContext(ctx, network, addr)
if err != nil {
ch <- result{msg, err}
return
Expand Down
5 changes: 2 additions & 3 deletions dns/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import "github.com/metacubex/mihomo/tunnel"

const RespectRules = tunnel.DnsRespectRules

type dialHandler = tunnel.DnsDialHandler
type dnsDialer = tunnel.DNSDialer

var getDialHandler = tunnel.GetDnsDialHandler
var listenPacket = tunnel.DnsListenPacket
var newDNSDialer = tunnel.NewDNSDialer
35 changes: 14 additions & 21 deletions dns/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ type dnsOverHTTPS struct {
quicConfig *quic.Config
quicConfigGuard sync.Mutex
url *url.URL
r *Resolver
httpVersions []C.HTTPVersion
proxyAdapter C.ProxyAdapter
proxyName string
dialer *dnsDialer
addr string
}

Expand All @@ -85,11 +83,9 @@ func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[strin
}

doh := &dnsOverHTTPS{
url: u,
addr: u.String(),
r: r,
proxyAdapter: proxyAdapter,
proxyName: proxyName,
url: u,
addr: u.String(),
dialer: newDNSDialer(r, proxyAdapter, proxyName),
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Expand Down Expand Up @@ -388,13 +384,12 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp
nextProtos = append(nextProtos, string(v))
}
tlsConfig.NextProtos = nextProtos
dialContext := getDialHandler(doh.r, doh.proxyAdapter, doh.proxyName)

if slices.Contains(doh.httpVersions, C.HTTPVersion3) {
// First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this
// upstream.
transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext)
transportH3, err := doh.createTransportH3(ctx, tlsConfig)
if err == nil {
log.Debugln("[%s] using HTTP/3 for this upstream: QUIC was faster", doh.url.String())
return transportH3, nil
Expand All @@ -410,7 +405,7 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp
transport := &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true,
DialContext: dialContext,
DialContext: doh.dialer.DialContext,
IdleConnTimeout: transportDefaultIdleConnTimeout,
MaxConnsPerHost: dohMaxConnsPerHost,
MaxIdleConns: dohMaxIdleConns,
Expand Down Expand Up @@ -490,13 +485,12 @@ func (h *http3Transport) Close() (err error) {
func (doh *dnsOverHTTPS) createTransportH3(
ctx context.Context,
tlsConfig *tls.Config,
dialContext dialHandler,
) (roundTripper http.RoundTripper, err error) {
if !doh.supportsH3() {
return nil, errors.New("HTTP3 support is not enabled")
}

addr, err := doh.probeH3(ctx, tlsConfig, dialContext)
addr, err := doh.probeH3(ctx, tlsConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -534,7 +528,7 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
IP: net.ParseIP(ip),
Port: portInt,
}
conn, err := listenPacket(ctx, doh.proxyAdapter, doh.proxyName, "udp", addr, doh.r)
conn, err := doh.dialer.ListenPacket(ctx, "udp", addr)
if err != nil {
return nil, err
}
Expand All @@ -557,12 +551,11 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
func (doh *dnsOverHTTPS) probeH3(
ctx context.Context,
tlsConfig *tls.Config,
dialContext dialHandler,
) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there are v4/v6 addresses).
rawConn, err := dialContext(ctx, "udp", doh.url.Host)
rawConn, err := doh.dialer.DialContext(ctx, "udp", doh.url.Host)
if err != nil {
return "", fmt.Errorf("failed to dial: %w", err)
}
Expand Down Expand Up @@ -592,7 +585,7 @@ func (doh *dnsOverHTTPS) probeH3(
chQuic := make(chan error, 1)
chTLS := make(chan error, 1)
go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS)
go doh.probeTLS(ctx, probeTLSCfg, chTLS)

select {
case quicErr := <-chQuic:
Expand Down Expand Up @@ -635,10 +628,10 @@ func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *

// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()

conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig)
conn, err := doh.tlsDial(ctx, "tcp", tlsConfig)
if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err)
return
Expand Down Expand Up @@ -694,10 +687,10 @@ func isHTTP3(client *http.Client) (ok bool) {

// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed
// to the function.
rawConn, err := dialContext(ctx, network, doh.url.Host)
rawConn, err := doh.dialer.DialContext(ctx, network, doh.url.Host)
if err != nil {
return nil, err
}
Expand Down
16 changes: 6 additions & 10 deletions dns/doq.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ type dnsOverQUIC struct {
bytesPool *sync.Pool
bytesPoolGuard sync.Mutex

addr string
proxyAdapter C.ProxyAdapter
proxyName string
r *Resolver
addr string
dialer *dnsDialer
}

// type check
Expand All @@ -72,10 +70,8 @@ var _ dnsClient = (*dnsOverQUIC)(nil)
// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(resolver *Resolver, addr string, proxyAdapter C.ProxyAdapter, proxyName string) (dnsClient, error) {
doq := &dnsOverQUIC{
addr: addr,
proxyAdapter: proxyAdapter,
proxyName: proxyName,
r: resolver,
addr: addr,
dialer: newDNSDialer(resolver, proxyAdapter, proxyName),
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Expand Down Expand Up @@ -300,7 +296,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio
// we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses).
rawConn, err := getDialHandler(doq.r, doq.proxyAdapter, doq.proxyName)(ctx, "udp", doq.addr)
rawConn, err := doq.dialer.DialContext(ctx, "udp", doq.addr)
if err != nil {
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
}
Expand All @@ -315,7 +311,7 @@ func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connectio

p, err := strconv.Atoi(port)
udpAddr := net.UDPAddr{IP: net.ParseIP(ip), Port: p}
udp, err := listenPacket(ctx, doq.proxyAdapter, doq.proxyName, "udp", addr, doq.r)
udp, err := doq.dialer.ListenPacket(ctx, "udp", addr)
if err != nil {
return nil, err
}
Expand Down
15 changes: 9 additions & 6 deletions dns/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/metacubex/mihomo/common/nnip"
"github.com/metacubex/mihomo/common/picker"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/log"

Expand Down Expand Up @@ -115,6 +116,11 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient {
continue
}

var options []dialer.Option
if s.Interface != "" {
options = append(options, dialer.WithInterface(s.Interface))
}

host, port, _ := net.SplitHostPort(s.Addr)
ret = append(ret, &client{
Client: &D.Client{
Expand All @@ -125,12 +131,9 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient {
UDPSize: 4096,
Timeout: 5 * time.Second,
},
port: port,
host: host,
iface: s.Interface,
r: resolver,
proxyAdapter: s.ProxyAdapter,
proxyName: s.ProxyName,
port: port,
host: host,
dialer: newDNSDialer(resolver, s.ProxyAdapter, s.ProxyName, options...),
})
}
return ret
Expand Down
Loading

0 comments on commit 40f40f6

Please sign in to comment.