diff --git a/pkg/dns/resolver/resolver.go b/pkg/dns/resolver/resolver.go index 8f6c027ea286..fc52595b770c 100644 --- a/pkg/dns/resolver/resolver.go +++ b/pkg/dns/resolver/resolver.go @@ -1,6 +1,7 @@ package resolver import ( + "strings" "sync" "github.com/miekg/dns" @@ -125,7 +126,10 @@ func (s *dnsResolver) serviceFromName(name string) (string, error) { return "", errors.Errorf("wrong DNS name: %s", name) } - service := split[0] + // If it terminates with the domain we remove it. + if split[len(split)-1] == s.domain { + split = split[0 : len(split)-1] + } - return service, nil + return strings.Join(split, "."), nil } diff --git a/pkg/dns/server_test.go b/pkg/dns/server_test.go index 5b0104cff551..b3ecc9c4c470 100644 --- a/pkg/dns/server_test.go +++ b/pkg/dns/server_test.go @@ -3,6 +3,8 @@ package dns_test import ( "fmt" + "github.com/kumahq/kuma/pkg/dns/vips" + "github.com/miekg/dns" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -23,6 +25,7 @@ var _ = Describe("DNS server", func() { stop := make(chan struct{}) done := make(chan struct{}) var metrics core_metrics.Metrics + var dnsResolver resolver.DNSResolver BeforeEach(func() { // setup @@ -30,19 +33,14 @@ var _ = Describe("DNS server", func() { port = uint32(p) Expect(err).ToNot(HaveOccurred()) - resolver := resolver.NewDNSResolver("mesh") + dnsResolver = resolver.NewDNSResolver("mesh") m, err := core_metrics.NewMetrics("Standalone") metrics = m Expect(err).ToNot(HaveOccurred()) - server, err := NewDNSServer(port, resolver, metrics) + server, err := NewDNSServer(port, dnsResolver, metrics) Expect(err).ToNot(HaveOccurred()) - resolver.SetVIPs(map[string]string{ - "service": "240.0.0.1", - }) // given - ip, err = resolver.ForwardLookupFQDN("service.mesh") - Expect(err).ToNot(HaveOccurred()) go func() { err := server.Start(stop) @@ -58,12 +56,19 @@ var _ = Describe("DNS server", func() { }) It("should resolve", func() { + // given + var err error + dnsResolver.SetVIPs(map[string]string{ + "service": "240.0.0.1", + }) + ip, err = dnsResolver.ForwardLookupFQDN("service.mesh") + Expect(err).ToNot(HaveOccurred()) + // when client := new(dns.Client) message := new(dns.Msg) _ = message.SetQuestion("service.mesh.", dns.TypeA) var response *dns.Msg - var err error Eventually(func() error { response, _, err = client.Exchange(message, fmt.Sprintf("127.0.0.1:%d", port)) return err @@ -78,6 +83,13 @@ var _ = Describe("DNS server", func() { }) It("should resolve concurrent", func() { + // given + dnsResolver.SetVIPs(map[string]string{ + "service": "240.0.0.1", + }) + ip, err := dnsResolver.ForwardLookupFQDN("service.mesh") + Expect(err).ToNot(HaveOccurred()) + resolved := make(chan struct{}) for i := 0; i < 100; i++ { go func() { @@ -103,11 +115,41 @@ var _ = Describe("DNS server", func() { }) It("should not resolve", func() { + // given + var err error + dnsResolver.SetVIPs(map[string]string{ + "service": "240.0.0.1", + }) + ip, err = dnsResolver.ForwardLookupFQDN("service.mesh") + Expect(err).ToNot(HaveOccurred()) + // when client := new(dns.Client) message := new(dns.Msg) _ = message.SetQuestion("backend.mesh.", dns.TypeA) var response *dns.Msg + Eventually(func() error { + response, _, err = client.Exchange(message, fmt.Sprintf("127.0.0.1:%d", port)) + return err + }).ShouldNot(HaveOccurred()) + // then + Expect(err).ToNot(HaveOccurred()) + // and + Expect(len(response.Answer)).To(Equal(0)) + + // and metrics are published + Expect(test_metrics.FindMetric(metrics, "dns_server_resolution", "result", "unresolved").Counter.GetValue()).To(Equal(1.0)) + }) + + It("should not resolve when no vips", func() { + // given + dnsResolver.SetVIPs(map[string]string{}) + + // when + client := new(dns.Client) + message := new(dns.Msg) + _ = message.SetQuestion("service.mesh.", dns.TypeA) + var response *dns.Msg var err error Eventually(func() error { response, _, err = client.Exchange(message, fmt.Sprintf("127.0.0.1:%d", port)) @@ -121,5 +163,32 @@ var _ = Describe("DNS server", func() { // and metrics are published Expect(test_metrics.FindMetric(metrics, "dns_server_resolution", "result", "unresolved").Counter.GetValue()).To(Equal(1.0)) }) + + It("should resolve services with '.'", func() { + // given + var err error + dnsResolver.SetVIPs(vips.List{ + "my.service": "240.0.0.1", + }) + ip, err = dnsResolver.ForwardLookupFQDN("my.service.mesh") + Expect(err).ToNot(HaveOccurred()) + + // when + client := new(dns.Client) + message := new(dns.Msg) + _ = message.SetQuestion("my.service.mesh.", dns.TypeA) + var response *dns.Msg + Eventually(func() error { + response, _, err = client.Exchange(message, fmt.Sprintf("127.0.0.1:%d", port)) + return err + }).ShouldNot(HaveOccurred()) + + // then + Expect(response.Answer[0].String()).To(Equal(fmt.Sprintf("my.service.mesh.\t60\tIN\tA\t%s", ip))) + + // and metrics are published + Expect(test_metrics.FindMetric(metrics, "dns_server")).ToNot(BeNil()) + Expect(test_metrics.FindMetric(metrics, "dns_server_resolution", "result", "resolved").Counter.GetValue()).To(Equal(1.0)) + }) }) })