diff --git a/grpcweb/client.go b/grpcweb/client.go index abc1e263..0a3eae7b 100644 --- a/grpcweb/client.go +++ b/grpcweb/client.go @@ -66,10 +66,11 @@ func (c Client) RPCCall(ctx context.Context, method string, req []byte, opts ... errChan <- io.EOF // Success! } } - err := invoke(ctx, c.host, c.service, method, req, onMsg, onEnd, opts...) + cancel, err := invoke(ctx, c.host, c.service, method, req, onMsg, onEnd, opts...) if err != nil { return nil, err } + defer cancel() select { case err := <-errChan: diff --git a/grpcweb/invoke.go b/grpcweb/invoke.go index a6995640..f962416c 100644 --- a/grpcweb/invoke.go +++ b/grpcweb/invoke.go @@ -12,8 +12,18 @@ import ( ) // Invoke populates the necessary JS structures and performs the gRPC-web call. -// It attempts to catch any JS errors thrown. -func invoke(ctx context.Context, host, service, method string, req []byte, onMsg onMessageFunc, onEnd onEndFunc, opts ...CallOption) (err error) { +// It attempts to catch any JS errors thrown. It returns a function that can +// be used to cancel the request. +func invoke( + ctx context.Context, + host, + service, + method string, + req []byte, + onMsg onMessageFunc, + onEnd onEndFunc, + opts ...CallOption, +) (cancel context.CancelFunc, err error) { methodDesc := newMethodDescriptor(newService(service), method, newResponseType()) c := &callInfo{} @@ -56,11 +66,16 @@ func invoke(ctx context.Context, host, service, method string, req []byte, onMsg // Perform CallOptions required before call for _, o := range opts { if err := o.before(c); err != nil { - return status.FromError(err) + return nil, status.FromError(err) } } - js.Global.Get("grpc").Call("invoke", methodDesc, props) + request := js.Global.Get("grpc").Call("invoke", methodDesc, props) - return nil + cancelFunc := func() { + // https://github.com/improbable-eng/grpc-web/blob/0ab7201b53447db59d63ff3a95173e565baae10a/ts/src/grpc.ts#L310 + request.Call("abort") + } + + return cancelFunc, nil } diff --git a/grpcweb/stream.go b/grpcweb/stream.go index db137adb..47765807 100644 --- a/grpcweb/stream.go +++ b/grpcweb/stream.go @@ -33,6 +33,7 @@ import ( // reader of messages received on the stream. type streamClient struct { ctx context.Context + cancel context.CancelFunc messages chan []byte errors chan error } @@ -61,11 +62,13 @@ func (c Client) NewServerStream(ctx context.Context, method string, req []byte, srv.errors <- io.EOF } } - err := invoke(ctx, c.host, c.service, method, req, onMsg, onEnd, opts...) + cancel, err := invoke(ctx, c.host, c.service, method, req, onMsg, onEnd, opts...) if err != nil { return nil, err } + srv.cancel = cancel + return srv, nil } @@ -77,6 +80,7 @@ func (s streamClient) RecvMsg() ([]byte, error) { case err := <-s.errors: return nil, err case <-s.ctx.Done(): + s.cancel() return nil, s.ctx.Err() } }