Skip to content

Commit

Permalink
Cleanup (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
bufdev authored Apr 22, 2022
1 parent 5a13642 commit 01d0833
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 204 deletions.
40 changes: 20 additions & 20 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,38 +75,38 @@ type ping struct {
}

func BenchmarkREST(b *testing.B) {
handler := func(writer http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
handler := func(writer http.ResponseWriter, request *http.Request) {
defer request.Body.Close()
defer func() {
_, err := io.Copy(io.Discard, req.Body)
_, err := io.Copy(io.Discard, request.Body)
assert.Nil(b, err)
}()
writer.Header().Set("Content-Type", "application/json")
var body io.Reader = req.Body
if req.Header.Get("Content-Encoding") == "gzip" {
gr, err := gzip.NewReader(body)
var body io.Reader = request.Body
if request.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(body)
if err != nil {
b.Fatalf("get gzip reader: %v", err)
}
defer gr.Close()
body = gr
defer gzipReader.Close()
body = gzipReader
}
var out io.Writer = writer
if strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
if strings.Contains(request.Header.Get("Accept-Encoding"), "gzip") {
writer.Header().Set("Content-Encoding", "gzip")
gw := gzip.NewWriter(writer)
defer gw.Close()
out = gw
gzipWriter := gzip.NewWriter(writer)
defer gzipWriter.Close()
out = gzipWriter
}
raw, err := io.ReadAll(body)
if err != nil {
b.Fatalf("read body: %v", err)
}
var pingReq ping
if err := json.Unmarshal(raw, &pingReq); err != nil {
var pingRequest ping
if err := json.Unmarshal(raw, &pingRequest); err != nil {
b.Fatalf("json unmarshal: %v", err)
}
bs, err := json.Marshal(&pingReq)
bs, err := json.Marshal(&pingRequest)
if err != nil {
b.Fatalf("json marshal: %v", err)
}
Expand Down Expand Up @@ -151,18 +151,18 @@ func unaryRESTIteration(b *testing.B, client *http.Client, url string, text stri
request.Header.Set("Content-Encoding", "gzip")
request.Header.Set("Accept-Encoding", "gzip")
request.Header.Set("Content-Type", "application/json")
res, err := client.Do(request)
response, err := client.Do(request)
if err != nil {
b.Fatalf("do request: %v", err)
}
defer func() {
_, err := io.Copy(io.Discard, res.Body)
_, err := io.Copy(io.Discard, response.Body)
assert.Nil(b, err)
}()
if res.StatusCode != http.StatusOK {
b.Fatalf("response status: %v", res.Status)
if response.StatusCode != http.StatusOK {
b.Fatalf("response status: %v", response.Status)
}
uncompressed, err := gzip.NewReader(res.Body)
uncompressed, err := gzip.NewReader(response.Body)
if err != nil {
b.Fatalf("uncompress response: %v", err)
}
Expand Down
76 changes: 34 additions & 42 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,33 @@ import (
// explicitly choose a protocol with either the WithGRPC or WithGRPCWeb
// options.
type Client[Req, Res any] struct {
config *clientConfiguration
config *clientConfig
callUnary func(context.Context, *Request[Req]) (*Response[Res], error)
protocolClient protocolClient
err error
}

// NewClient constructs a new Client.
func NewClient[Req, Res any](
httpClient HTTPClient,
url string,
options ...ClientOption,
) *Client[Req, Res] {
func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...ClientOption) *Client[Req, Res] {
client := &Client[Req, Res]{}
config, err := newClientConfiguration(url, options)
config, err := newClientConfig(url, options)
if err != nil {
client.err = err
return client
}
client.config = config
protocolClient, protocolErr := client.config.Protocol.NewClient(&protocolClientParams{
CompressionName: config.RequestCompressionName,
CompressionPools: newReadOnlyCompressionPools(config.CompressionPools),
Codec: config.Codec,
Protobuf: config.protobuf(),
CompressMinBytes: config.CompressMinBytes,
HTTPClient: httpClient,
URL: url,
BufferPool: config.BufferPool,
})
protocolClient, protocolErr := client.config.Protocol.NewClient(
&protocolClientParams{
CompressionName: config.RequestCompressionName,
CompressionPools: newReadOnlyCompressionPools(config.CompressionPools),
Codec: config.Codec,
Protobuf: config.protobuf(),
CompressMinBytes: config.CompressMinBytes,
HTTPClient: httpClient,
URL: url,
BufferPool: config.BufferPool,
},
)
if protocolErr != nil {
client.err = protocolErr
return client
Expand Down Expand Up @@ -89,8 +87,8 @@ func NewClient[Req, Res any](
}
return response, receiver.Close()
})
if ic := config.Interceptor; ic != nil {
unaryFunc = ic.WrapUnary(unaryFunc)
if interceptor := config.Interceptor; interceptor != nil {
unaryFunc = interceptor.WrapUnary(unaryFunc)
}
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
// To make the specification and RPC headers visible to the full interceptor
Expand All @@ -111,14 +109,11 @@ func NewClient[Req, Res any](
}

// CallUnary calls a request-response procedure.
func (c *Client[Req, Res]) CallUnary(
ctx context.Context,
req *Request[Req],
) (*Response[Res], error) {
func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
if c.err != nil {
return nil, c.err
}
return c.callUnary(ctx, req)
return c.callUnary(ctx, request)
}

// CallClientStream calls a client streaming procedure.
Expand All @@ -131,19 +126,16 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(
ctx context.Context,
req *Request[Req],
) (*ServerStreamForClient[Res], error) {
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
sender, receiver := c.newStream(ctx, StreamTypeServer)
mergeHeaders(sender.Header(), req.header)
mergeHeaders(sender.Header(), request.header)
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
if err := sender.Send(req.Msg); err != nil && !errors.Is(err, io.EOF) {
if err := sender.Send(request.Msg); err != nil && !errors.Is(err, io.EOF) {
_ = sender.Close(err)
_ = receiver.Close()
return nil, err
Expand All @@ -164,20 +156,20 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
}

func (c *Client[Req, Res]) newStream(ctx context.Context, streamType StreamType) (Sender, Receiver) {
if ic := c.config.Interceptor; ic != nil {
ctx = ic.WrapStreamContext(ctx)
if interceptor := c.config.Interceptor; interceptor != nil {
ctx = interceptor.WrapStreamContext(ctx)
}
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(header)
sender, receiver := c.protocolClient.NewStream(ctx, c.config.newSpecification(streamType), header)
if ic := c.config.Interceptor; ic != nil {
sender = ic.WrapStreamSender(ctx, sender)
receiver = ic.WrapStreamReceiver(ctx, receiver)
if interceptor := c.config.Interceptor; interceptor != nil {
sender = interceptor.WrapStreamSender(ctx, sender)
receiver = interceptor.WrapStreamReceiver(ctx, receiver)
}
return sender, receiver
}

type clientConfiguration struct {
type clientConfig struct {
Protocol protocol
Procedure string
CompressMinBytes int
Expand All @@ -188,9 +180,9 @@ type clientConfiguration struct {
BufferPool *bufferPool
}

func newClientConfiguration(url string, options []ClientOption) (*clientConfiguration, *Error) {
protoPath := extractProtobufPath(url)
config := clientConfiguration{
func newClientConfig(url string, options []ClientOption) (*clientConfig, *Error) {
protoPath := extractProtoPath(url)
config := clientConfig{
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
BufferPool: newBufferPool(),
Expand All @@ -206,7 +198,7 @@ func newClientConfiguration(url string, options []ClientOption) (*clientConfigur
return &config, nil
}

func (c *clientConfiguration) validate() *Error {
func (c *clientConfig) validate() *Error {
if c.Codec == nil || c.Codec.Name() == "" {
return errorf(CodeUnknown, "no codec configured")
}
Expand All @@ -224,14 +216,14 @@ func (c *clientConfiguration) validate() *Error {
return nil
}

func (c *clientConfiguration) protobuf() Codec {
func (c *clientConfig) protobuf() Codec {
if c.Codec.Name() == codecNameProto {
return c.Codec
}
return &protoBinaryCodec{}
}

func (c *clientConfiguration) newSpecification(t StreamType) Specification {
func (c *clientConfig) newSpecification(t StreamType) Specification {
return Specification{
StreamType: t,
Procedure: c.Procedure,
Expand Down
6 changes: 3 additions & 3 deletions client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ func Example_client() {
examplePingServer.URL(),
connect.WithGRPC(),
)
res, err := client.Ping(
response, err := client.Ping(
context.Background(),
connect.NewRequest(&pingv1.PingRequest{Number: 42}),
)
if err != nil {
logger.Println("error:", err)
return
}
logger.Println("response content-type:", res.Header().Get("Content-Type"))
logger.Println("response message:", res.Msg)
logger.Println("response content-type:", response.Header().Get("Content-Type"))
logger.Println("response message:", response.Msg)

// Output:
// response content-type: application/grpc+proto
Expand Down
14 changes: 7 additions & 7 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
// If the server returns an error, Send returns an error that wraps io.EOF.
// Clients should check for case using the standard library's errors.Is and
// unmarshal the error using CloseAndReceive.
func (c *ClientStreamForClient[Req, Res]) Send(msg *Req) error {
func (c *ClientStreamForClient[Req, Res]) Send(request *Req) error {
if c.err != nil {
return c.err
}
return c.sender.Send(msg)
return c.sender.Send(request)
}

// CloseAndReceive closes the send side of the stream and waits for the
Expand All @@ -59,15 +59,15 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
if err := c.sender.Close(nil); err != nil {
return nil, err
}
res, err := receiveUnaryResponse[Res](c.receiver)
response, err := receiveUnaryResponse[Res](c.receiver)
if err != nil {
_ = c.receiver.Close()
return nil, err
}
if err := c.receiver.Close(); err != nil {
return nil, err
}
return res, nil
return response, nil
}

// ServerStreamForClient is the client's view of a server streaming RPC.
Expand Down Expand Up @@ -169,11 +169,11 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
if b.err != nil {
return nil, b.err
}
var res Res
if err := b.receiver.Receive(&res); err != nil {
var msg Res
if err := b.receiver.Receive(&msg); err != nil {
return nil, err
}
return &res, nil
return &msg, nil
}

// CloseReceive closes the receive side of the stream.
Expand Down
18 changes: 9 additions & 9 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,31 +123,31 @@ type readOnlyCompressionPools interface {
CommaSeparatedNames() string
}

func newReadOnlyCompressionPools(pools map[string]*compressionPool) readOnlyCompressionPools {
known := make([]string, 0, len(pools))
for name := range pools {
known = append(known, name)
func newReadOnlyCompressionPools(nameToPool map[string]*compressionPool) readOnlyCompressionPools {
knownNames := make([]string, 0, len(nameToPool))
for name := range nameToPool {
knownNames = append(knownNames, name)
}
return &namedCompressionPools{
nameToPools: pools,
commaSeparatedNames: strings.Join(known, ","),
nameToPool: nameToPool,
commaSeparatedNames: strings.Join(knownNames, ","),
}
}

type namedCompressionPools struct {
nameToPools map[string]*compressionPool
nameToPool map[string]*compressionPool
commaSeparatedNames string
}

func (m *namedCompressionPools) Get(name string) *compressionPool {
if name == "" || name == compressionIdentity {
return nil
}
return m.nameToPools[name]
return m.nameToPool[name]
}

func (m *namedCompressionPools) Contains(name string) bool {
_, ok := m.nameToPools[name]
_, ok := m.nameToPool[name]
return ok
}

Expand Down
Loading

0 comments on commit 01d0833

Please sign in to comment.