diff --git a/fastdialer/dialer.go b/fastdialer/dialer.go index 5daef0a..9393ec9 100644 --- a/fastdialer/dialer.go +++ b/fastdialer/dialer.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "github.com/Mzack9999/gcache" + gounit "github.com/docker/go-units" "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" "github.com/projectdiscovery/fastdialer/fastdialer/metafiles" "github.com/projectdiscovery/hmap/store/hybrid" @@ -32,21 +34,31 @@ import ( // reads from env variable DISABLE_ZTLS_FALLBACK var ( disableZTLSFallback = false - MaxDNSCacheSize = 10 * 1024 * 1024 // 10 MB + MaxDNSCacheSize int64 + MaxDNSItems = 1024 ) func init() { // enable permissive parsing for ztls, so that it can allow permissive parsing for X509 certificates asn1.AllowPermissiveParsing = true disableZTLSFallback = env.GetEnvOrDefault("DISABLE_ZTLS_FALLBACK", false) - MaxDNSCacheSize = env.GetEnvOrDefault("MAX_DNS_CACHE_SIZE", 10*1024*1024) + maxCacheSize := env.GetEnvOrDefault("MAX_DNS_CACHE_SIZE", "10mb") + maxDnsCacheSize, err := gounit.FromHumanSize(maxCacheSize) + if err != nil { + panic(err) + } + MaxDNSCacheSize = maxDnsCacheSize + MaxDNSItems = env.GetEnvOrDefault("MAX_DNS_ITEMS", 1024) } // Dialer structure containing data information type Dialer struct { - options *Options - dnsclient *retryabledns.Client - dnsCache *hybrid.HybridMap + options *Options + dnsclient *retryabledns.Client + // memory typed cache + mDnsCache gcache.Cache[string, *retryabledns.DNSData] + // memory/disk untyped ([]byte) cache + hmDnsCache *hybrid.HybridMap hostsFileData *hybrid.HybridMap dialerHistory *hybrid.HybridMap dialerTLSData *hybrid.HybridMap @@ -79,16 +91,15 @@ func NewDialer(options Options) (*Dialer, error) { } } // when loading in memory set max size to 10 MB - var dnsCache *hybrid.HybridMap + var ( + hmDnsCache *hybrid.HybridMap + dnsCache gcache.Cache[string, *retryabledns.DNSData] + ) + options.CacheType = Memory if options.CacheType == Memory { - opts := hybrid.DefaultMemoryOptions - opts.MaxMemorySize = MaxDNSCacheSize - dnsCache, err = hybrid.New(opts) - if err != nil { - return nil, err - } + dnsCache = gcache.New[string, *retryabledns.DNSData](MaxDNSItems).Build() } else { - dnsCache, err = hybrid.New(hybrid.DefaultHybridOptions) + hmDnsCache, err = hybrid.New(hybrid.DefaultHybridOptions) if err != nil { return nil, err } @@ -150,7 +161,8 @@ func NewDialer(options Options) (*Dialer, error) { return &Dialer{ dnsclient: dnsclient, - dnsCache: dnsCache, + mDnsCache: dnsCache, + hmDnsCache: hmDnsCache, hostsFileData: hostsFileData, dialerHistory: dialerHistory, dialerTLSData: dialerTLSData, @@ -247,7 +259,6 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS if err != nil { // otherwise attempt to retrieve it data, err = d.dnsclient.Resolve(hostname) - } if data == nil { return nil, ResolveHostError @@ -428,8 +439,11 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS // Close instance and cleanups func (d *Dialer) Close() { - if d.dnsCache != nil { - d.dnsCache.Close() + if d.mDnsCache != nil { + d.mDnsCache.Purge() + } + if d.hmDnsCache != nil { + d.hmDnsCache.Close() } if d.options.WithDialerHistory && d.dialerHistory != nil { d.dialerHistory.Close() @@ -484,7 +498,11 @@ func (d *Dialer) GetDNSDataFromCache(hostname string) (*retryabledns.DNSData, er dataBytes, ok = d.hostsFileData.Get(hostname) } if !ok { - dataBytes, ok = d.dnsCache.Get(hostname) + if d.mDnsCache != nil { + return d.mDnsCache.GetIFPresent(hostname) + } + + dataBytes, ok = d.hmDnsCache.Get(hostname) if !ok { return nil, NoDNSDataError } @@ -534,11 +552,24 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) { return nil, ResolveHostError } if len(data.A)+len(data.AAAA) > 0 { - b, _ := data.Marshal() - err = d.dnsCache.Set(hostname, b) - } - if err != nil { - return nil, err + if d.mDnsCache != nil { + err := d.mDnsCache.Set(hostname, data) + if err != nil { + return nil, err + } + } + + if d.hmDnsCache != nil { + b, _ := data.Marshal() + if err != nil { + return nil, err + } + err := d.hmDnsCache.Set(hostname, b) + if err != nil { + return nil, err + } + } + } return data, nil } diff --git a/go.mod b/go.mod index e95d6a9..9be1761 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/projectdiscovery/fastdialer go 1.21 require ( + github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 github.com/dimchansky/utfbom v1.1.1 + github.com/docker/go-units v0.5.0 github.com/pkg/errors v0.9.1 github.com/projectdiscovery/hmap v0.0.34 github.com/projectdiscovery/networkpolicy v0.0.7 diff --git a/go.sum b/go.sum index 9dd2bda..2bf5567 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 h1:KFac3SiGbId8ub47e7kd2PLZeACxc1LkiiNoDOFRClE= +github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057/go.mod h1:iLB2pivrPICvLOuROKmlqURtFIEsoJZaMidQfCG1+D4= github.com/akrylysov/pogreb v0.10.1 h1:FqlR8VR7uCbJdfUob916tPM+idpKgeESDXOA1K0DK4w= github.com/akrylysov/pogreb v0.10.1/go.mod h1:pNs6QmpQ1UlTJKDezuRWmaqkgUE2TuU0YTWyqJZ7+lI= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= @@ -18,6 +20,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U= github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=