diff --git a/contrib/registry/consul/registry.go b/contrib/registry/consul/registry.go index 5b098384f97..42fa4a37091 100644 --- a/contrib/registry/consul/registry.go +++ b/contrib/registry/consul/registry.go @@ -215,10 +215,17 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er } func (r *Registry) resolve(ctx context.Context, ss *serviceSet) error { - timeoutCtx, cancel := context.WithTimeout(ctx, r.timeout) - defer cancel() + listServices := r.cli.Service + if r.timeout > 0 { + listServices = func(ctx context.Context, service string, index uint64, passingOnly bool) ([]*registry.ServiceInstance, uint64, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, r.timeout) + defer cancel() - services, idx, err := r.cli.Service(timeoutCtx, ss.serviceName, 0, true) + return r.cli.Service(timeoutCtx, service, index, passingOnly) + } + } + + services, idx, err := listServices(ctx, ss.serviceName, 0, true) if err != nil { return err } @@ -232,9 +239,7 @@ func (r *Registry) resolve(ctx context.Context, ss *serviceSet) error { for { select { case <-ticker.C: - timeoutCtx, cancel := context.WithTimeout(context.Background(), r.timeout) - tmpService, tmpIdx, err := r.cli.Service(timeoutCtx, ss.serviceName, idx, true) - cancel() + tmpService, tmpIdx, err := listServices(context.Background(), ss.serviceName, idx, true) if err != nil { time.Sleep(time.Second) continue diff --git a/transport/grpc/resolver/discovery/builder.go b/transport/grpc/resolver/discovery/builder.go index e8e98726f79..dfed56bf7fd 100644 --- a/transport/grpc/resolver/discovery/builder.go +++ b/transport/grpc/resolver/discovery/builder.go @@ -94,11 +94,16 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolv }() var err error - select { - case <-done: + if b.timeout > 0 { + select { + case <-done: + err = watchRes.err + case <-time.After(b.timeout): + err = ErrWatcherCreateTimeout + } + } else { + <-done err = watchRes.err - case <-time.After(b.timeout): - err = ErrWatcherCreateTimeout } if err != nil { cancel() diff --git a/transport/grpc/resolver/discovery/builder_test.go b/transport/grpc/resolver/discovery/builder_test.go index 0f27078acb1..1c855830115 100644 --- a/transport/grpc/resolver/discovery/builder_test.go +++ b/transport/grpc/resolver/discovery/builder_test.go @@ -107,7 +107,7 @@ func TestBuilder_Build(t *testing.T) { &mockConn{}, resolver.BuildOptions{}, ) - if err == nil { - t.Errorf("expected error, got %v", err) + if err != nil { + t.Errorf("expected no error, got %v", err) } }