From 776edd3ba1f9c75bd87e7131639d1815df428aa1 Mon Sep 17 00:00:00 2001 From: Weiran Fang <8175562+WeiranFang@users.noreply.github.com> Date: Mon, 15 Apr 2019 11:13:34 -0700 Subject: [PATCH] interceptor: new APIs for chaining client interceptors. (#2696) --- call_test.go | 204 ++++++++++++++++++++++++++++++++++++++++++++++++- clientconn.go | 65 ++++++++++++++++ dialoptions.go | 30 +++++++- 3 files changed, 296 insertions(+), 3 deletions(-) diff --git a/call_test.go b/call_test.go index a51108d50ebe..78760ba5297a 100644 --- a/call_test.go +++ b/call_test.go @@ -123,6 +123,8 @@ type server struct { conns map[transport.ServerTransport]bool } +type ctxKey string + func newTestServer() *server { return &server{startedErr: make(chan error, 1)} } @@ -202,17 +204,217 @@ func (s *server) stop() { } func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) { + return setUpWithOptions(t, port, maxStreams) +} + +func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) { server := newTestServer() go server.start(t, port, maxStreams) server.wait(t, 2*time.Second) addr := "localhost:" + server.port - cc, err := Dial(addr, WithBlock(), WithInsecure(), WithCodec(testCodec{})) + dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial(addr, dopts...) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } return server, cc } +func (s) TestUnaryClientInterceptor(t *testing.T) { + parentKey := ctxKey("parentKey") + + interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("interceptor should have %v in context", parentKey) + } + return invoker(ctx, method, req, reply, cc, opts...) + } + + server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor)) + defer func() { + cc.Close() + server.stop() + }() + + var reply string + ctx := context.Background() + parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) + if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) + } +} + +func (s) TestChainUnaryClientInterceptor(t *testing.T) { + var ( + parentKey = ctxKey("parentKey") + firstIntKey = ctxKey("firstIntKey") + secondIntKey = ctxKey("secondIntKey") + ) + + firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("first interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) != nil { + t.Fatalf("first interceptor should not have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + t.Fatalf("first interceptor should not have %v in context", secondIntKey) + } + firstCtx := context.WithValue(ctx, firstIntKey, 1) + err := invoker(firstCtx, method, req, reply, cc, opts...) + *(reply.(*string)) += "1" + return err + } + + secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("second interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) == nil { + t.Fatalf("second interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + t.Fatalf("second interceptor should not have %v in context", secondIntKey) + } + secondCtx := context.WithValue(ctx, secondIntKey, 2) + err := invoker(secondCtx, method, req, reply, cc, opts...) + *(reply.(*string)) += "2" + return err + } + + lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("last interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) == nil { + t.Fatalf("last interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) == nil { + t.Fatalf("last interceptor should have %v in context", secondIntKey) + } + err := invoker(ctx, method, req, reply, cc, opts...) + *(reply.(*string)) += "3" + return err + } + + server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt)) + defer func() { + cc.Close() + server.stop() + }() + + var reply string + ctx := context.Background() + parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) + if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) + } +} + +func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) { + var ( + parentKey = ctxKey("parentKey") + baseIntKey = ctxKey("baseIntKey") + ) + + baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("base interceptor should have %v in context", parentKey) + } + if ctx.Value(baseIntKey) != nil { + t.Fatalf("base interceptor should not have %v in context", baseIntKey) + } + baseCtx := context.WithValue(ctx, baseIntKey, 1) + return invoker(baseCtx, method, req, reply, cc, opts...) + } + + chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + if ctx.Value(parentKey) == nil { + t.Fatalf("chain interceptor should have %v in context", parentKey) + } + if ctx.Value(baseIntKey) == nil { + t.Fatalf("chain interceptor should have %v in context", baseIntKey) + } + return invoker(ctx, method, req, reply, cc, opts...) + } + + server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt)) + defer func() { + cc.Close() + server.stop() + }() + + var reply string + ctx := context.Background() + parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) + if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) + } +} + +func (s) TestChainStreamClientInterceptor(t *testing.T) { + var ( + parentKey = ctxKey("parentKey") + firstIntKey = ctxKey("firstIntKey") + secondIntKey = ctxKey("secondIntKey") + ) + + firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { + if ctx.Value(parentKey) == nil { + t.Fatalf("first interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) != nil { + t.Fatalf("first interceptor should not have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + t.Fatalf("first interceptor should not have %v in context", secondIntKey) + } + firstCtx := context.WithValue(ctx, firstIntKey, 1) + return streamer(firstCtx, desc, cc, method, opts...) + } + + secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { + if ctx.Value(parentKey) == nil { + t.Fatalf("second interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) == nil { + t.Fatalf("second interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + t.Fatalf("second interceptor should not have %v in context", secondIntKey) + } + secondCtx := context.WithValue(ctx, secondIntKey, 2) + return streamer(secondCtx, desc, cc, method, opts...) + } + + lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { + if ctx.Value(parentKey) == nil { + t.Fatalf("last interceptor should have %v in context", parentKey) + } + if ctx.Value(firstIntKey) == nil { + t.Fatalf("last interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) == nil { + t.Fatalf("last interceptor should have %v in context", secondIntKey) + } + return streamer(ctx, desc, cc, method, opts...) + } + + server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt)) + defer func() { + cc.Close() + server.stop() + }() + + ctx := context.Background() + parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) + _, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar") + if err != nil { + t.Fatalf("grpc.NewStream(_, _, _) = %v, want ", err) + } +} + func (s) TestInvoke(t *testing.T) { server, cc := setUp(t, 0, math.MaxUint32) var reply string diff --git a/clientconn.go b/clientconn.go index bd2d2b317798..8255caef929b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * opt.apply(&cc.dopts) } + chainUnaryClientInterceptors(cc) + chainStreamClientInterceptors(cc) + defer func() { if err != nil { cc.Close() @@ -327,6 +330,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * return cc, nil } +// chainUnaryClientInterceptors chains all unary client interceptors into one. +func chainUnaryClientInterceptors(cc *ClientConn) { + interceptors := cc.dopts.chainUnaryInts + // Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will + // be executed before any other chained interceptors. + if cc.dopts.unaryInt != nil { + interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...) + } + var chainedInt UnaryClientInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { + return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...) + } + } + cc.dopts.unaryInt = chainedInt +} + +// getChainUnaryInvoker recursively generate the chained unary invoker. +func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker { + if curr == len(interceptors)-1 { + return finalInvoker + } + return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error { + return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...) + } +} + +// chainStreamClientInterceptors chains all stream client interceptors into one. +func chainStreamClientInterceptors(cc *ClientConn) { + interceptors := cc.dopts.chainStreamInts + // Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + if cc.dopts.streamInt != nil { + interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...) + } + var chainedInt StreamClientInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { + return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...) + } + } + cc.dopts.streamInt = chainedInt +} + +// getChainStreamer recursively generate the chained client stream constructor. +func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer { + if curr == len(interceptors)-1 { + return finalStreamer + } + return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...) + } +} + // connectivityStateManager keeps the connectivity.State of ClientConn. // This struct will eventually be exported so the balancers can access it. type connectivityStateManager struct { diff --git a/dialoptions.go b/dialoptions.go index e114fecbb7b4..1fdc619d1730 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -39,8 +39,12 @@ import ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - unaryInt UnaryClientInterceptor - streamInt StreamClientInterceptor + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + + chainUnaryInts []UnaryClientInterceptor + chainStreamInts []StreamClientInterceptor + cp Compressor dc Decompressor bs backoff.Strategy @@ -414,6 +418,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { }) } +// WithChainUnaryInterceptor returns a DialOption that specifies the chained +// interceptor for unary RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All interceptors added by this method will be chained, and the interceptor +// defined by WithUnaryInterceptor will always be prepended to the chain. +func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.chainUnaryInts = append(o.chainUnaryInts, interceptors...) + }) +} + // WithStreamInterceptor returns a DialOption that specifies the interceptor for // streaming RPCs. func WithStreamInterceptor(f StreamClientInterceptor) DialOption { @@ -422,6 +437,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption { }) } +// WithChainStreamInterceptor returns a DialOption that specifies the chained +// interceptor for unary RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All interceptors added by this method will be chained, and the interceptor +// defined by WithStreamInterceptor will always be prepended to the chain. +func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.chainStreamInts = append(o.chainStreamInts, interceptors...) + }) +} + // WithAuthority returns a DialOption that specifies the value to be used as the // :authority pseudo-header. This value only works with WithInsecure and has no // effect if TransportCredentials are present.