Skip to content

Commit

Permalink
feat: new dialWrap with optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed May 18, 2024
1 parent f44b8ee commit 6a73dd4
Show file tree
Hide file tree
Showing 4 changed files with 790 additions and 245 deletions.
301 changes: 57 additions & 244 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strings"
"time"

"github.com/Mzack9999/gcache"
gounit "github.com/docker/go-units"
"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/fastdialer/fastdialer/metafiles"
"github.com/projectdiscovery/fastdialer/fastdialer/utils"
"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/networkpolicy"
retryabledns "github.com/projectdiscovery/retryabledns"
cryptoutil "github.com/projectdiscovery/utils/crypto"
"github.com/projectdiscovery/utils/env"
errorutil "github.com/projectdiscovery/utils/errors"
iputil "github.com/projectdiscovery/utils/ip"
ptrutil "github.com/projectdiscovery/utils/ptr"
utls "github.com/refraction-networking/utls"
"github.com/zmap/zcrypto/encoding/asn1"
ztls "github.com/zmap/zcrypto/tls"
"golang.org/x/net/proxy"
Expand All @@ -36,6 +29,7 @@ var (
disableZTLSFallback = false
MaxDNSCacheSize int64
MaxDNSItems = 1024
MaxDialCacheSize = 10000
)

func init() {
Expand Down Expand Up @@ -65,6 +59,7 @@ type Dialer struct {
dialer *net.Dialer
proxyDialer *proxy.Dialer
networkpolicy *networkpolicy.NetworkPolicy
dialCache gcache.Cache[string, *utils.DialWrap]
}

// NewDialer instance
Expand Down Expand Up @@ -162,7 +157,6 @@ func NewDialer(options Options) (*Dialer, error) {
if err != nil {
return nil, err
}

return &Dialer{
dnsclient: dnsclient,
mDnsCache: dnsCache,
Expand All @@ -174,13 +168,22 @@ func NewDialer(options Options) (*Dialer, error) {
proxyDialer: options.ProxyDialer,
options: &options,
networkpolicy: np,
dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(),
}, nil
}

// Dial function compatible with net/http
func (d *Dialer) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
conn, err = d.dial(ctx, network, address, false, false, nil, nil, impersonate.None, nil)
return
return d.dial(ctx, &dialOptions{
network: network,
address: address,
shouldUseTLS: false,
shouldUseZTLS: false,
tlsconfig: nil,
ztlsconfig: nil,
impersonateStrategy: impersonate.None,
impersonateIdentity: nil,
})
}

// DialTLS with encrypted connection
Expand All @@ -199,14 +202,30 @@ func (d *Dialer) DialZTLS(ctx context.Context, network, address string) (conn ne

// DialTLS with encrypted connection
func (d *Dialer) DialTLSWithConfig(ctx context.Context, network, address string, config *tls.Config) (conn net.Conn, err error) {
conn, err = d.dial(ctx, network, address, true, false, config, nil, impersonate.None, nil)
return
return d.dial(ctx, &dialOptions{
network: network,
address: address,
shouldUseTLS: true,
shouldUseZTLS: false,
tlsconfig: config,
ztlsconfig: nil,
impersonateStrategy: impersonate.None,
impersonateIdentity: nil,
})
}

// DialTLSWithConfigImpersonate dials tls with impersonation
func (d *Dialer) DialTLSWithConfigImpersonate(ctx context.Context, network, address string, config *tls.Config, impersonate impersonate.Strategy, identity *impersonate.Identity) (conn net.Conn, err error) {
conn, err = d.dial(ctx, network, address, true, false, config, nil, impersonate, identity)
return
return d.dial(ctx, &dialOptions{
network: network,
address: address,
shouldUseTLS: true,
shouldUseZTLS: false,
tlsconfig: config,
ztlsconfig: nil,
impersonateStrategy: impersonate,
impersonateIdentity: identity,
})
}

// DialZTLSWithConfig dials ztls with config
Expand All @@ -217,233 +236,27 @@ func (d *Dialer) DialZTLSWithConfig(ctx context.Context, network, address string
if err != nil {
return nil, err
}
return d.dial(ctx, network, address, true, false, stdTLSConfig, nil, impersonate.None, nil)
}
return d.dial(ctx, network, address, false, true, nil, config, impersonate.None, nil)
}

func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS, shouldUseZTLS bool, tlsconfig *tls.Config, ztlsconfig *ztls.Config, impersonateStrategy impersonate.Strategy, impersonateIdentity *impersonate.Identity) (conn net.Conn, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var hostname, port, fixedIP string

if strings.HasPrefix(address, "[") {
closeBracketIndex := strings.Index(address, "]")
if closeBracketIndex == -1 {
return nil, MalformedIP6Error
}
hostname = address[:closeBracketIndex+1]
if len(address) < closeBracketIndex+2 {
return nil, NoPortSpecifiedError
}
port = address[closeBracketIndex+2:]
} else {
addressParts := strings.SplitN(address, ":", 3)
numberOfParts := len(addressParts)

if numberOfParts >= 2 {
// ip|host:port
hostname = addressParts[0]
port = addressParts[1]
// ip|host:port:ip => curl --resolve ip:port:ip
if numberOfParts > 2 {
fixedIP = addressParts[2]
}
// check if the ip is within the context
if ctxIP := ctx.Value(IP); ctxIP != nil {
fixedIP = fmt.Sprint(ctxIP)
}
} else {
// no port => error
return nil, NoPortSpecifiedError
}
}

// check if data is in cache
hostname = asAscii(hostname)
data, err := d.GetDNSData(hostname)
if err != nil {
// otherwise attempt to retrieve it
data, err = d.dnsclient.Resolve(hostname)
}
if data == nil {
return nil, ResolveHostError
}

if err != nil || len(data.A)+len(data.AAAA) == 0 {
return nil, NoAddressFoundError
}

var numInvalidIPS int
var IPS []string
// use fixed ip as first
if fixedIP != "" {
IPS = append(IPS, fixedIP)
} else {
IPS = append(IPS, append(data.A, data.AAAA...)...)
}

// Dial to the IPs finally.
for _, ip := range IPS {
// check if we have allow/deny list
if !d.networkpolicy.Validate(ip) {
if d.options.OnInvalidTarget != nil {
d.options.OnInvalidTarget(hostname, ip, port)
}
numInvalidIPS++
continue
}
if d.options.OnBeforeDial != nil {
d.options.OnBeforeDial(hostname, ip, port)
}
hostPort := net.JoinHostPort(ip, port)
if shouldUseTLS {
tlsconfigCopy := tlsconfig.Clone()
switch {
case d.options.SNIName != "":
tlsconfigCopy.ServerName = d.options.SNIName
case ctx.Value(SniName) != nil:
sniName := ctx.Value(SniName).(string)
tlsconfigCopy.ServerName = sniName
case !iputil.IsIP(hostname):
tlsconfigCopy.ServerName = hostname
}
if impersonateStrategy == impersonate.None {
conn, err = tls.DialWithDialer(d.dialer, network, hostPort, tlsconfigCopy)
} else {
nativeConn, err := d.dialer.DialContext(ctx, network, hostPort)
if err != nil {
return nativeConn, err
}
// clone existing tls config
uTLSConfig := &utls.Config{
InsecureSkipVerify: tlsconfigCopy.InsecureSkipVerify,
ServerName: tlsconfigCopy.ServerName,
MinVersion: tlsconfigCopy.MinVersion,
MaxVersion: tlsconfigCopy.MaxVersion,
CipherSuites: tlsconfigCopy.CipherSuites,
}
var uTLSConn *utls.UConn
if impersonateStrategy == impersonate.Random {
uTLSConn = utls.UClient(nativeConn, uTLSConfig, utls.HelloRandomized)
} else if impersonateStrategy == impersonate.Custom {
uTLSConn = utls.UClient(nativeConn, uTLSConfig, utls.HelloCustom)
clientHelloSpec := utls.ClientHelloSpec(ptrutil.Safe(impersonateIdentity))
if err := uTLSConn.ApplyPreset(&clientHelloSpec); err != nil {
return nil, err
}
}
if err := uTLSConn.Handshake(); err != nil {
return nil, err
}
conn = uTLSConn
}
} else if shouldUseZTLS {
ztlsconfigCopy := ztlsconfig.Clone()
switch {
case d.options.SNIName != "":
ztlsconfigCopy.ServerName = d.options.SNIName
case ctx.Value(SniName) != nil:
sniName := ctx.Value(SniName).(string)
ztlsconfigCopy.ServerName = sniName
case !iputil.IsIP(hostname):
ztlsconfigCopy.ServerName = hostname
}
conn, err = ztls.DialWithDialer(d.dialer, network, hostPort, ztlsconfigCopy)
} else {
if d.proxyDialer != nil {
dialer := *d.proxyDialer
// timeout not working for socks5 proxy dialer
// tying to handle it here
connectionCh := make(chan net.Conn, 1)
errCh := make(chan error, 1)
go func() {
conn, err = dialer.Dial(network, hostPort)
if err != nil {
errCh <- err
return
}
connectionCh <- conn
}()
// using timer as time.After is not recovered gy GC
dialerTime := time.NewTimer(d.options.DialerTimeout)
defer dialerTime.Stop()
select {
case <-dialerTime.C:
return nil, fmt.Errorf("timeout after %v", d.options.DialerTimeout)
case conn = <-connectionCh:
case err = <-errCh:
}
} else {
conn, err = d.dialer.DialContext(ctx, network, hostPort)
}
}
// fallback to ztls in case of handshake error with chrome ciphers
// ztls fallback can either be disabled by setting env variable DISABLE_ZTLS_FALLBACK=true or by setting DisableZtlsFallback=true in options
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) && !(d.options.DisableZtlsFallback && disableZTLSFallback) {
var ztlsconfigCopy *ztls.Config
if shouldUseZTLS {
ztlsconfigCopy = ztlsconfig.Clone()
} else {
if tlsconfig == nil {
tlsconfig = &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: true,
}
}
ztlsconfigCopy, err = AsZTLSConfig(tlsconfig)
if err != nil {
return nil, errorutil.NewWithErr(err).Msgf("could not convert tls config to ztls config")
}
}
ztlsconfigCopy.CipherSuites = ztls.ChromeCiphers
conn, err = ztls.DialWithDialer(d.dialer, network, hostPort, ztlsconfigCopy)
err = errorutil.WrapfWithNil(err, "ztls fallback failed")
}
if err == nil {
if d.options.WithDialerHistory && d.dialerHistory != nil {
setErr := d.dialerHistory.Set(hostname, []byte(ip))
if setErr != nil {
return nil, setErr
}
}
if d.options.OnDialCallback != nil {
d.options.OnDialCallback(hostname, ip)
}
if d.options.WithTLSData && shouldUseTLS {
if connTLS, ok := conn.(*tls.Conn); ok {
var data bytes.Buffer
connState := connTLS.ConnectionState()
err := json.NewEncoder(&data).Encode(cryptoutil.TLSGrab(&connState))
if err != nil {
return nil, err
}
setErr := d.dialerTLSData.Set(hostname, data.Bytes())
if setErr != nil {
return nil, setErr
}
}
}
break
}
}

if conn == nil {
if numInvalidIPS == len(IPS) {
return nil, NoAddressAllowedError
}
return nil, CouldNotConnectError
}

if err != nil {
return nil, err
}

return
return d.dial(ctx, &dialOptions{
network: network,
address: address,
shouldUseTLS: true,
shouldUseZTLS: false,
tlsconfig: stdTLSConfig,
ztlsconfig: nil,
impersonateStrategy: impersonate.None,
impersonateIdentity: nil,
})
}
return d.dial(ctx, &dialOptions{
network: network,
address: address,
shouldUseTLS: false,
shouldUseZTLS: true,
tlsconfig: nil,
ztlsconfig: config,
impersonateStrategy: impersonate.None,
impersonateIdentity: nil,
})
}

// Close instance and cleanups
Expand Down Expand Up @@ -569,8 +382,8 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {
}

if d.hmDnsCache != nil {
b, _ := data.Marshal()
if err != nil {
b, errX := data.Marshal()
if errX != nil {
return nil, err
}
err := d.hmDnsCache.Set(hostname, b)
Expand Down
Loading

0 comments on commit 6a73dd4

Please sign in to comment.