Skip to content

Commit

Permalink
feat: sniff add skip-src-address and skip-dst-address
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Aug 27, 2024
1 parent 3e2c9ce commit 8483178
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 117 deletions.
4 changes: 4 additions & 0 deletions component/cidr/ipcidr_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func (set *IpCidrSet) Merge() error {
return nil
}

func (set *IpCidrSet) IsEmpty() bool {
return set == nil || len(set.rr) == 0
}

func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) {
for _, r := range set.rr {
for _, prefix := range r.Prefixes() {
Expand Down
106 changes: 59 additions & 47 deletions component/sniffer/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sniffer

import (
"errors"
"fmt"
"net"
"net/netip"
"time"
Expand All @@ -20,19 +19,29 @@ var (
ErrNoClue = errors.New("not enough information for making a decision")
)

var Dispatcher *SnifferDispatcher

type SnifferDispatcher struct {
type Dispatcher struct {
enable bool
sniffers map[sniffer.Sniffer]SnifferConfig
forceDomain []C.Rule
skipSrcAddress []C.Rule
skipDstAddress []C.Rule
skipDomain []C.Rule
skipList *lru.LruCache[string, uint8]
skipList *lru.LruCache[netip.AddrPort, uint8]
forceDnsMapping bool
parsePureIp bool
}

func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
func (sd *Dispatcher) shouldOverride(metadata *C.Metadata) bool {
for _, rule := range sd.skipDstAddress {
if ok, _ := rule.Match(&C.Metadata{DstIP: metadata.DstIP}); ok {
return false
}
}
for _, rule := range sd.skipSrcAddress {
if ok, _ := rule.Match(&C.Metadata{DstIP: metadata.SrcIP}); ok {
return false
}
}
if metadata.Host == "" && sd.parsePureIp {
return true
}
Expand All @@ -47,10 +56,9 @@ func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
return false
}

func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter) bool {
metadata := packet.Metadata()

if sd.shouldOverride(packet.Metadata()) {
if sd.shouldOverride(metadata) {
for sniffer, config := range sd.sniffers {
if sniffer.SupportNetwork() == C.UDP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist := sniffer.SupportPort(metadata.DstPort)
Expand All @@ -73,7 +81,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
}

// TCPSniff returns true if the connection is sniffed to have a domain
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
func (sd *Dispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
if sd.shouldOverride(metadata) {
inWhitelist := false
overrideDest := false
Expand All @@ -91,34 +99,35 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false
}

dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
dst := metadata.AddrPort()
if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
return false
}

if host, err := sd.sniffDomain(conn, metadata); err != nil {
host, err := sd.sniffDomain(conn, metadata)
if err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return false
} else {
for _, rule := range sd.skipDomain {
if ok, _ := rule.Match(&C.Metadata{Host: host}); ok {
log.Debugln("[Sniffer] Skip sni[%s]", host)
return false
}
}

for _, rule := range sd.skipDomain {
if ok, _ := rule.Match(&C.Metadata{Host: host}); ok {
log.Debugln("[Sniffer] Skip sni[%s]", host)
return false
}
}

sd.skipList.Delete(dst)
sd.skipList.Delete(dst)

sd.replaceDomain(metadata, host, overrideDest)
return true
}
sd.replaceDomain(metadata, host, overrideDest)
return true
}
return false
}

func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
func (sd *Dispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
metadata.SniffHost = host
if overrideDest {
log.Debugln("[Sniffer] Sniff %s [%s]-->[%s] success, replace domain [%s]-->[%s]",
Expand All @@ -131,11 +140,11 @@ func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, ov
metadata.DNSMode = C.DNSNormal
}

func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
func (sd *Dispatcher) Enable() bool {
return sd != nil && sd.enable
}

func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
func (sd *Dispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
for s := range sd.sniffers {
if s.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
Expand Down Expand Up @@ -178,8 +187,8 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
return "", ErrorSniffFailed
}

func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
func (sd *Dispatcher) cacheSniffFailed(metadata *C.Metadata) {
dst := metadata.AddrPort()
sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
if oldValue <= 5 {
oldValue++
Expand All @@ -188,32 +197,35 @@ func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
})
}

func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: false,
}

return &dispatcher, nil
type Config struct {
Enable bool
Sniffers map[sniffer.Type]SnifferConfig
ForceDomain []C.Rule
SkipSrcAddress []C.Rule
SkipDstAddress []C.Rule
SkipDomain []C.Rule
ForceDnsMapping bool
ParsePureIp bool
}

func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig,
forceDomain []C.Rule, skipDomain []C.Rule,
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
forceDomain: forceDomain,
skipDomain: skipDomain,
skipList: lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0),
func NewDispatcher(snifferConfig *Config) (*Dispatcher, error) {
dispatcher := Dispatcher{
enable: snifferConfig.Enable,
forceDomain: snifferConfig.ForceDomain,
skipSrcAddress: snifferConfig.SkipSrcAddress,
skipDstAddress: snifferConfig.SkipDstAddress,
skipDomain: snifferConfig.SkipDomain,
skipList: lru.New(lru.WithSize[netip.AddrPort, uint8](128), lru.WithAge[netip.AddrPort, uint8](600)),
forceDnsMapping: snifferConfig.ForceDnsMapping,
parsePureIp: snifferConfig.ParsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, len(snifferConfig.Sniffers)),
}

for snifferName, config := range snifferConfig {
for snifferName, config := range snifferConfig.Sniffers {
s, err := NewSniffer(snifferName, config)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err
return &Dispatcher{enable: false}, err
}
dispatcher.sniffers[s] = config
}
Expand Down
Loading

0 comments on commit 8483178

Please sign in to comment.