Skip to content

Commit

Permalink
Merge pull request #6 from cilium/add-send-context
Browse files Browse the repository at this point in the history
dns: Add SendContext()
  • Loading branch information
jrajahalme authored Nov 2, 2023
2 parents e4fe466 + ab0e0a0 commit 6fba7e4
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,32 @@ func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn
}

func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
start := time.Now()
err = c.SendContext(ctx, m, co, start)
if err != nil {
return nil, 0, err
}

if isPacketConn(co.Conn) {
for {
r, err = co.ReadMsg()
// Ignore replies with mismatched IDs because they might be
// responses to earlier queries that timed out.
if err != nil || r.Id == m.Id {
break
}
}
} else {
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
}

return r, time.Since(start), err
}

func (c *Client) SendContext(ctx context.Context, m *Msg, co *Conn, t time.Time) error {
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
Expand All @@ -221,7 +247,6 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg,
}

// write with the appropriate write timeout
t := time.Now()
writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -237,27 +262,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg,

co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider

if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}

if isPacketConn(co.Conn) {
for {
r, err = co.ReadMsg()
// Ignore replies with mismatched IDs because they might be
// responses to earlier queries that timed out.
if err != nil || r.Id == m.Id {
break
}
}
} else {
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
}
rtt = time.Since(t)
return r, rtt, err
return co.WriteMsg(m)
}

// ReadMsg reads a message from the connection co.
Expand Down

0 comments on commit 6fba7e4

Please sign in to comment.