From 42a589848da8b8d90eec67eb57cb3ea88ee5cdd4 Mon Sep 17 00:00:00 2001 From: Jarno Rajahalme Date: Tue, 7 Nov 2023 20:36:53 +0200 Subject: [PATCH] client: Add shared client support Add SharedClient and SharedClients types. Signed-off-by: Jarno Rajahalme --- shared_client.go | 207 ++++++++++++++++++++ shared_client_test.go | 433 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 640 insertions(+) create mode 100644 shared_client.go create mode 100644 shared_client_test.go diff --git a/shared_client.go b/shared_client.go new file mode 100644 index 0000000000..7d23774c9a --- /dev/null +++ b/shared_client.go @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +//go:build go1.18 +// +build go1.18 + +package dns + +import ( + "context" + "errors" + "fmt" + "sync" + "time" +) + +// SharedClients holds a set of SharedClient instances. +type SharedClients struct { + sync.Mutex + // clients are created and destroyed on demand, hence 'Mutex' needs to be taken. + clients map[string]*SharedClient +} + +func NewSharedClients() *SharedClients { + return &SharedClients{} +} + +// GetSharedClient gets or creates an instance of SharedClient keyed with 'key'. +// if 'key' is an empty sting, a new client is always created and it is not actually shared. +// The returned 'closer' must be called once the client is no longer needed. +func (s *SharedClients) GetSharedClient(key string, conf *Client, serverAddrStr string) (client *SharedClient, closer func()) { + s.Lock() + defer s.Unlock() + + if key != "" { + // locate client to re-use if possible. + client = s.clients[key] + } + if client == nil { + client = newSharedClient(conf, serverAddrStr) + if key != "" { + s.clients[key] = client + } + } + client.refcount++ + + return client, func() { + s.Lock() + defer s.Unlock() + + client.refcount-- + if client.refcount == 0 { + // Make client unreachable and close it's connection. + // Must hold the proxy mutex for this. + if key != "" { + delete(s.clients, key) + } + // connection must be closed while holding the proxy lock to avoid a race + // where a new client is created with the same 5-tuple before this one is + // closed, which could happen if the proxy lock is released before this + // Close call. + if client.conn != nil { + client.conn.Close() + } + } + } +} + +var errNoReader = errors.New("Reader stopped") + +type Response struct { + *Msg + err error +} + +// A SharedClient keeps state for concurrent transactions on the same upstream client/connection. +type SharedClient struct { + serverAddr string + + *Client + + refcount int // protected by SharedClient's lock + + // this mutex protects writes on 'conn' and all access to 'reqs' + sync.Mutex + reqs map[uint16]chan Response // outstanding requests + + // 'readerLock' mutex is used to serialize reads on 'conn'. It is always taken and released + // while holding the main lock but the main lock can be released and re-acquired while + // holding 'readerLock' mutex. + readerLock sync.Mutex + + // Client's connection shared among all requests from the same source address/port. The + // locks above are used to serialize reads and writes on this connection, but reads and + // writes can happen at the same time. + conn *Conn +} + +func newSharedClient(conf *Client, serverAddr string) *SharedClient { + return &SharedClient{ + serverAddr: serverAddr, + Client: conf, + reqs: make(map[uint16]chan Response), + } +} + +// ExchangeShared writes the request to the Client's connection and co-operatively +// reads responses from the connection and distributes them to the requestors. +// At most one caller is reading from Client's connection at any time. +func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err error) { + return c.ExchangeSharedContext(context.Background(), m) +} + +// ExchangeSharedContext writes the request to the Client's connection and co-operatively +// reads responses from the connection and distributes them to the requestors. +// At most one caller is reading from Client's connection at any time. +func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Msg, rtt time.Duration, err error) { + // Lock allows only one request to be written at a time, but that can happen + // concurrently with reading. + c.Lock() + defer c.Unlock() + if _, exists := c.reqs[m.Id]; exists { + return nil, 0, fmt.Errorf("duplicate request: %d", m.Id) + } + + // Dial if needed + if c.conn == nil { + c.conn, err = c.DialContext(ctx, c.serverAddr) + if err != nil { + return nil, 0, fmt.Errorf("failed to dial connection to %v: %w", c.serverAddr, err) + } + } + + // Send while holding the client lock, as Client is not made to be usable from + // concurrent goroutines. + start := time.Now() + err = c.SendContext(ctx, m, c.conn, start) + if err != nil { + return nil, 0, err + } + + // Create channel for the response with buffer of one, so that write to it + // does not block if we happen to do it ourselves. + ch := make(chan Response, 1) + c.reqs[m.Id] = ch + + // Wait for the response + var resp Response + for { + // Try taking the reader lock + if c.readerLock.TryLock() { + // We are responsible for reading responses for all users + // of this client until we get our own response or an error occurs. + var err error + for err == nil { + // Release the client lock for the duration of the blocking read + // operation to allow concurrent writes to the underlying + // connection. + var r *Msg + c.Unlock() + // This ReadMsg() will eventually fail due to the read deadline set + // by 'Client' on the underlying connection when sending the + // (last) request. + r, err = c.conn.ReadMsg() + c.Lock() + if err != nil { + break + } + // Locate the request for this response, skipping if not found + ch, exists := c.reqs[r.Id] + if !exists { + continue + } + // Pass the response to the waiting requester + delete(c.reqs, r.Id) + ch <- Response{Msg: r} + if r.Id == m.Id { + // Got our response, quit reading and tell others that + // its their turn to read. + err = errNoReader + } + } + // Releasing the reader lock before sending errors on waiter's channels + // so that when they get them, one of them can take the reader lock. + c.readerLock.Unlock() + for id, ch := range c.reqs { + // Another reader will pick up if any errNoReader errors are sent. + // Only delete the pending request in other error cases. + if !errors.Is(err, errNoReader) { + delete(c.reqs, id) + } + ch <- Response{err: err} + } + } + // Get the response of error from the current reader. + // Unlock for the blocking duration to allow concurrent writes + // on the client's connection. + c.Unlock() + resp = <-ch + c.Lock() + if !errors.Is(resp.err, errNoReader) { + break + } + // Trying again + } + return resp.Msg, time.Since(start), resp.err +} diff --git a/shared_client_test.go b/shared_client_test.go new file mode 100644 index 0000000000..890d6e2b35 --- /dev/null +++ b/shared_client_test.go @@ -0,0 +1,433 @@ +//go:build go1.18 +// +build go1.18 + +package dns + +import ( + "context" + "crypto/tls" + "net" + "strconv" + "strings" + "testing" + "time" +) + +var ( + clients = NewSharedClients() +) + +func TestSharedClientSync(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + if r == nil { + t.Fatal("response is nil") + } + if r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } + // And now another ExchangeAsync on the same shared client. + r, _, err = c.ExchangeShared(m) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + if r == nil || r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } +} + +func TestSharedClientLocalAddress(t *testing.T) { + HandleFunc("miek.nl.", HelloServerEchoAddrPort) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + laddr := net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 12345, Zone: ""} + c.Dialer = &net.Dialer{LocalAddr: &laddr} + + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + if r != nil && r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } + if len(r.Extra) != 1 { + t.Fatalf("failed to get additional answers\n%v", r) + } + txt := r.Extra[0].(*TXT) + if txt == nil { + t.Errorf("invalid TXT response\n%v", txt) + } + if len(txt.Txt) != 1 || !strings.Contains(txt.Txt[0], ":12345") { + t.Errorf("invalid TXT response\n%v", txt.Txt) + } +} + +func TestSharedClientTLSSyncV4(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addrstr, _, err := RunLocalTLSServer(":0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + // test tcp-tls + c.Net = "tcp-tls" + c.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + if r == nil { + t.Fatal("response is nil") + } + if r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } + + // test tcp4-tls + c.Net = "tcp4-tls" + c.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + r, _, err = c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + if r == nil { + t.Fatal("response is nil") + } + if r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } +} + +func TestSharedClientSyncBadID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadID) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + // Test with client.Exchange, the plain Exchange function is just a wrapper, so + // we don't need to test that separately. + conf := &Client{ + Timeout: 10 * time.Millisecond, + } + + c, closer := clients.GetSharedClient("", conf, addrstr) + defer closer() + + if _, _, err := c.ExchangeShared(m); err == nil || !isNetworkTimeout(err) { + t.Errorf("query did not time out") + } +} + +func TestSharedClientSyncBadThenGoodID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadThenGoodID) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + if r.Id != m.Id { + t.Errorf("failed to get response with expected Id") + } +} + +func TestSharedClientSyncTCPBadID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadID) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + c.Net = "tcp" + c.Timeout = 10 * time.Millisecond + + // ExchangeShared does not pass through bad IDs, they will be filtered out just like + // for UDP and the request should time out + if _, _, err := c.ExchangeShared(m); err == nil || !isNetworkTimeout(err) { + t.Errorf("query did not time out") + } +} + +func TestSharedClientEDNS0(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeDNSKEY) + + m.SetEdns0(2048, true) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + + if r != nil && r.Rcode != RcodeSuccess { + t.Errorf("failed to get a valid answer\n%v", r) + } +} + +// Validates the transmission and parsing of local EDNS0 options. +func TestSharedClientEDNS0Local(t *testing.T) { + optStr1 := "1979:0x0707" + optStr2 := strconv.Itoa(EDNS0LOCALSTART) + ":0x0601" + + handler := func(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + + m.Extra = make([]RR, 1, 2) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello local edns"}} + + // If the local options are what we expect, then reflect them back. + ec1 := req.Extra[0].(*OPT).Option[0].(*EDNS0_LOCAL).String() + ec2 := req.Extra[0].(*OPT).Option[1].(*EDNS0_LOCAL).String() + if ec1 == optStr1 && ec2 == optStr2 { + m.Extra = append(m.Extra, req.Extra[0]) + } + + w.WriteMsg(m) + } + + HandleFunc("miek.nl.", handler) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + + // Add two local edns options to the query. + ec1 := &EDNS0_LOCAL{Code: 1979, Data: []byte{7, 7}} + ec2 := &EDNS0_LOCAL{Code: EDNS0LOCALSTART, Data: []byte{6, 1}} + o := &OPT{Hdr: RR_Header{Name: ".", Rrtype: TypeOPT}, Option: []EDNS0{ec1, ec2}} + m.Extra = append(m.Extra, o) + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + r, _, err := c.ExchangeShared(m) + if err != nil { + t.Fatalf("failed to exchange: %s", err) + } + + if r == nil { + t.Fatal("response is nil") + } + if r.Rcode != RcodeSuccess { + t.Fatal("failed to get a valid answer") + } + + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello local edns" { + t.Error("Unexpected result for miek.nl", txt, "!= Hello local edns") + } + + // Validate the local options in the reply. + got := r.Extra[1].(*OPT).Option[0].(*EDNS0_LOCAL).String() + if got != optStr1 { + t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr1) + } + + got = r.Extra[1].(*OPT).Option[1].(*EDNS0_LOCAL).String() + if got != optStr2 { + t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr2) + } +} + +func TestSharedTimeout(t *testing.T) { + // Set up a dummy UDP server that won't respond + addr, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + t.Fatalf("unable to resolve local udp address: %v", err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer conn.Close() + addrstr := conn.LocalAddr().String() + + // Message to send + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + + runTest := func(name string, exchange func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error)) { + t.Run(name, func(t *testing.T) { + start := time.Now() + + timeout := time.Millisecond + allowable := timeout + 10*time.Millisecond + + _, _, err := exchange(m, addrstr, timeout) + if err == nil { + t.Errorf("no timeout using Client.%s", name) + } + + length := time.Since(start) + if length > allowable { + t.Errorf("exchange took longer %v than specified Timeout %v", length, allowable) + } + }) + } + runTest("ExchangeShared", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) { + c, closer := clients.GetSharedClient("", &Client{Timeout: timeout}, addrstr) + defer closer() + + return c.ExchangeShared(m) + }) + runTest("ExchangeSharedContext", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + c, closer := clients.GetSharedClient("", new(Client), addrstr) + defer closer() + + return c.ExchangeSharedContext(ctx, m) + }) +} + +// Check that responses from deduplicated requests aren't shared between callers +func TestSharedConcurrentExchanges(t *testing.T) { + cases := make([]*Msg, 2) + cases[0] = new(Msg) + cases[1] = new(Msg) + cases[1].Truncated = true + + for _, m := range cases { + mm := m // redeclare m so as not to trip the race detector + handler := func(w ResponseWriter, req *Msg) { + r := mm.Copy() + r.SetReply(req) + + w.WriteMsg(r) + } + + HandleFunc("miek.nl.", handler) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSRV) + + c, closer := clients.GetSharedClient("", &Client{SingleInflight: true}, addrstr) + defer closer() + + // Force this client to always return the same request, + // even though we're querying sequentially. Running the + // Exchange calls below concurrently can fail due to + // goroutine scheduling, but this simulates the same + // outcome. + c.group.dontDeleteForTesting = true + + r := make([]*Msg, 2) + for i := range r { + r[i], _, _ = c.ExchangeShared(m.Copy()) + if r[i] == nil { + t.Errorf("response %d is nil", i) + } + } + + if r[0] == r[1] { + t.Errorf("got same response, expected non-shared responses") + } + } +}