diff --git a/dnsserver.go b/dnsserver.go index 51a6153..f0c91c4 100644 --- a/dnsserver.go +++ b/dnsserver.go @@ -91,6 +91,20 @@ func NewDNSConfig() *DNSConfig { } } +// Clone clones a [DNSConfig]. +func (dc *DNSConfig) Clone() *DNSConfig { + defer dc.mu.Unlock() + dc.mu.Lock() + out := NewDNSConfig() + for key, value := range dc.r { + out.r[key] = &DNSRecord{ + A: append([]net.IP{}, value.A...), + CNAME: value.CNAME, + } + } + return out +} + // ErrNotIPAddress indicates that a string is not a serialized IP address. var ErrNotIPAddress = errors.New("netem: not a valid IP address") @@ -132,6 +146,21 @@ func (dc *DNSConfig) Lookup(name string) (*DNSRecord, bool) { return record, found } +func dnsConfigWithWhoami(config *DNSConfig, endpoint net.Addr) *DNSConfig { + // make sure we operate on a copy + config = config.Clone() + + // extract the endpoint address + ipAddr, _, err := net.SplitHostPort(endpoint.String()) + if err != nil { + return config + } + + // add whoami.v4.powerdns.org record + _ = config.AddRecord("whoami.v4.powerdns.org", "", ipAddr) + return config +} + // dnsServerWorker is the [DNSServer] worker. func dnsServerWorker( logger Logger, @@ -156,7 +185,7 @@ func dnsServerWorker( } rawQuery := buffer[:count] - rawResponse, err := DNSServerRoundTrip(config, rawQuery) + rawResponse, err := DNSServerRoundTrip(dnsConfigWithWhoami(config, addr), rawQuery) if err != nil { logger.Warnf("netem: dnsServerRoundTrip: %s", err.Error()) continue diff --git a/dnsserver_test.go b/dnsserver_test.go index 9adbdfb..2d22f87 100644 --- a/dnsserver_test.go +++ b/dnsserver_test.go @@ -47,4 +47,25 @@ func TestDNSConfig(t *testing.T) { }) }) }) + + t.Run("we can clone a DNSConfig", func(t *testing.T) { + config := NewDNSConfig() + config.AddRecord("www.example.com", "", "130.192.91.211") + other := config.Clone() + config.RemoveRecord("www.example.com") + if _, good := config.Lookup("www.example.com"); good { + t.Fatal("expected record to be missing") + } + record, good := other.Lookup("www.example.com") + if !good { + t.Fatal("expected to see record") + } + expect := &DNSRecord{ + A: []net.IP{net.IPv4(130, 192, 91, 211)}, + CNAME: "", + } + if diff := cmp.Diff(expect, record); diff != "" { + t.Fatal(diff) + } + }) } diff --git a/example_star_test.go b/example_star_test.go index 11e309a..146010b 100644 --- a/example_star_test.go +++ b/example_star_test.go @@ -1,6 +1,7 @@ package netem_test import ( + "context" "fmt" "io" "log" @@ -114,3 +115,65 @@ func Example_starTopologyHTTPSAndDNS() { // Bonsoir, Elliot! // } + +// This example shows how DNS servers implement whoami.v4.powerdns.org +// an HTTPS server. Then we create an HTTPS client and we use such a +// client to fetch a very important message from the server. +func Example_starTopologyDNSWhoami() { + // Create a star topology for our hosts. + topology := netem.MustNewStarTopology(&netem.NullLogger{}) + defer topology.Close() + + // Add client stack to topology. Note that we don't need to + // close the clientStack: the topology will do that. + clientStack, err := topology.AddHost( + "130.192.91.211", // host IP address + "8.8.8.8", // host DNS resolver IP address + &netem.LinkConfig{}, // link with no PLR, RTT, DPI + ) + if err != nil { + log.Fatal(err) + } + + // Add DNS server stack to topology. + dnsServerStack, err := topology.AddHost( + "8.8.8.8", + "8.8.8.8", // this host is its own DNS resolver + &netem.LinkConfig{}, + ) + if err != nil { + log.Fatal(err) + } + + // spawn a DNS server with empty configuration. + dnsConfig := netem.NewDNSConfig() + dnsServer, err := netem.NewDNSServer( + &netem.NullLogger{}, + dnsServerStack, + "8.8.8.8", + dnsConfig, + ) + if err != nil { + log.Fatal(err) + } + defer dnsServer.Close() + + // create the DNS query to use + query := netem.NewDNSRequestA("whoami.v4.powerdns.org") + + // issue a DNS request to the server + response, err := netem.DNSRoundTrip(context.Background(), clientStack, "8.8.8.8", query) + if err != nil { + log.Fatal(err) + } + + // parse the DNS response + addrs, _, err := netem.DNSParseResponse(query, response) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s\n", addrs) + // Output: + // [130.192.91.211] + // +}