Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix memory leaks #241

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 57 additions & 44 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"time"

"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/fastdialer/fastdialer/metafiles"
"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/networkpolicy"
retryabledns "github.com/projectdiscovery/retryabledns"
cryptoutil "github.com/projectdiscovery/utils/crypto"
"github.com/projectdiscovery/utils/env"
errorutil "github.com/projectdiscovery/utils/errors"
iputil "github.com/projectdiscovery/utils/ip"
ptrutil "github.com/projectdiscovery/utils/ptr"
Expand All @@ -28,22 +30,24 @@ import (

// option to disable ztls fallback in case of handshake error
// reads from env variable DISABLE_ZTLS_FALLBACK
var disableZTLSFallback = false
var (
disableZTLSFallback = false
MaxDNSCacheSize = 10 * 1024 * 1024 // 10 MB
)

func init() {
// enable permissive parsing for ztls, so that it can allow permissive parsing for X509 certificates
asn1.AllowPermissiveParsing = true
value := os.Getenv("DISABLE_ZTLS_FALLBACK")
if strings.EqualFold(value, "true") {
disableZTLSFallback = true
}
disableZTLSFallback = env.GetEnvOrDefault("DISABLE_ZTLS_FALLBACK", false)
MaxDNSCacheSize = env.GetEnvOrDefault("MAX_DNS_CACHE_SIZE", 10*1024*1024)
}

// Dialer structure containing data information
type Dialer struct {
options *Options
dnsclient *retryabledns.Client
hm *hybrid.HybridMap
dnsCache *hybrid.HybridMap
hostsFileData *hybrid.HybridMap
dialerHistory *hybrid.HybridMap
dialerTLSData *hybrid.HybridMap
dialer *net.Dialer
Expand All @@ -62,12 +66,8 @@ func NewDialer(options Options) (*Dialer, error) {
}
}

cacheOptions := getHMapConfiguration(options)
resolvers = append(resolvers, options.BaseResolvers...)
hm, err := hybrid.New(cacheOptions)
if err != nil {
return nil, err
}
var err error
var dialerHistory *hybrid.HybridMap
if options.WithDialerHistory {
// we need to use disk to store all the dialed ips
Expand All @@ -78,6 +78,22 @@ func NewDialer(options Options) (*Dialer, error) {
return nil, err
}
}
// when loading in memory set max size to 10 MB
var dnsCache *hybrid.HybridMap
if options.CacheType == Memory {
opts := hybrid.DefaultMemoryOptions
opts.MaxMemorySize = MaxDNSCacheSize
dnsCache, err = hybrid.New(opts)
if err != nil {
return nil, err
}
} else {
dnsCache, err = hybrid.New(hybrid.DefaultHybridOptions)
if err != nil {
return nil, err
}
}

var dialerTLSData *hybrid.HybridMap
if options.WithTLSData {
dialerTLSData, err = hybrid.New(hybrid.DefaultDiskOptions)
Expand All @@ -97,10 +113,14 @@ func NewDialer(options Options) (*Dialer, error) {
}
}

var hostsFileData *hybrid.HybridMap
// load hardcoded values from host file
if options.HostsFile {
// nolint:errcheck // if they cannot be loaded it's not a hard failure
loadHostsFile(hm)
if options.CacheType == Memory {
hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.InMemory)
} else {
hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.Hybrid)
}
}
dnsclient, err := retryabledns.New(resolvers, options.MaxRetries)
if err != nil {
Expand Down Expand Up @@ -128,7 +148,17 @@ func NewDialer(options Options) (*Dialer, error) {
return nil, err
}

return &Dialer{dnsclient: dnsclient, hm: hm, dialerHistory: dialerHistory, dialerTLSData: dialerTLSData, dialer: dialer, proxyDialer: options.ProxyDialer, options: &options, networkpolicy: np}, nil
return &Dialer{
dnsclient: dnsclient,
dnsCache: dnsCache,
hostsFileData: hostsFileData,
dialerHistory: dialerHistory,
dialerTLSData: dialerTLSData,
dialer: dialer,
proxyDialer: options.ProxyDialer,
options: &options,
networkpolicy: np,
}, nil
}

// Dial function compatible with net/http
Expand Down Expand Up @@ -398,15 +428,16 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS

// Close instance and cleanups
func (d *Dialer) Close() {
if d.hm != nil {
d.hm.Close()
if d.dnsCache != nil {
d.dnsCache.Close()
}
if d.options.WithDialerHistory && d.dialerHistory != nil {
d.dialerHistory.Close()
}
if d.options.WithTLSData {
d.dialerTLSData.Close()
}
// donot close hosts file as it is meant to be shared
}

// GetDialedIP returns the ip dialed by the HTTP client
Expand Down Expand Up @@ -447,11 +478,17 @@ func (d *Dialer) GetTLSData(hostname string) (*cryptoutil.TLSData, error) {
func (d *Dialer) GetDNSDataFromCache(hostname string) (*retryabledns.DNSData, error) {
hostname = asAscii(hostname)
var data retryabledns.DNSData
dataBytes, ok := d.hm.Get(hostname)
var dataBytes []byte
var ok bool
if d.hostsFileData != nil {
dataBytes, ok = d.hostsFileData.Get(hostname)
}
if !ok {
return nil, NoDNSDataError
dataBytes, ok = d.dnsCache.Get(hostname)
if !ok {
return nil, NoDNSDataError
}
}

err := data.Unmarshal(dataBytes)
return &data, err
}
Expand Down Expand Up @@ -498,7 +535,7 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {
}
if len(data.A)+len(data.AAAA) > 0 {
b, _ := data.Marshal()
err = d.hm.Set(hostname, b)
err = d.dnsCache.Set(hostname, b)
}
if err != nil {
return nil, err
Expand All @@ -508,30 +545,6 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {
return data, nil
}

func getHMapConfiguration(options Options) hybrid.Options {
var cacheOptions hybrid.Options
switch options.CacheType {
case Memory:
cacheOptions = hybrid.DefaultMemoryOptions
if options.CacheMemoryMaxItems > 0 {
cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems
}
case Disk:
cacheOptions = hybrid.DefaultDiskOptions
cacheOptions.DBType = getHMAPDBType(options)
case Hybrid:
cacheOptions = hybrid.DefaultHybridOptions
}
if options.WithCleanup {
cacheOptions.Cleanup = options.WithCleanup
if options.CacheMemoryMaxItems > 0 {
cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems
}
cacheOptions.DBType = getHMAPDBType(options)
}
return cacheOptions
}

func getHMAPDBType(options Options) hybrid.DBType {
switch options.DiskDbType {
case Pogreb:
Expand Down
3 changes: 3 additions & 0 deletions fastdialer/metafiles/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// metafiles are metadata files related to networking like
// /etc/hosts etc
package metafiles
22 changes: 16 additions & 6 deletions fastdialer/hostsfile.go → fastdialer/metafiles/hostsfile.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
package fastdialer
package metafiles

import (
"bufio"
"net"
"os"
"path/filepath"
"runtime/debug"
"strings"

"github.com/dimchansky/utfbom"
"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/retryabledns"
)

func loadHostsFile(hm *hybrid.HybridMap) error {
// loads Entries from hosts file if max is -1 it will load all entries to given hybrid map
func loadHostsFile(hm *hybrid.HybridMap, max int) error {
osHostsFilePath := os.ExpandEnv(filepath.FromSlash(HostsFilePath))

if env, isset := os.LookupEnv("HOSTS_PATH"); isset && len(env) > 0 {
Expand All @@ -28,6 +30,9 @@ func loadHostsFile(hm *hybrid.HybridMap) error {
dnsDatas := make(map[string]retryabledns.DNSData)
scanner := bufio.NewScanner(utfbom.SkipOnly(file))
for scanner.Scan() {
if max > 0 && len(dnsDatas) == MaxHostsEntires {
break
}
ip, hosts := HandleHostLine(scanner.Text())
if ip == "" || len(hosts) == 0 {
continue
Expand All @@ -53,10 +58,15 @@ func loadHostsFile(hm *hybrid.HybridMap) error {
dnsdataBytes, _ := dnsdata.Marshal()
_ = hm.Set(host, dnsdataBytes)
}
if len(dnsDatas) > 10000 && max < 0 {
// this freeups memory when loading large hosts files
// useful when loading all entries to hybrid storage
debug.FreeOSMemory()
}
return nil
}

const commentChar string = "#"
const CommentChar string = "#"

// HandleHostLine a hosts file line
func HandleHostLine(raw string) (ip string, hosts []string) {
Expand All @@ -67,7 +77,7 @@ func HandleHostLine(raw string) (ip string, hosts []string) {

// trim comment
if HasComment(raw) {
commentSplit := strings.Split(raw, commentChar)
commentSplit := strings.Split(raw, CommentChar)
raw = commentSplit[0]
}

Expand All @@ -88,10 +98,10 @@ func HandleHostLine(raw string) (ip string, hosts []string) {

// IsComment check if the file is a comment
func IsComment(raw string) bool {
return strings.HasPrefix(strings.TrimSpace(raw), commentChar)
return strings.HasPrefix(strings.TrimSpace(raw), CommentChar)
}

// HasComment check if the line has a comment
func HasComment(raw string) bool {
return strings.Contains(raw, commentChar)
return strings.Contains(raw, CommentChar)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//go:build !windows
// +build !windows

package fastdialer
package metafiles

// HostsFilePath in unix file os
const HostsFilePath = "/etc/hosts"
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//go:build windows
// +build windows

package fastdialer
package metafiles

const HostsFilePath = "${SystemRoot}/System32/drivers/etc/hosts"
89 changes: 89 additions & 0 deletions fastdialer/metafiles/shared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package metafiles

import (
"runtime"
"sync"

"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/utils/env"
)

type StorageType int

const (
InMemory StorageType = iota
Hybrid
)

var (
MaxHostsEntires = 4096
// LoadAllEntries is a switch when true loads all entries to hybrid storage
// backend and uses it even if in-memory storage backend was requested
LoadAllEntries = false
)

func init() {
MaxHostsEntires = env.GetEnvOrDefault("HF_MAX_HOSTS", 4096)
LoadAllEntries = env.GetEnvOrDefault("HF_LOAD_ALL", false)
}

// GetHostsFileDnsData returns the immutable dns data that is constant throughout the program
// lifecycle and shouldn't be purged by cache etc.
func GetHostsFileDnsData(storage StorageType) (*hybrid.HybridMap, error) {
if LoadAllEntries {
storage = Hybrid
}
switch storage {
case InMemory:
return getHFInMemory()
case Hybrid:
return getHFHybridStorage()
}
return nil, nil
}

var hostsMemOnce = &sync.Once{}

// getImm
func getHFInMemory() (*hybrid.HybridMap, error) {
var hm *hybrid.HybridMap
var err error
hostsMemOnce.Do(func() {
opts := hybrid.DefaultMemoryOptions
hm, err = hybrid.New(opts)
if err != nil {
return
}
err = loadHostsFile(hm, MaxHostsEntires)
if err != nil {
hm.Close()
return
}
})
return hm, nil
}

var hostsHybridOnce = &sync.Once{}

func getHFHybridStorage() (*hybrid.HybridMap, error) {
var hm *hybrid.HybridMap
var err error
hostsHybridOnce.Do(func() {
opts := hybrid.DefaultHybridOptions
opts.Cleanup = true
hm, err = hybrid.New(opts)
if err != nil {
return
}
err = loadHostsFile(hm, -1)
if err != nil {
hm.Close()
return
}
// set finalizer for cleanup
runtime.SetFinalizer(hm, func(hm *hybrid.HybridMap) {
_ = hm.Close()
})
})
return hm, nil
}
Loading
Loading