diff --git a/bench_test.go b/bench_test.go index 8bab7421..60eb94d9 100644 --- a/bench_test.go +++ b/bench_test.go @@ -48,14 +48,13 @@ func BenchmarkConnect(b *testing.B) { assert.True(b, ok) httpTransport.DisableCompression = true - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( httpClient, server.URL, connect.WithGRPC(), connect.WithGzip(), connect.WithGzipRequests(), ) - assert.Nil(b, err) twoMiB := strings.Repeat("a", 2*1024*1024) b.ResetTimer() diff --git a/client.go b/client.go index c8767b56..a35d7e3f 100644 --- a/client.go +++ b/client.go @@ -34,6 +34,7 @@ type Client[Req, Res any] struct { config *clientConfiguration callUnary func(context.Context, *Request[Req]) (*Response[Res], error) protocolClient protocolClient + err error } // NewClient constructs a new Client. @@ -41,12 +42,15 @@ func NewClient[Req, Res any]( httpClient HTTPClient, url string, options ...ClientOption, -) (*Client[Req, Res], error) { +) *Client[Req, Res] { + client := &Client[Req, Res]{} config, err := newClientConfiguration(url, options) if err != nil { - return nil, err + client.err = err + return client } - protocolClient, protocolErr := config.Protocol.NewClient(&protocolClientParams{ + client.config = config + protocolClient, protocolErr := client.config.Protocol.NewClient(&protocolClientParams{ CompressionName: config.RequestCompressionName, CompressionPools: newReadOnlyCompressionPools(config.CompressionPools), Codec: config.Codec, @@ -57,8 +61,10 @@ func NewClient[Req, Res any]( BufferPool: config.BufferPool, }) if protocolErr != nil { - return nil, protocolErr + client.err = protocolErr + return client } + client.protocolClient = protocolClient // Rather than applying unary interceptors along the hot path, we can do it // once at client creation. unarySpec := config.newSpecification(StreamTypeUnary) @@ -86,7 +92,7 @@ func NewClient[Req, Res any]( if ic := config.Interceptor; ic != nil { unaryFunc = ic.WrapUnary(unaryFunc) } - callUnary := func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { + client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { // To make the specification and RPC headers visible to the full interceptor // chain (as though they were supplied by the caller), we'll add them here. request.spec = unarySpec @@ -101,11 +107,7 @@ func NewClient[Req, Res any]( } return typed, nil } - return &Client[Req, Res]{ - config: config, - callUnary: callUnary, - protocolClient: protocolClient, - }, nil + return client } // CallUnary calls a request-response procedure. @@ -113,11 +115,17 @@ func (c *Client[Req, Res]) CallUnary( ctx context.Context, req *Request[Req], ) (*Response[Res], error) { + if c.err != nil { + return nil, c.err + } return c.callUnary(ctx, req) } // CallClientStream calls a client streaming procedure. func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] { + if c.err != nil { + return &ClientStreamForClient[Req, Res]{err: c.err} + } sender, receiver := c.newStream(ctx, StreamTypeClient) return &ClientStreamForClient[Req, Res]{sender: sender, receiver: receiver} } @@ -127,6 +135,9 @@ func (c *Client[Req, Res]) CallServerStream( ctx context.Context, req *Request[Req], ) (*ServerStreamForClient[Res], error) { + if c.err != nil { + return nil, c.err + } sender, receiver := c.newStream(ctx, StreamTypeServer) mergeHeaders(sender.Header(), req.header) // Send always returns an io.EOF unless the error is from the client-side. @@ -145,6 +156,9 @@ func (c *Client[Req, Res]) CallServerStream( // CallBidiStream calls a bidirectional streaming procedure. func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] { + if c.err != nil { + return &BidiStreamForClient[Req, Res]{err: c.err} + } sender, receiver := c.newStream(ctx, StreamTypeBidi) return &BidiStreamForClient[Req, Res]{sender: sender, receiver: receiver} } diff --git a/client_example_test.go b/client_example_test.go index 45d918a6..3eb2f555 100644 --- a/client_example_test.go +++ b/client_example_test.go @@ -59,15 +59,11 @@ func Example_client() { // client that communicate over in-memory pipes. Don't do this in production! httpClient = examplePingServer.Client() - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( httpClient, examplePingServer.URL(), connect.WithGRPC(), ) - if err != nil { - logger.Println("error:", err) - return - } res, err := client.Ping( context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42}), diff --git a/client_stream.go b/client_stream.go index aca3f775..1f5ac2a8 100644 --- a/client_stream.go +++ b/client_stream.go @@ -27,6 +27,8 @@ import ( type ClientStreamForClient[Req, Res any] struct { sender Sender receiver Receiver + // Error from client construction. If non-nil, return for all calls. + err error } // RequestHeader returns the request headers. Headers are sent to the server with the @@ -42,12 +44,18 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header { // Clients should check for case using the standard library's errors.Is and // unmarshal the error using CloseAndReceive. func (c *ClientStreamForClient[Req, Res]) Send(msg *Req) error { + if c.err != nil { + return c.err + } return c.sender.Send(msg) } // CloseAndReceive closes the send side of the stream and waits for the // response. func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], error) { + if c.err != nil { + return nil, c.err + } if err := c.sender.Close(nil); err != nil { return nil, err } @@ -124,6 +132,8 @@ func (s *ServerStreamForClient[Res]) Close() error { type BidiStreamForClient[Req, Res any] struct { sender Sender receiver Receiver + // Error from client construction. If non-nil, return for all calls. + err error } // RequestHeader returns the request headers. Headers are sent with the first @@ -139,17 +149,26 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { // Clients should check for case using the standard library's errors.Is and // unmarshal the error using Receive. func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error { + if b.err != nil { + return b.err + } return b.sender.Send(msg) } // CloseSend closes the send side of the stream. func (b *BidiStreamForClient[Req, Res]) CloseSend() error { + if b.err != nil { + return b.err + } return b.sender.Close(nil) } // Receive a message. When the server is done sending messages and no other // errors have occurred, Receive will return an error that wraps io.EOF. func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { + if b.err != nil { + return nil, b.err + } var res Res if err := b.receiver.Receive(&res); err != nil { return nil, err @@ -159,6 +178,9 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { // CloseReceive closes the receive side of the stream. func (b *BidiStreamForClient[Req, Res]) CloseReceive() error { + if b.err != nil { + return b.err + } return b.receiver.Close() } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 48b3656d..4df0907a 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -133,7 +133,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) { g.P() wrapComments(g, "This is a compile-time assertion to ensure that this generated file ", "and the connect package are compatible. If you get a compiler error that this constant ", - "isn't defined, this code was generated with a version of connect newer than the one ", + "is not defined, this code was generated with a version of connect newer than the one ", "compiled into your binary. You can fix the problem by either regenerating this code ", "with an older version of connect or updating the connect version compiled into your binary.") g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_0_1")) @@ -199,10 +199,11 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S deprecated(g) } g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"), - ", baseURL string, opts ...", clientOption, ") (", names.Client, ", error) {") + ", baseURL string, opts ...", clientOption, ") ", names.Client, " {") g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + g.P("return &", names.ClientImpl, "{") for _, method := range service.Methods { - g.P(unexport(method.GoName), "Client, err := ", + g.P(unexport(method.GoName), ": ", connectPackage.Ident("NewClient"), "[", method.Input.GoIdent, ", ", method.Output.GoIdent, "]", "(", @@ -210,16 +211,9 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S g.P("httpClient,") g.P(`baseURL + "`, procedureName(method), `",`) g.P("opts...,") - g.P(")") - g.P("if err != nil {") - g.P("return nil, err") - g.P("}") - } - g.P("return &", names.ClientImpl, "{") - for _, method := range service.Methods { - g.P(unexport(method.GoName), ": ", unexport(method.GoName), "Client,") + g.P("),") } - g.P("}, nil") + g.P("}") g.P("}") g.P() @@ -359,11 +353,11 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() { g.P("return ", connectPackage.Ident("NewError"), "(", connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"), - `("`, method.Desc.FullName(), ` isn't implemented"))`) + `("`, method.Desc.FullName(), ` is not implemented"))`) } else { g.P("return nil, ", connectPackage.Ident("NewError"), "(", connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"), - `("`, method.Desc.FullName(), ` isn't implemented"))`) + `("`, method.Desc.FullName(), ` is not implemented"))`) } g.P("}") g.P() diff --git a/connect_ext_test.go b/connect_ext_test.go index 3428b9ca..c2ffd6d9 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -277,8 +277,7 @@ func TestServer(t *testing.T) { testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { // nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) - assert.Nil(t, err) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -354,8 +353,7 @@ func TestHeaderBasic(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) - assert.Nil(t, err) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) req := connect.NewRequest(&pingv1.PingRequest{}) req.Header().Set(key, cval) res, err := client.Ping(context.Background(), req) @@ -378,10 +376,9 @@ func TestMarshalStatusError(t *testing.T) { assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) - assert.Nil(tb, err) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) req := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) - _, err = client.Fail(context.Background(), req) + _, err := client.Fail(context.Background(), req) tb.Log(err) assert.NotNil(t, err) var connectErr *connect.Error @@ -406,16 +403,15 @@ func TestBidiRequiresHTTP2(t *testing.T) { }) server := httptest.NewServer(handler) defer server.Close() - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, connect.WithGRPC(), ) - assert.Nil(t, err) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(&pingv1.CumSumRequest{})) assert.Nil(t, stream.CloseSend()) - _, err = stream.Receive() + _, err := stream.Receive() assert.NotNil(t, err) var connectErr *connect.Error assert.True(t, errors.As(err, &connectErr)) diff --git a/interceptor_example_test.go b/interceptor_example_test.go index 98e10205..9dcaf7be 100644 --- a/interceptor_example_test.go +++ b/interceptor_example_test.go @@ -41,16 +41,12 @@ func ExampleInterceptor() { }) }, ) - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), examplePingServer.URL(), connect.WithGRPC(), connect.WithInterceptors(loggingInterceptor), ) - if err != nil { - logger.Println("error:", err) - return - } if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42})); err != nil { logger.Println("error:", err) return @@ -84,16 +80,12 @@ func ExampleWithInterceptors() { }) }, ) - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), examplePingServer.URL(), connect.WithGRPC(), connect.WithInterceptors(outer, inner), ) - if err != nil { - logger.Println("error:", err) - return - } if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})); err != nil { logger.Println("error:", err) return diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index a8afd28c..8281ed12 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -29,8 +29,9 @@ import ( func TestClientStreamErrors(t *testing.T) { t.Parallel() - _, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC()) + _, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC()).Ping(context.Background(), nil) assert.NotNil(t, err) + assert.Match(t, err.Error(), "missing scheme") // We don't even get to calling methods on the client, so there's no question // of interceptors running. Once we're calling methods on the client, all // errors are visible to interceptors. @@ -180,17 +181,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - client, err := pingv1connect.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, connect.WithGRPC(), clientOnion, ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) assert.Nil(t, err) - - _, err = client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) - assert.Nil(t, err) - _, err = client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) } diff --git a/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go index f396898b..3ca136c1 100644 --- a/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/connect/ping/v1/pingv1connect/ping.connect.go @@ -28,10 +28,10 @@ import ( ) // This is a compile-time assertion to ensure that this generated file and the connect package are -// compatible. If you get a compiler error that this constant isn't defined, this code was generated -// with a version of connect newer than the one compiled into your binary. You can fix the problem -// by either regenerating this code with an older version of connect or updating the connect version -// compiled into your binary. +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. const _ = connect_go.IsAtLeastVersion0_0_1 const ( @@ -60,55 +60,35 @@ type PingServiceClient interface { // // The URL supplied here should be the base URL for the gRPC server (for example, // http://api.acme.com or https://acme.com/grpc). -func NewPingServiceClient(httpClient connect_go.HTTPClient, baseURL string, opts ...connect_go.ClientOption) (PingServiceClient, error) { +func NewPingServiceClient(httpClient connect_go.HTTPClient, baseURL string, opts ...connect_go.ClientOption) PingServiceClient { baseURL = strings.TrimRight(baseURL, "/") - pingClient, err := connect_go.NewClient[v1.PingRequest, v1.PingResponse]( - httpClient, - baseURL+"/connect.ping.v1.PingService/Ping", - opts..., - ) - if err != nil { - return nil, err - } - failClient, err := connect_go.NewClient[v1.FailRequest, v1.FailResponse]( - httpClient, - baseURL+"/connect.ping.v1.PingService/Fail", - opts..., - ) - if err != nil { - return nil, err - } - sumClient, err := connect_go.NewClient[v1.SumRequest, v1.SumResponse]( - httpClient, - baseURL+"/connect.ping.v1.PingService/Sum", - opts..., - ) - if err != nil { - return nil, err - } - countUpClient, err := connect_go.NewClient[v1.CountUpRequest, v1.CountUpResponse]( - httpClient, - baseURL+"/connect.ping.v1.PingService/CountUp", - opts..., - ) - if err != nil { - return nil, err - } - cumSumClient, err := connect_go.NewClient[v1.CumSumRequest, v1.CumSumResponse]( - httpClient, - baseURL+"/connect.ping.v1.PingService/CumSum", - opts..., - ) - if err != nil { - return nil, err - } return &pingServiceClient{ - ping: pingClient, - fail: failClient, - sum: sumClient, - countUp: countUpClient, - cumSum: cumSumClient, - }, nil + ping: connect_go.NewClient[v1.PingRequest, v1.PingResponse]( + httpClient, + baseURL+"/connect.ping.v1.PingService/Ping", + opts..., + ), + fail: connect_go.NewClient[v1.FailRequest, v1.FailResponse]( + httpClient, + baseURL+"/connect.ping.v1.PingService/Fail", + opts..., + ), + sum: connect_go.NewClient[v1.SumRequest, v1.SumResponse]( + httpClient, + baseURL+"/connect.ping.v1.PingService/Sum", + opts..., + ), + countUp: connect_go.NewClient[v1.CountUpRequest, v1.CountUpResponse]( + httpClient, + baseURL+"/connect.ping.v1.PingService/CountUp", + opts..., + ), + cumSum: connect_go.NewClient[v1.CumSumRequest, v1.CumSumResponse]( + httpClient, + baseURL+"/connect.ping.v1.PingService/CumSum", + opts..., + ), + } } // pingServiceClient implements PingServiceClient. @@ -198,21 +178,21 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect_go.HandlerOpt type UnimplementedPingServiceHandler struct{} func (UnimplementedPingServiceHandler) Ping(context.Context, *connect_go.Request[v1.PingRequest]) (*connect_go.Response[v1.PingResponse], error) { - return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Ping isn't implemented")) + return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Ping is not implemented")) } func (UnimplementedPingServiceHandler) Fail(context.Context, *connect_go.Request[v1.FailRequest]) (*connect_go.Response[v1.FailResponse], error) { - return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail isn't implemented")) + return nil, connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } func (UnimplementedPingServiceHandler) Sum(context.Context, *connect_go.ClientStream[v1.SumRequest, v1.SumResponse]) error { - return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum isn't implemented")) + return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) } func (UnimplementedPingServiceHandler) CountUp(context.Context, *connect_go.Request[v1.CountUpRequest], *connect_go.ServerStream[v1.CountUpResponse]) error { - return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CountUp isn't implemented")) + return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CountUp is not implemented")) } func (UnimplementedPingServiceHandler) CumSum(context.Context, *connect_go.BidiStream[v1.CumSumRequest, v1.CumSumResponse]) error { - return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CumSum isn't implemented")) + return connect_go.NewError(connect_go.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CumSum is not implemented")) }