From ab0e0a010097ca593ab898c83fbf28832ac24687 Mon Sep 17 00:00:00 2001 From: Jarno Rajahalme Date: Thu, 2 Nov 2023 16:52:24 +0200 Subject: [PATCH] dns: Add SendContext() Implement exchangeContext() with exported SendContext(). Signed-off-by: Jarno Rajahalme --- client.go | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index 9aa658530..f689b37a0 100644 --- a/client.go +++ b/client.go @@ -210,6 +210,32 @@ func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn } func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { + start := time.Now() + err = c.SendContext(ctx, m, co, start) + if err != nil { + return nil, 0, err + } + + if isPacketConn(co.Conn) { + for { + r, err = co.ReadMsg() + // Ignore replies with mismatched IDs because they might be + // responses to earlier queries that timed out. + if err != nil || r.Id == m.Id { + break + } + } + } else { + r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } + } + + return r, time.Since(start), err +} + +func (c *Client) SendContext(ctx context.Context, m *Msg, co *Conn, t time.Time) error { opt := m.IsEdns0() // If EDNS0 is used use that for size. if opt != nil && opt.UDPSize() >= MinMsgSize { @@ -221,7 +247,6 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, } // write with the appropriate write timeout - t := time.Now() writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout())) readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout())) if deadline, ok := ctx.Deadline(); ok { @@ -237,27 +262,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider - if err = co.WriteMsg(m); err != nil { - return nil, 0, err - } - - if isPacketConn(co.Conn) { - for { - r, err = co.ReadMsg() - // Ignore replies with mismatched IDs because they might be - // responses to earlier queries that timed out. - if err != nil || r.Id == m.Id { - break - } - } - } else { - r, err = co.ReadMsg() - if err == nil && r.Id != m.Id { - err = ErrId - } - } - rtt = time.Since(t) - return r, rtt, err + return co.WriteMsg(m) } // ReadMsg reads a message from the connection co.