diff --git a/routing/dht/dht_net.go b/routing/dht/dht_net.go index 0152dab4a91..8ad4286ce5f 100644 --- a/routing/dht/dht_net.go +++ b/routing/dht/dht_net.go @@ -214,7 +214,7 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb log.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes) mes := new(pb.Message) - if err := ms.r.ReadMsg(mes); err != nil { + if err := ms.ctxReadMsg(ctx, mes); err != nil { ms.s.Close() ms.s = nil return nil, err @@ -227,3 +227,17 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb return mes, nil } + +func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { + errc := make(chan error, 1) + go func() { + errc <- ms.r.ReadMsg(mes) + }() + + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + } +}