Skip to content

Commit

Permalink
Do not return error from Client constructor (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
bufdev authored Apr 20, 2022
1 parent 2503844 commit 5a13642
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 113 deletions.
3 changes: 1 addition & 2 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ func BenchmarkConnect(b *testing.B) {
assert.True(b, ok)
httpTransport.DisableCompression = true

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
httpClient,
server.URL,
connect.WithGRPC(),
connect.WithGzip(),
connect.WithGzipRequests(),
)
assert.Nil(b, err)
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

Expand Down
34 changes: 24 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,23 @@ type Client[Req, Res any] struct {
config *clientConfiguration
callUnary func(context.Context, *Request[Req]) (*Response[Res], error)
protocolClient protocolClient
err error
}

// NewClient constructs a new Client.
func NewClient[Req, Res any](
httpClient HTTPClient,
url string,
options ...ClientOption,
) (*Client[Req, Res], error) {
) *Client[Req, Res] {
client := &Client[Req, Res]{}
config, err := newClientConfiguration(url, options)
if err != nil {
return nil, err
client.err = err
return client
}
protocolClient, protocolErr := config.Protocol.NewClient(&protocolClientParams{
client.config = config
protocolClient, protocolErr := client.config.Protocol.NewClient(&protocolClientParams{
CompressionName: config.RequestCompressionName,
CompressionPools: newReadOnlyCompressionPools(config.CompressionPools),
Codec: config.Codec,
Expand All @@ -57,8 +61,10 @@ func NewClient[Req, Res any](
BufferPool: config.BufferPool,
})
if protocolErr != nil {
return nil, protocolErr
client.err = protocolErr
return client
}
client.protocolClient = protocolClient
// Rather than applying unary interceptors along the hot path, we can do it
// once at client creation.
unarySpec := config.newSpecification(StreamTypeUnary)
Expand Down Expand Up @@ -86,7 +92,7 @@ func NewClient[Req, Res any](
if ic := config.Interceptor; ic != nil {
unaryFunc = ic.WrapUnary(unaryFunc)
}
callUnary := func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
// To make the specification and RPC headers visible to the full interceptor
// chain (as though they were supplied by the caller), we'll add them here.
request.spec = unarySpec
Expand All @@ -101,23 +107,25 @@ func NewClient[Req, Res any](
}
return typed, nil
}
return &Client[Req, Res]{
config: config,
callUnary: callUnary,
protocolClient: protocolClient,
}, nil
return client
}

// CallUnary calls a request-response procedure.
func (c *Client[Req, Res]) CallUnary(
ctx context.Context,
req *Request[Req],
) (*Response[Res], error) {
if c.err != nil {
return nil, c.err
}
return c.callUnary(ctx, req)
}

// CallClientStream calls a client streaming procedure.
func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] {
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeClient)
return &ClientStreamForClient[Req, Res]{sender: sender, receiver: receiver}
}
Expand All @@ -127,6 +135,9 @@ func (c *Client[Req, Res]) CallServerStream(
ctx context.Context,
req *Request[Req],
) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
sender, receiver := c.newStream(ctx, StreamTypeServer)
mergeHeaders(sender.Header(), req.header)
// Send always returns an io.EOF unless the error is from the client-side.
Expand All @@ -145,6 +156,9 @@ func (c *Client[Req, Res]) CallServerStream(

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeBidi)
return &BidiStreamForClient[Req, Res]{sender: sender, receiver: receiver}
}
Expand Down
6 changes: 1 addition & 5 deletions client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,11 @@ func Example_client() {
// client that communicate over in-memory pipes. Don't do this in production!
httpClient = examplePingServer.Client()

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
httpClient,
examplePingServer.URL(),
connect.WithGRPC(),
)
if err != nil {
logger.Println("error:", err)
return
}
res, err := client.Ping(
context.Background(),
connect.NewRequest(&pingv1.PingRequest{Number: 42}),
Expand Down
22 changes: 22 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
type ClientStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
// Error from client construction. If non-nil, return for all calls.
err error
}

// RequestHeader returns the request headers. Headers are sent to the server with the
Expand All @@ -42,12 +44,18 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
// Clients should check for case using the standard library's errors.Is and
// unmarshal the error using CloseAndReceive.
func (c *ClientStreamForClient[Req, Res]) Send(msg *Req) error {
if c.err != nil {
return c.err
}
return c.sender.Send(msg)
}

// CloseAndReceive closes the send side of the stream and waits for the
// response.
func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], error) {
if c.err != nil {
return nil, c.err
}
if err := c.sender.Close(nil); err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,6 +132,8 @@ func (s *ServerStreamForClient[Res]) Close() error {
type BidiStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
// Error from client construction. If non-nil, return for all calls.
err error
}

// RequestHeader returns the request headers. Headers are sent with the first
Expand All @@ -139,17 +149,26 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
// Clients should check for case using the standard library's errors.Is and
// unmarshal the error using Receive.
func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error {
if b.err != nil {
return b.err
}
return b.sender.Send(msg)
}

// CloseSend closes the send side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseSend() error {
if b.err != nil {
return b.err
}
return b.sender.Close(nil)
}

// Receive a message. When the server is done sending messages and no other
// errors have occurred, Receive will return an error that wraps io.EOF.
func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
if b.err != nil {
return nil, b.err
}
var res Res
if err := b.receiver.Receive(&res); err != nil {
return nil, err
Expand All @@ -159,6 +178,9 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {

// CloseReceive closes the receive side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseReceive() error {
if b.err != nil {
return b.err
}
return b.receiver.Close()
}

Expand Down
22 changes: 8 additions & 14 deletions cmd/protoc-gen-connect-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) {
g.P()
wrapComments(g, "This is a compile-time assertion to ensure that this generated file ",
"and the connect package are compatible. If you get a compiler error that this constant ",
"isn't defined, this code was generated with a version of connect newer than the one ",
"is not defined, this code was generated with a version of connect newer than the one ",
"compiled into your binary. You can fix the problem by either regenerating this code ",
"with an older version of connect or updating the connect version compiled into your binary.")
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_0_1"))
Expand Down Expand Up @@ -199,27 +199,21 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
deprecated(g)
}
g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"),
", baseURL string, opts ...", clientOption, ") (", names.Client, ", error) {")
", baseURL string, opts ...", clientOption, ") ", names.Client, " {")
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), "Client, err := ",
g.P(unexport(method.GoName), ": ",
connectPackage.Ident("NewClient"),
"[", method.Input.GoIdent, ", ", method.Output.GoIdent, "]",
"(",
)
g.P("httpClient,")
g.P(`baseURL + "`, procedureName(method), `",`)
g.P("opts...,")
g.P(")")
g.P("if err != nil {")
g.P("return nil, err")
g.P("}")
}
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), ": ", unexport(method.GoName), "Client,")
g.P("),")
}
g.P("}, nil")
g.P("}")
g.P("}")
g.P()

Expand Down Expand Up @@ -359,11 +353,11 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic
if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
g.P("return ", connectPackage.Ident("NewError"), "(",
connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"),
`("`, method.Desc.FullName(), ` isn't implemented"))`)
`("`, method.Desc.FullName(), ` is not implemented"))`)
} else {
g.P("return nil, ", connectPackage.Ident("NewError"), "(",
connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"),
`("`, method.Desc.FullName(), ` isn't implemented"))`)
`("`, method.Desc.FullName(), ` is not implemented"))`)
}
g.P("}")
g.P()
Expand Down
16 changes: 6 additions & 10 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ func TestServer(t *testing.T) {
testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { // nolint:thelper
run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
assert.Nil(t, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
testPing(t, client)
testSum(t, client)
testCountUp(t, client)
Expand Down Expand Up @@ -354,8 +353,7 @@ func TestHeaderBasic(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
assert.Nil(t, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
req := connect.NewRequest(&pingv1.PingRequest{})
req.Header().Set(key, cval)
res, err := client.Ping(context.Background(), req)
Expand All @@ -378,10 +376,9 @@ func TestMarshalStatusError(t *testing.T) {

assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) {
tb.Helper()
client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
assert.Nil(tb, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
req := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)})
_, err = client.Fail(context.Background(), req)
_, err := client.Fail(context.Background(), req)
tb.Log(err)
assert.NotNil(t, err)
var connectErr *connect.Error
Expand All @@ -406,16 +403,15 @@ func TestBidiRequiresHTTP2(t *testing.T) {
})
server := httptest.NewServer(handler)
defer server.Close()
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithGRPC(),
)
assert.Nil(t, err)
stream := client.CumSum(context.Background())
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{}))
assert.Nil(t, stream.CloseSend())
_, err = stream.Receive()
_, err := stream.Receive()
assert.NotNil(t, err)
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
Expand Down
12 changes: 2 additions & 10 deletions interceptor_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@ func ExampleInterceptor() {
})
},
)
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
examplePingServer.Client(),
examplePingServer.URL(),
connect.WithGRPC(),
connect.WithInterceptors(loggingInterceptor),
)
if err != nil {
logger.Println("error:", err)
return
}
if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42})); err != nil {
logger.Println("error:", err)
return
Expand Down Expand Up @@ -84,16 +80,12 @@ func ExampleWithInterceptors() {
})
},
)
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
examplePingServer.Client(),
examplePingServer.URL(),
connect.WithGRPC(),
connect.WithInterceptors(outer, inner),
)
if err != nil {
logger.Println("error:", err)
return
}
if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})); err != nil {
logger.Println("error:", err)
return
Expand Down
10 changes: 4 additions & 6 deletions interceptor_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (

func TestClientStreamErrors(t *testing.T) {
t.Parallel()
_, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC())
_, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC()).Ping(context.Background(), nil)
assert.NotNil(t, err)
assert.Match(t, err.Error(), "missing scheme")
// We don't even get to calling methods on the client, so there's no question
// of interceptors running. Once we're calling methods on the client, all
// errors are visible to interceptors.
Expand Down Expand Up @@ -180,17 +181,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithGRPC(),
clientOnion,
)
_, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10}))
assert.Nil(t, err)

_, err = client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10}))
assert.Nil(t, err)

_, err = client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10}))
assert.Nil(t, err)
}
Expand Down
Loading

0 comments on commit 5a13642

Please sign in to comment.