Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not return error from Client constructor #204

Merged
merged 1 commit into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ func BenchmarkConnect(b *testing.B) {
assert.True(b, ok)
httpTransport.DisableCompression = true

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
httpClient,
server.URL,
connect.WithGRPC(),
connect.WithGzip(),
connect.WithGzipRequests(),
)
assert.Nil(b, err)
twoMiB := strings.Repeat("a", 2*1024*1024)
b.ResetTimer()

Expand Down
34 changes: 24 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,23 @@ type Client[Req, Res any] struct {
config *clientConfiguration
callUnary func(context.Context, *Request[Req]) (*Response[Res], error)
protocolClient protocolClient
err error
}

// NewClient constructs a new Client.
func NewClient[Req, Res any](
httpClient HTTPClient,
url string,
options ...ClientOption,
) (*Client[Req, Res], error) {
) *Client[Req, Res] {
client := &Client[Req, Res]{}
config, err := newClientConfiguration(url, options)
if err != nil {
return nil, err
client.err = err
return client
}
protocolClient, protocolErr := config.Protocol.NewClient(&protocolClientParams{
client.config = config
protocolClient, protocolErr := client.config.Protocol.NewClient(&protocolClientParams{
CompressionName: config.RequestCompressionName,
CompressionPools: newReadOnlyCompressionPools(config.CompressionPools),
Codec: config.Codec,
Expand All @@ -57,8 +61,10 @@ func NewClient[Req, Res any](
BufferPool: config.BufferPool,
})
if protocolErr != nil {
return nil, protocolErr
client.err = protocolErr
return client
}
client.protocolClient = protocolClient
// Rather than applying unary interceptors along the hot path, we can do it
// once at client creation.
unarySpec := config.newSpecification(StreamTypeUnary)
Expand Down Expand Up @@ -86,7 +92,7 @@ func NewClient[Req, Res any](
if ic := config.Interceptor; ic != nil {
unaryFunc = ic.WrapUnary(unaryFunc)
}
callUnary := func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
// To make the specification and RPC headers visible to the full interceptor
// chain (as though they were supplied by the caller), we'll add them here.
request.spec = unarySpec
Expand All @@ -101,23 +107,25 @@ func NewClient[Req, Res any](
}
return typed, nil
}
return &Client[Req, Res]{
config: config,
callUnary: callUnary,
protocolClient: protocolClient,
}, nil
return client
}

// CallUnary calls a request-response procedure.
func (c *Client[Req, Res]) CallUnary(
ctx context.Context,
req *Request[Req],
) (*Response[Res], error) {
if c.err != nil {
return nil, c.err
}
return c.callUnary(ctx, req)
}

// CallClientStream calls a client streaming procedure.
func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] {
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeClient)
return &ClientStreamForClient[Req, Res]{sender: sender, receiver: receiver}
}
Expand All @@ -127,6 +135,9 @@ func (c *Client[Req, Res]) CallServerStream(
ctx context.Context,
req *Request[Req],
) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
sender, receiver := c.newStream(ctx, StreamTypeServer)
mergeHeaders(sender.Header(), req.header)
// Send always returns an io.EOF unless the error is from the client-side.
Expand All @@ -145,6 +156,9 @@ func (c *Client[Req, Res]) CallServerStream(

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
sender, receiver := c.newStream(ctx, StreamTypeBidi)
return &BidiStreamForClient[Req, Res]{sender: sender, receiver: receiver}
}
Expand Down
6 changes: 1 addition & 5 deletions client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,11 @@ func Example_client() {
// client that communicate over in-memory pipes. Don't do this in production!
httpClient = examplePingServer.Client()

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
httpClient,
examplePingServer.URL(),
connect.WithGRPC(),
)
if err != nil {
logger.Println("error:", err)
return
}
res, err := client.Ping(
context.Background(),
connect.NewRequest(&pingv1.PingRequest{Number: 42}),
Expand Down
22 changes: 22 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
type ClientStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
// Error from client construction. If non-nil, return for all calls.
err error
}

// RequestHeader returns the request headers. Headers are sent to the server with the
Expand All @@ -42,12 +44,18 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
// Clients should check for case using the standard library's errors.Is and
// unmarshal the error using CloseAndReceive.
func (c *ClientStreamForClient[Req, Res]) Send(msg *Req) error {
if c.err != nil {
return c.err
}
return c.sender.Send(msg)
}

// CloseAndReceive closes the send side of the stream and waits for the
// response.
func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], error) {
if c.err != nil {
return nil, c.err
}
if err := c.sender.Close(nil); err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,6 +132,8 @@ func (s *ServerStreamForClient[Res]) Close() error {
type BidiStreamForClient[Req, Res any] struct {
sender Sender
receiver Receiver
// Error from client construction. If non-nil, return for all calls.
err error
}

// RequestHeader returns the request headers. Headers are sent with the first
Expand All @@ -139,17 +149,26 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
// Clients should check for case using the standard library's errors.Is and
// unmarshal the error using Receive.
func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error {
if b.err != nil {
return b.err
}
return b.sender.Send(msg)
}

// CloseSend closes the send side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseSend() error {
if b.err != nil {
return b.err
}
return b.sender.Close(nil)
}

// Receive a message. When the server is done sending messages and no other
// errors have occurred, Receive will return an error that wraps io.EOF.
func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
if b.err != nil {
return nil, b.err
}
var res Res
if err := b.receiver.Receive(&res); err != nil {
return nil, err
Expand All @@ -159,6 +178,9 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {

// CloseReceive closes the receive side of the stream.
func (b *BidiStreamForClient[Req, Res]) CloseReceive() error {
if b.err != nil {
return b.err
}
return b.receiver.Close()
}

Expand Down
22 changes: 8 additions & 14 deletions cmd/protoc-gen-connect-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) {
g.P()
wrapComments(g, "This is a compile-time assertion to ensure that this generated file ",
"and the connect package are compatible. If you get a compiler error that this constant ",
"isn't defined, this code was generated with a version of connect newer than the one ",
"is not defined, this code was generated with a version of connect newer than the one ",
"compiled into your binary. You can fix the problem by either regenerating this code ",
"with an older version of connect or updating the connect version compiled into your binary.")
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_0_1"))
Expand Down Expand Up @@ -199,27 +199,21 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
deprecated(g)
}
g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"),
", baseURL string, opts ...", clientOption, ") (", names.Client, ", error) {")
", baseURL string, opts ...", clientOption, ") ", names.Client, " {")
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), "Client, err := ",
g.P(unexport(method.GoName), ": ",
connectPackage.Ident("NewClient"),
"[", method.Input.GoIdent, ", ", method.Output.GoIdent, "]",
"(",
)
g.P("httpClient,")
g.P(`baseURL + "`, procedureName(method), `",`)
g.P("opts...,")
g.P(")")
g.P("if err != nil {")
g.P("return nil, err")
g.P("}")
}
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), ": ", unexport(method.GoName), "Client,")
g.P("),")
}
g.P("}, nil")
g.P("}")
g.P("}")
g.P()

Expand Down Expand Up @@ -359,11 +353,11 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic
if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
g.P("return ", connectPackage.Ident("NewError"), "(",
connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"),
`("`, method.Desc.FullName(), ` isn't implemented"))`)
`("`, method.Desc.FullName(), ` is not implemented"))`)
} else {
g.P("return nil, ", connectPackage.Ident("NewError"), "(",
connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"),
`("`, method.Desc.FullName(), ` isn't implemented"))`)
`("`, method.Desc.FullName(), ` is not implemented"))`)
}
g.P("}")
g.P()
Expand Down
16 changes: 6 additions & 10 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ func TestServer(t *testing.T) {
testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { // nolint:thelper
run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
assert.Nil(t, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
testPing(t, client)
testSum(t, client)
testCountUp(t, client)
Expand Down Expand Up @@ -354,8 +353,7 @@ func TestHeaderBasic(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
assert.Nil(t, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC())
req := connect.NewRequest(&pingv1.PingRequest{})
req.Header().Set(key, cval)
res, err := client.Ping(context.Background(), req)
Expand All @@ -378,10 +376,9 @@ func TestMarshalStatusError(t *testing.T) {

assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) {
tb.Helper()
client, err := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
assert.Nil(tb, err)
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...)
req := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)})
_, err = client.Fail(context.Background(), req)
_, err := client.Fail(context.Background(), req)
tb.Log(err)
assert.NotNil(t, err)
var connectErr *connect.Error
Expand All @@ -406,16 +403,15 @@ func TestBidiRequiresHTTP2(t *testing.T) {
})
server := httptest.NewServer(handler)
defer server.Close()
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithGRPC(),
)
assert.Nil(t, err)
stream := client.CumSum(context.Background())
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{}))
assert.Nil(t, stream.CloseSend())
_, err = stream.Receive()
_, err := stream.Receive()
assert.NotNil(t, err)
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
Expand Down
12 changes: 2 additions & 10 deletions interceptor_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@ func ExampleInterceptor() {
})
},
)
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
examplePingServer.Client(),
examplePingServer.URL(),
connect.WithGRPC(),
connect.WithInterceptors(loggingInterceptor),
)
if err != nil {
logger.Println("error:", err)
return
}
if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42})); err != nil {
logger.Println("error:", err)
return
Expand Down Expand Up @@ -84,16 +80,12 @@ func ExampleWithInterceptors() {
})
},
)
client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
examplePingServer.Client(),
examplePingServer.URL(),
connect.WithGRPC(),
connect.WithInterceptors(outer, inner),
)
if err != nil {
logger.Println("error:", err)
return
}
if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})); err != nil {
logger.Println("error:", err)
return
Expand Down
10 changes: 4 additions & 6 deletions interceptor_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (

func TestClientStreamErrors(t *testing.T) {
t.Parallel()
_, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC())
_, err := pingv1connect.NewPingServiceClient(http.DefaultClient, "INVALID_URL", connect.WithGRPC()).Ping(context.Background(), nil)
assert.NotNil(t, err)
assert.Match(t, err.Error(), "missing scheme")
// We don't even get to calling methods on the client, so there's no question
// of interceptors running. Once we're calling methods on the client, all
// errors are visible to interceptors.
Expand Down Expand Up @@ -180,17 +181,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

client, err := pingv1connect.NewPingServiceClient(
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithGRPC(),
clientOnion,
)
_, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10}))
assert.Nil(t, err)

_, err = client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10}))
assert.Nil(t, err)

_, err = client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10}))
assert.Nil(t, err)
}
Expand Down
Loading