Skip to content

Commit

Permalink
interceptor: new APIs for chaining client interceptors. (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiranFang authored and menghanl committed Apr 15, 2019
1 parent a9de79b commit 776edd3
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 3 deletions.
204 changes: 203 additions & 1 deletion call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ type server struct {
conns map[transport.ServerTransport]bool
}

type ctxKey string

func newTestServer() *server {
return &server{startedErr: make(chan error, 1)}
}
Expand Down Expand Up @@ -202,17 +204,217 @@ func (s *server) stop() {
}

func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
return setUpWithOptions(t, port, maxStreams)
}

func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
server := newTestServer()
go server.start(t, port, maxStreams)
server.wait(t, 2*time.Second)
addr := "localhost:" + server.port
cc, err := Dial(addr, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
cc, err := Dial(addr, dopts...)
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
return server, cc
}

func (s) TestUnaryClientInterceptor(t *testing.T) {
parentKey := ctxKey("parentKey")

interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("interceptor should have %v in context", parentKey)
}
return invoker(ctx, method, req, reply, cc, opts...)
}

server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
defer func() {
cc.Close()
server.stop()
}()

var reply string
ctx := context.Background()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}

func (s) TestChainUnaryClientInterceptor(t *testing.T) {
var (
parentKey = ctxKey("parentKey")
firstIntKey = ctxKey("firstIntKey")
secondIntKey = ctxKey("secondIntKey")
)

firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("first interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) != nil {
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
}
firstCtx := context.WithValue(ctx, firstIntKey, 1)
err := invoker(firstCtx, method, req, reply, cc, opts...)
*(reply.(*string)) += "1"
return err
}

secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("second interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) == nil {
t.Fatalf("second interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
}
secondCtx := context.WithValue(ctx, secondIntKey, 2)
err := invoker(secondCtx, method, req, reply, cc, opts...)
*(reply.(*string)) += "2"
return err
}

lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("last interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) == nil {
t.Fatalf("last interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) == nil {
t.Fatalf("last interceptor should have %v in context", secondIntKey)
}
err := invoker(ctx, method, req, reply, cc, opts...)
*(reply.(*string)) += "3"
return err
}

server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
defer func() {
cc.Close()
server.stop()
}()

var reply string
ctx := context.Background()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}

func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
var (
parentKey = ctxKey("parentKey")
baseIntKey = ctxKey("baseIntKey")
)

baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("base interceptor should have %v in context", parentKey)
}
if ctx.Value(baseIntKey) != nil {
t.Fatalf("base interceptor should not have %v in context", baseIntKey)
}
baseCtx := context.WithValue(ctx, baseIntKey, 1)
return invoker(baseCtx, method, req, reply, cc, opts...)
}

chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
if ctx.Value(parentKey) == nil {
t.Fatalf("chain interceptor should have %v in context", parentKey)
}
if ctx.Value(baseIntKey) == nil {
t.Fatalf("chain interceptor should have %v in context", baseIntKey)
}
return invoker(ctx, method, req, reply, cc, opts...)
}

server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
defer func() {
cc.Close()
server.stop()
}()

var reply string
ctx := context.Background()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}

func (s) TestChainStreamClientInterceptor(t *testing.T) {
var (
parentKey = ctxKey("parentKey")
firstIntKey = ctxKey("firstIntKey")
secondIntKey = ctxKey("secondIntKey")
)

firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
if ctx.Value(parentKey) == nil {
t.Fatalf("first interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) != nil {
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
}
firstCtx := context.WithValue(ctx, firstIntKey, 1)
return streamer(firstCtx, desc, cc, method, opts...)
}

secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
if ctx.Value(parentKey) == nil {
t.Fatalf("second interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) == nil {
t.Fatalf("second interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) != nil {
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
}
secondCtx := context.WithValue(ctx, secondIntKey, 2)
return streamer(secondCtx, desc, cc, method, opts...)
}

lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
if ctx.Value(parentKey) == nil {
t.Fatalf("last interceptor should have %v in context", parentKey)
}
if ctx.Value(firstIntKey) == nil {
t.Fatalf("last interceptor should have %v in context", firstIntKey)
}
if ctx.Value(secondIntKey) == nil {
t.Fatalf("last interceptor should have %v in context", secondIntKey)
}
return streamer(ctx, desc, cc, method, opts...)
}

server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
defer func() {
cc.Close()
server.stop()
}()

ctx := context.Background()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
if err != nil {
t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
}
}

func (s) TestInvoke(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
Expand Down
65 changes: 65 additions & 0 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
opt.apply(&cc.dopts)
}

chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)

defer func() {
if err != nil {
cc.Close()
Expand Down Expand Up @@ -327,6 +330,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
return cc, nil
}

// chainUnaryClientInterceptors chains all unary client interceptors into one.
func chainUnaryClientInterceptors(cc *ClientConn) {
interceptors := cc.dopts.chainUnaryInts
// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
// be executed before any other chained interceptors.
if cc.dopts.unaryInt != nil {
interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
}
var chainedInt UnaryClientInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
}
}
cc.dopts.unaryInt = chainedInt
}

// getChainUnaryInvoker recursively generate the chained unary invoker.
func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
if curr == len(interceptors)-1 {
return finalInvoker
}
return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
}
}

// chainStreamClientInterceptors chains all stream client interceptors into one.
func chainStreamClientInterceptors(cc *ClientConn) {
interceptors := cc.dopts.chainStreamInts
// Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will
// be executed before any other chained interceptors.
if cc.dopts.streamInt != nil {
interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...)
}
var chainedInt StreamClientInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...)
}
}
cc.dopts.streamInt = chainedInt
}

// getChainStreamer recursively generate the chained client stream constructor.
func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer {
if curr == len(interceptors)-1 {
return finalStreamer
}
return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...)
}
}

// connectivityStateManager keeps the connectivity.State of ClientConn.
// This struct will eventually be exported so the balancers can access it.
type connectivityStateManager struct {
Expand Down
30 changes: 28 additions & 2 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ import (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor

chainUnaryInts []UnaryClientInterceptor
chainStreamInts []StreamClientInterceptor

cp Compressor
dc Decompressor
bs backoff.Strategy
Expand Down Expand Up @@ -414,6 +418,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
})
}

// WithChainUnaryInterceptor returns a DialOption that specifies the chained
// interceptor for unary RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All interceptors added by this method will be chained, and the interceptor
// defined by WithUnaryInterceptor will always be prepended to the chain.
func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
})
}

// WithStreamInterceptor returns a DialOption that specifies the interceptor for
// streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
Expand All @@ -422,6 +437,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
})
}

// WithChainStreamInterceptor returns a DialOption that specifies the chained
// interceptor for unary RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All interceptors added by this method will be chained, and the interceptor
// defined by WithStreamInterceptor will always be prepended to the chain.
func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.chainStreamInts = append(o.chainStreamInts, interceptors...)
})
}

// WithAuthority returns a DialOption that specifies the value to be used as the
// :authority pseudo-header. This value only works with WithInsecure and has no
// effect if TransportCredentials are present.
Expand Down

0 comments on commit 776edd3

Please sign in to comment.