Skip to content

Commit

Permalink
feat: use errkit & error related refactor:
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed May 22, 2024
1 parent 138ca78 commit fdf8f3d
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 63 deletions.
6 changes: 3 additions & 3 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
retryabledns "github.com/projectdiscovery/retryabledns"
cryptoutil "github.com/projectdiscovery/utils/crypto"
"github.com/projectdiscovery/utils/env"
"github.com/projectdiscovery/utils/errkit"
"github.com/zmap/zcrypto/encoding/asn1"
ztls "github.com/zmap/zcrypto/tls"
"golang.org/x/net/proxy"
Expand Down Expand Up @@ -196,8 +197,7 @@ func (d *Dialer) DialTLS(ctx context.Context, network, address string) (conn net

// DialZTLS with encrypted connection using ztls
func (d *Dialer) DialZTLS(ctx context.Context, network, address string) (conn net.Conn, err error) {
conn, err = d.DialZTLSWithConfig(ctx, network, address, &ztls.Config{InsecureSkipVerify: true})
return
return d.DialZTLSWithConfig(ctx, network, address, &ztls.Config{InsecureSkipVerify: true})
}

// DialTLS with encrypted connection
Expand Down Expand Up @@ -234,7 +234,7 @@ func (d *Dialer) DialZTLSWithConfig(ctx context.Context, network, address string
if IsTLS13(config) {
stdTLSConfig, err := AsTLSConfig(config)
if err != nil {
return nil, err
return nil, errkit.Wrap(err, "could not convert ztls config to tls config")
}
return d.dial(ctx, &dialOptions{
network: network,
Expand Down
65 changes: 46 additions & 19 deletions fastdialer/dialer_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"os"
"strings"
Expand All @@ -16,7 +16,7 @@ import (
"github.com/projectdiscovery/fastdialer/fastdialer/utils"
ctxutil "github.com/projectdiscovery/utils/context"
cryptoutil "github.com/projectdiscovery/utils/crypto"
errorutil "github.com/projectdiscovery/utils/errors"
"github.com/projectdiscovery/utils/errkit"
iputil "github.com/projectdiscovery/utils/ip"
ptrutil "github.com/projectdiscovery/utils/ptr"
utls "github.com/refraction-networking/utls"
Expand Down Expand Up @@ -48,7 +48,7 @@ type dialOptions struct {

func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, err error) {
// add global timeout to context
ctx, cancel := context.WithTimeoutCause(ctx, d.options.DialerTimeout, fmt.Errorf("fastdialer dial timeout"))
ctx, cancel := context.WithTimeoutCause(ctx, d.options.DialerTimeout, ErrDialTimeout)
defer cancel()

var hostname, port, fixedIP string
Expand Down Expand Up @@ -99,12 +99,10 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
if data == nil {
return nil, ResolveHostError
}

if err != nil || len(data.A)+len(data.AAAA) == 0 {
return nil, NoAddressFoundError
}

var numInvalidIPS int
// use fixed ip as first
if fixedIP != "" {
IPS = append(IPS, fixedIP)
Expand All @@ -121,7 +119,6 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
if d.options.OnInvalidTarget != nil {
d.options.OnInvalidTarget(hostname, ip, port)
}
numInvalidIPS++
continue
}
if d.options.OnBeforeDial != nil {
Expand Down Expand Up @@ -153,10 +150,10 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er

dw, err = utils.NewDialWrap(d.dialer, IPS, opts.network, opts.address, opts.port)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("could not create dialwrap"))
return nil, errkit.Wrap(err, "could not create dialwrap")
}
if err = d.dialCache.Set(connHash(opts.network, opts.address), dw); err != nil {
return nil, errors.Join(err, fmt.Errorf("could not set dialwrap"))
return nil, errkit.Wrap(err, "could not set dialwrap")
}
}
if dw != nil {
Expand All @@ -178,7 +175,7 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
return
}
if conn.RemoteAddr() == nil {
return nil, errors.New("remote address is nil")
return nil, errkit.New("remote address is nil")
}
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
if d.options.WithDialerHistory && d.dialerHistory != nil {
Expand All @@ -205,16 +202,22 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
}
}
}
// if conn == nil {
// return nil, CouldNotConnectError
// }
return
}

func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (conn net.Conn, err error) {
hostPort := net.JoinHostPort(opts.ips[0], opts.port)

// logAddress is the address that will be logged in case of error
logAddress := opts.hostname
if logAddress == "" {
logAddress = opts.ips[0]
}
logAddress += ":" + opts.port

if opts.shouldUseTLS {
tlsconfigCopy := opts.tlsconfig.Clone()

switch {
case d.options.SNIName != "":
tlsconfigCopy.ServerName = d.options.SNIName
Expand All @@ -224,20 +227,21 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
case !iputil.IsIP(opts.hostname):
tlsconfigCopy.ServerName = opts.hostname
}

if opts.impersonateStrategy == impersonate.None {
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, err
return nil, handleDialError(err, logAddress)
}
TlsConn := tls.Client(l4Conn, tlsconfigCopy)
if err := TlsConn.HandshakeContext(ctx); err != nil {
return nil, errors.Join(err, fmt.Errorf("could not handshake"))
return nil, errkit.Wrap(err, "could not tls handshake")
}
conn = TlsConn
} else {
nativeConn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, err
return nil, handleDialError(err, logAddress)
}
// clone existing tls config
uTLSConfig := &utls.Config{
Expand Down Expand Up @@ -275,7 +279,7 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
}
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, err
return nil, handleDialError(err, logAddress)
}
ztlsConn := ztls.Client(l4Conn, ztlsconfigCopy)
_, err = ctxutil.ExecFuncWithTwoReturns(ctx, func() (bool, error) {
Expand Down Expand Up @@ -310,13 +314,15 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
case conn = <-connectionCh:
case err = <-errCh:
}
err = handleDialError(err, logAddress)
} else {
conn, err = l4.DialContext(ctx, opts.network, hostPort)
err = handleDialError(err, logAddress)
}
}
// fallback to ztls in case of handshake error with chrome ciphers
// ztls fallback can either be disabled by setting env variable DISABLE_ZTLS_FALLBACK=true or by setting DisableZtlsFallback=true in options
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) && !(d.options.DisableZtlsFallback && disableZTLSFallback) {
if err != nil && !errkit.Is(err, os.ErrDeadlineExceeded) && !(d.options.DisableZtlsFallback && disableZTLSFallback) {
var ztlsconfigCopy *ztls.Config
if opts.shouldUseZTLS {
ztlsconfigCopy = opts.ztlsconfig.Clone()
Expand All @@ -330,13 +336,13 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
}
ztlsconfigCopy, err = AsZTLSConfig(opts.tlsconfig)
if err != nil {
return nil, errorutil.NewWithErr(err).Msgf("could not convert tls config to ztls config")
return nil, errkit.Wrap(err, "could not convert tls config to ztls config")
}
}
ztlsconfigCopy.CipherSuites = ztls.ChromeCiphers
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, err
return nil, handleDialError(err, logAddress)
}
ztlsConn := ztls.Client(l4Conn, ztlsconfigCopy)
_, err = ctxutil.ExecFuncWithTwoReturns(ctx, func() (bool, error) {
Expand All @@ -355,3 +361,24 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
func connHash(network string, address string) string {
return fmt.Sprintf("%s-%s", network, address)
}

// handleDialError is a helper function to handle dial errors
// it also adds address attribute to the error
func handleDialError(err error, address string) error {
if err == nil {
return nil
}
errx := errkit.FromError(err)
errx = errx.SetAttr(slog.Any("address", address))
// if error kind is not set, if it is i/o timeout, set it to temporary
if errx.Kind() == nil {
if errx.Cause() != nil && strings.Contains(errx.Cause().Error(), "i/o timeout") {
// TODO: this is a tough call, i/o timeout happens in both cases
// it could be either temporary or permanent internally i/o timeout
// is actually a context.DeadlineExceeded error but std lib has decided to keep legacy/original error
errx = errx.SetKind(errkit.ErrKindNetworkTemporary)
}
}
// TODO: parse and mark permanent or temporary errors
return errx
}
25 changes: 12 additions & 13 deletions fastdialer/error.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package fastdialer

import (
"github.com/pkg/errors"
)
import "github.com/projectdiscovery/utils/errkit"

var (
CouldNotConnectError = errors.New("could not connect to any address found for host")
NoAddressFoundError = errors.New("no address found for host")
NoAddressAllowedError = errors.New("denied address found for host")
NoPortSpecifiedError = errors.New("port was not specified")
MalformedIP6Error = errors.New("malformed IPv6 address")
ResolveHostError = errors.New("could not resolve host")
NoTLSHistoryError = errors.New("no tls data history available")
NoTLSDataError = errors.New("no tls data found for the key")
NoDNSDataError = errors.New("no data found")
AsciiConversionError = errors.New("could not convert hostname to ASCII")
CouldNotConnectError = errkit.New("could not connect to any address found for host").SetKind(errkit.ErrKindNetworkPermanent)
NoAddressFoundError = errkit.New("no address found for host").SetKind(errkit.ErrKindNetworkPermanent)
NoAddressAllowedError = errkit.New("denied address found for host").SetKind(errkit.ErrKindNetworkPermanent)
NoPortSpecifiedError = errkit.New("port was not specified").SetKind(errkit.ErrKindNetworkPermanent)
MalformedIP6Error = errkit.New("malformed IPv6 address").SetKind(errkit.ErrKindNetworkPermanent)
ResolveHostError = errkit.New("could not resolve host").SetKind(errkit.ErrKindNetworkPermanent)
NoTLSHistoryError = errkit.New("no tls data history available")
NoTLSDataError = errkit.New("no tls data found for the key")
NoDNSDataError = errkit.New("no data found")
AsciiConversionError = errkit.New("could not convert hostname to ASCII")
ErrDialTimeout = errkit.New("dial timeout").SetKind(errkit.ErrKindNetworkTemporary)
)
6 changes: 3 additions & 3 deletions fastdialer/ja3/ja3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package ja3

import (
"crypto/sha256"
"errors"
"fmt"
"strconv"
"strings"

"github.com/projectdiscovery/utils/errkit"
utls "github.com/refraction-networking/utls"
)

Expand Down Expand Up @@ -72,7 +72,7 @@ func parseVersion(version string) (uint16, error) {
func parseCipherSuites(cipherToken string) ([]uint16, error) {
cipherToken = cleanup(cipherToken)
if cipherToken == "" {
return nil, errors.New("no cipher suites provided")
return nil, errkit.New("no cipher suites provided")
}
ciphers := strings.Split(cipherToken, "-")
var cipherSuites []uint16
Expand All @@ -90,7 +90,7 @@ func parseExtensions(extensionToken string) ([]utls.TLSExtension, error) {
var extensions []utls.TLSExtension
extensionToken = cleanup(extensionToken)
if extensionToken == "" {
return nil, errors.New("no extensions provided")
return nil, errkit.New("no extensions provided")
}
exts := strings.Split(extensionToken, "-")
for _, ext := range exts {
Expand Down
29 changes: 17 additions & 12 deletions fastdialer/utils/dialwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"sync/atomic"
"time"

"github.com/pkg/errors"
"go.uber.org/multierr"

"github.com/projectdiscovery/utils/errkit"
iputil "github.com/projectdiscovery/utils/ip"
)

Expand Down Expand Up @@ -37,9 +35,10 @@ import (
// Error constants
var (
// errGotConnection has already been established
ErrInflightCancel = errors.New("context cancelled before establishing connection")
ErrNoIPs = errors.New("no ips provided in dialWrap")
ExpireConnAfter = time.Duration(5) * time.Second
ErrInflightCancel = errkit.New("context cancelled before establishing connection")
ErrNoIPs = errkit.New("no ips provided in dialWrap")
ExpireConnAfter = time.Duration(5) * time.Second
ErrPortClosedOrFiltered = errkit.New("port closed or filtered").SetKind(errkit.ErrKindNetworkPermanent)
)

// dialResult represents the result of a dial operation
Expand Down Expand Up @@ -105,14 +104,14 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con
if d.completedFirstFlight.Load() {
// if first flight completed and it failed due to other reasons
// and not due to context cancellation
if d.err != nil && !errors.Is(d.err, ErrInflightCancel) && !errors.Is(d.err, context.Canceled) {
if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) {
return nil, d.err
}
return d.dial(ctx)
}
select {
case <-ctx.Done():
return nil, multierr.Append(ErrInflightCancel, ctx.Err())
return nil, errkit.Append(ErrInflightCancel, ctx.Err())
case res, ok := <-d.firstFlight(ctx):
if !ok {
// closed channel so depending on the error
Expand Down Expand Up @@ -193,7 +192,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) {
defer wg.Done()
select {
case <-ctx.Done():
rec <- &dialResult{error: multierr.Append(ErrInflightCancel, ctx.Err())}
rec <- &dialResult{error: errkit.Append(ErrInflightCancel, ctx.Err())}
default:
c, err := d.dialer.DialContext(ctx, d.network, net.JoinHostPort(ipx.String(), d.port))
rec <- &dialResult{Conn: c, error: err, expiry: time.Now().Add(ExpireConnAfter)}
Expand All @@ -209,7 +208,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) {
if result.Conn != nil {
conns = append(conns, result)
} else {
if !errors.Is(result.error, ErrInflightCancel) {
if !errkit.Is(result.error, ErrInflightCancel) {
errs = append(errs, result)
}
}
Expand All @@ -227,7 +226,13 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) {
// and blacklist those ips permanently
var finalErr error
for _, v := range errs {
finalErr = multierr.Append(finalErr, v.error)
finalErr = errkit.Append(finalErr, v.error)
}
// if this is the case then most likely the port is closed or filtered
// so return appropriate error
if !errkit.Is(finalErr, ErrInflightCancel) {
// if it not inflight cancel then it is a permanent error
return nil, errkit.Append(ErrPortClosedOrFiltered, finalErr)
}
return nil, finalErr
}
Expand Down Expand Up @@ -393,7 +398,7 @@ func (d *DialWrap) dialSerial(ctx context.Context, ras []net.IP, network, port s
}

if firstErr == nil {
firstErr = errors.Wrap(net.UnknownNetworkError(network), "dialSerial")
firstErr = errkit.Wrap(net.UnknownNetworkError(network), "dialSerial")
}
return nil, firstErr
}
Expand Down
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@ require (
github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057
github.com/dimchansky/utfbom v1.1.1
github.com/docker/go-units v0.5.0
github.com/pkg/errors v0.9.1
github.com/projectdiscovery/hmap v0.0.42
github.com/projectdiscovery/hmap v0.0.43
github.com/projectdiscovery/networkpolicy v0.0.8
github.com/projectdiscovery/retryabledns v1.0.59
github.com/projectdiscovery/utils v0.0.93-0.20240519190012-c4bf7513228c
github.com/projectdiscovery/retryabledns v1.0.60
github.com/projectdiscovery/utils v0.0.95-0.20240522204248-10ef59b98abe
github.com/refraction-networking/utls v1.5.4
github.com/stretchr/testify v1.9.0
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9
github.com/zmap/zcrypto v0.0.0-20230422215203-9a665e1e9968
go.uber.org/multierr v1.11.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/net v0.23.0
)
Expand All @@ -36,6 +34,7 @@ require (
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/microcosm-cc/bluemonday v1.0.25 // indirect
github.com/miekg/dns v1.1.56 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/projectdiscovery/blackrock v0.0.1 // indirect
github.com/quic-go/quic-go v0.42.0 // indirect
Expand All @@ -53,6 +52,7 @@ require (
github.com/yl2chen/cidranger v1.0.2 // indirect
github.com/zmap/rc2 v0.0.0-20190804163417-abaa70531248 // indirect
go.etcd.io/bbolt v1.3.7 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/sys v0.18.0 // indirect
Expand Down
Loading

0 comments on commit fdf8f3d

Please sign in to comment.