Skip to content

Commit

Permalink
add support for a proxy dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 2, 2023
1 parent 1897e7c commit 0945776
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
23 changes: 14 additions & 9 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"

"github.com/miekg/dns"
"golang.org/x/net/proxy"

logging "github.com/ipfs/go-log/v2"
)
Expand All @@ -19,7 +20,7 @@ const (

var log = logging.Logger("doh")

func doRequest(ctx context.Context, url string, m *dns.Msg) (*dns.Msg, error) {
func doRequest(ctx context.Context, url string, m *dns.Msg, dialer proxy.Dialer) (*dns.Msg, error) {
data, err := m.Pack()
if err != nil {
return nil, err
Expand All @@ -34,14 +35,18 @@ func doRequest(ctx context.Context, url string, m *dns.Msg) (*dns.Msg, error) {
req.Header.Set("Accept", dohMimeType)

req = req.WithContext(ctx)

resp, err := http.DefaultClient.Do(req)
client := http.DefaultClient
if dialer != nil {
client = &http.Client{Transport: &http.Transport{Dial: dialer.Dial}}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
fmt.Println(resp.Status, resp.StatusCode)
return nil, fmt.Errorf("HTTP error: %q [%d]", resp.Status, resp.StatusCode)
}

Expand All @@ -62,13 +67,13 @@ func doRequest(ctx context.Context, url string, m *dns.Msg) (*dns.Msg, error) {
return r, nil
}

func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
func doRequestA(ctx context.Context, url string, domain string, dialer proxy.Dialer) ([]net.IPAddr, uint32, error) {
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeA)

r, err := doRequest(ctx, url, m)
r, err := doRequest(ctx, url, m, dialer)
if err != nil {
return nil, 0, err
}
Expand All @@ -90,13 +95,13 @@ func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, u
return result, ttl, nil
}

func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
func doRequestAAAA(ctx context.Context, url string, domain string, dialer proxy.Dialer) ([]net.IPAddr, uint32, error) {
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeAAAA)

r, err := doRequest(ctx, url, m)
r, err := doRequest(ctx, url, m, dialer)
if err != nil {
return nil, 0, err
}
Expand All @@ -119,13 +124,13 @@ func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr
return result, ttl, nil
}

func doRequestTXT(ctx context.Context, url string, domain string) ([]string, uint32, error) {
func doRequestTXT(ctx context.Context, url string, domain string, dialer proxy.Dialer) ([]string, uint32, error) {
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeTXT)

r, err := doRequest(ctx, url, m)
r, err := doRequest(ctx, url, m, dialer)
if err != nil {
return nil, 0, err
}
Expand Down
15 changes: 12 additions & 3 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/miekg/dns"
"golang.org/x/net/proxy"

madns "github.com/multiformats/go-multiaddr-dns"
)
Expand All @@ -21,6 +22,7 @@ type Resolver struct {
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
maxCacheTTL time.Duration
dialer proxy.Dialer
}

type ipAddrEntry struct {
Expand Down Expand Up @@ -51,6 +53,13 @@ func WithCacheDisabled() Option {
}
}

func WithDialer(dialer proxy.Dialer) Option {
return func(tr *Resolver) error {
tr.dialer = dialer
return nil
}
}

func NewResolver(url string, opts ...Option) (*Resolver, error) {
if !strings.HasPrefix(url, "https:") {
url = "https://" + url
Expand Down Expand Up @@ -88,12 +97,12 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne

resch := make(chan response, 2)
go func() {
ip4, ttl, err := doRequestA(ctx, r.url, domain)
ip4, ttl, err := doRequestA(ctx, r.url, domain, r.dialer)
resch <- response{ip4, ttl, err}
}()

go func() {
ip6, ttl, err := doRequestAAAA(ctx, r.url, domain)
ip6, ttl, err := doRequestAAAA(ctx, r.url, domain, r.dialer)
resch <- response{ip6, ttl, err}
}()

Expand Down Expand Up @@ -121,7 +130,7 @@ func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, erro
return result, nil
}

result, ttl, err := doRequestTXT(ctx, r.url, domain)
result, ttl, err := doRequestTXT(ctx, r.url, domain, r.dialer)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 0945776

Please sign in to comment.