Skip to content

Commit

Permalink
Move predefined DNS server to rule action
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Feb 26, 2025
1 parent 1bb1a47 commit e2d6b78
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 181 deletions.
29 changes: 14 additions & 15 deletions common/dialer/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ type resolveDialer struct {
}

func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
return &resolveParallelNetworkDialer{
resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
},
parallelDialer,
}
}
return &resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
Expand All @@ -60,21 +74,6 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer
}

func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
},
dialer,
}
}

func (d *resolveDialer) initialize() error {
d.initOnce.Do(d.initServer)
return d.initErr
Expand Down
26 changes: 13 additions & 13 deletions constant/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ const (
)

const (
DNSTypeLegacy = "legacy"
DNSTypeUDP = "udp"
DNSTypeTCP = "tcp"
DNSTypeTLS = "tls"
DNSTypeHTTPS = "https"
DNSTypeQUIC = "quic"
DNSTypeHTTP3 = "h3"
DNSTypeHosts = "hosts"
DNSTypeLocal = "local"
DNSTypePreDefined = "predefined"
DNSTypeFakeIP = "fakeip"
DNSTypeDHCP = "dhcp"
DNSTypeTailscale = "tailscale"
DNSTypeLegacy = "legacy"
DNSTypeLegacyRcode = "legacy_rcode"
DNSTypeUDP = "udp"
DNSTypeTCP = "tcp"
DNSTypeTLS = "tls"
DNSTypeHTTPS = "https"
DNSTypeQUIC = "quic"
DNSTypeHTTP3 = "h3"
DNSTypeLocal = "local"
DNSTypeHosts = "hosts"
DNSTypeFakeIP = "fakeip"
DNSTypeDHCP = "dhcp"
DNSTypeTailscale = "tailscale"
)

const (
Expand Down
1 change: 1 addition & 0 deletions constant/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
RuleActionTypeHijackDNS = "hijack-dns"
RuleActionTypeSniff = "sniff"
RuleActionTypeResolve = "resolve"
RuleActionTypePredefined = "predefined"
)

const (
Expand Down
32 changes: 32 additions & 0 deletions dns/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
}
case *R.RuleActionReject:
return nil, currentRule, currentRuleIndex
case *R.RuleActionPredefined:
return nil, currentRule, currentRuleIndex
}
}
}
Expand Down Expand Up @@ -260,6 +262,21 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
case *R.RuleActionPredefined:
return &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Response: true,
Authoritative: true,
RecursionDesired: true,
RecursionAvailable: true,
Rcode: action.Rcode,
},
Question: message.Question,
Answer: action.Answer,
Ns: action.Ns,
Extra: action.Extra,
}, nil
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
Expand Down Expand Up @@ -376,6 +393,20 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
case *R.RuleActionPredefined:
if action.Rcode != mDNS.RcodeSuccess {
err = RcodeError(action.Rcode)
} else {
for _, answer := range action.Answer {
switch record := answer.(type) {
case *mDNS.A:
responseAddrs = append(responseAddrs, M.AddrFromIP(record.A))
case *mDNS.AAAA:
responseAddrs = append(responseAddrs, M.AddrFromIP(record.AAAA))
}
}
}
goto response
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
Expand All @@ -395,6 +426,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
printResult()
}
}
response:
printResult()
if len(responseAddrs) > 0 {
r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
Expand Down
83 changes: 0 additions & 83 deletions dns/transport/predefined.go

This file was deleted.

1 change: 0 additions & 1 deletion include/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ func DNSTransportRegistry() *dns.TransportRegistry {
transport.RegisterUDP(registry)
transport.RegisterTLS(registry)
transport.RegisterHTTPS(registry)
transport.RegisterPredefined(registry)
hosts.RegisterTransport(registry)
local.RegisterTransport(registry)
fakeip.RegisterTransport(registry)
Expand Down
51 changes: 42 additions & 9 deletions option/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,46 @@ func (o *DNSOptions) UnmarshalJSONContext(ctx context.Context, content []byte) e
}
legacyOptions := o.LegacyDNSOptions
o.LegacyDNSOptions = LegacyDNSOptions{}
return badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions)
err = badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions)
if err != nil {
return err
}
rcodeMap := make(map[string]int)
o.Servers = common.Filter(o.Servers, func(it NewDNSServerOptions) bool {
if it.Type == C.DNSTypeLegacyRcode {
rcodeMap[it.Tag] = it.Options.(int)
return false
}
return true
})
if len(rcodeMap) > 0 {
for i := 0; i < len(o.Rules); i++ {
rewriteRcode(rcodeMap, &o.Rules[i])
}
}
return nil
}

func rewriteRcode(rcodeMap map[string]int, rule *DNSRule) {
switch rule.Type {
case C.RuleTypeDefault:
rewriteRcodeAction(rcodeMap, &rule.DefaultOptions.DNSRuleAction)
case C.RuleTypeLogical:
rewriteRcodeAction(rcodeMap, &rule.LogicalOptions.DNSRuleAction)
}
}

func rewriteRcodeAction(rcodeMap map[string]int, ruleAction *DNSRuleAction) {
if ruleAction.Action != C.RuleActionTypeRoute {
return
}
rcode, loaded := rcodeMap[ruleAction.RouteOptions.Server]
if !loaded {
return
}
ruleAction.Action = C.RuleActionTypePredefined
ruleAction.PredefinedOptions.Rcode = common.Ptr(DNSRCode(rcode))
return
}

type DNSClientOptions struct {
Expand Down Expand Up @@ -243,14 +282,8 @@ func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error {
default:
return E.New("unknown rcode: ", serverURL.Host)
}
o.Type = C.DNSTypePreDefined
o.Options = &PredefinedDNSServerOptions{
Responses: []DNSResponseOptions{
{
RCode: common.Ptr(DNSRCode(rcode)),
},
},
}
o.Type = C.DNSTypeLegacyRcode
o.Options = rcode
case C.DNSTypeDHCP:
o.Type = C.DNSTypeDHCP
dhcpOptions := DHCPDNSServerOptions{}
Expand Down
61 changes: 1 addition & 60 deletions option/dns_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,14 @@ package option
import (
"encoding/base64"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"

"github.com/miekg/dns"
)

type PredefinedDNSServerOptions struct {
Responses []DNSResponseOptions `json:"responses,omitempty"`
}

type DNSResponseOptions struct {
Query badoption.Listable[string] `json:"query,omitempty"`
QueryType badoption.Listable[DNSQueryType] `json:"query_type,omitempty"`

RCode *DNSRCode `json:"rcode,omitempty"`
Answer badoption.Listable[DNSRecordOptions] `json:"answer,omitempty"`
Ns badoption.Listable[DNSRecordOptions] `json:"ns,omitempty"`
Extra badoption.Listable[DNSRecordOptions] `json:"extra,omitempty"`
}

type DNSRCode int

func (r DNSRCode) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -64,49 +48,6 @@ func (r *DNSRCode) Build() int {
return int(*r)
}

func (o DNSResponseOptions) Build() ([]dns.Question, *dns.Msg, error) {
var questions []dns.Question
if len(o.Query) == 0 && len(o.QueryType) == 0 {
questions = []dns.Question{{Qclass: dns.ClassINET}}
} else if len(o.Query) == 0 {
for _, queryType := range o.QueryType {
questions = append(questions, dns.Question{
Qtype: uint16(queryType),
Qclass: dns.ClassINET,
})
}
} else if len(o.QueryType) == 0 {
for _, domain := range o.Query {
questions = append(questions, dns.Question{
Name: dns.Fqdn(domain),
Qclass: dns.ClassINET,
})
}
} else {
for _, queryType := range o.QueryType {
for _, domain := range o.Query {
questions = append(questions, dns.Question{
Name: dns.Fqdn(domain),
Qtype: uint16(queryType),
Qclass: dns.ClassINET,
})
}
}
}
return questions, &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Rcode: o.RCode.Build(),
Authoritative: true,
RecursionDesired: true,
RecursionAvailable: true,
},
Answer: common.Map(o.Answer, DNSRecordOptions.build),
Ns: common.Map(o.Ns, DNSRecordOptions.build),
Extra: common.Map(o.Extra, DNSRecordOptions.build),
}, nil
}

type DNSRecordOptions struct {
dns.RR
fromBase64 bool
Expand Down Expand Up @@ -156,6 +97,6 @@ func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
return nil
}

func (o DNSRecordOptions) build() dns.RR {
func (o DNSRecordOptions) Build() dns.RR {
return o.RR
}
Loading

0 comments on commit e2d6b78

Please sign in to comment.