From 48185635ad7dd618d5c1d87076eea78e3ba9f0d6 Mon Sep 17 00:00:00 2001 From: Yamashou <1230124fw@gmail.com> Date: Sat, 30 Nov 2024 22:51:53 +0900 Subject: [PATCH] Add UnsafeChainInterceptor #241 --- clientgenv2/source_generator.go | 1 - clientv2/client.go | 45 ++++++++++- clientv2/client_test.go | 127 ++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 3 deletions(-) diff --git a/clientgenv2/source_generator.go b/clientgenv2/source_generator.go index 37cca75..e2ade39 100644 --- a/clientgenv2/source_generator.go +++ b/clientgenv2/source_generator.go @@ -295,7 +295,6 @@ func (r *SourceGenerator) OperationArguments(variableDefinitions ast.VariableDef func (r *SourceGenerator) Type(typeName string) types.Type { goType, err := r.binder.FindTypeFromName(r.cfg.Models[typeName].Model[0]) if err != nil { - // 実装として正しいtypeNameを渡していれば必ず見つかるはずなのでpanic panic(fmt.Sprintf("%+v", err)) } diff --git a/clientv2/client.go b/clientv2/client.go index fb441d0..5be9655 100644 --- a/clientv2/client.go +++ b/clientv2/client.go @@ -58,6 +58,27 @@ func ChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor { } } +func UnsafeChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor { + n := len(interceptors) + + return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + chainer := func(currentInter RequestInterceptor, currentFunc RequestInterceptorFunc) RequestInterceptorFunc { + return func(currentCtx context.Context, currentReq *http.Request, currentGqlInfo *GQLRequestInfo, currentRes any) error { + return currentInter(currentCtx, currentReq, currentGqlInfo, currentRes, func(nextCtx context.Context, nextReq *http.Request, nextGqlInfo *GQLRequestInfo, nextRes any) error { + return currentFunc(nextCtx, nextReq, nextGqlInfo, nextRes) + }) + } + } + + chainedHandler := next + for i := n - 1; i >= 0; i-- { + chainedHandler = chainer(interceptors[i], chainedHandler) + } + + return chainedHandler(ctx, req, gqlInfo, res) + } +} + // Client is the http client wrapper type Client struct { Client HttpClient @@ -65,6 +86,7 @@ type Client struct { RequestInterceptor RequestInterceptor CustomDo RequestInterceptorFunc ParseDataWhenErrors bool + IsUnsafeRequestInterceptor bool } // Request represents an outgoing GraphQL request @@ -91,6 +113,23 @@ func NewClient(client HttpClient, baseURL string, options *Options, interceptors return c } +func NewClientWithUnsafeRequestInterceptor(client HttpClient, baseURL string, options *Options, interceptors ...RequestInterceptor) *Client { + c := &Client{ + Client: client, + BaseURL: baseURL, + RequestInterceptor: UnsafeChainInterceptor(append([]RequestInterceptor{func(ctx context.Context, requestSet *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + return next(ctx, requestSet, gqlInfo, res) + }}, interceptors...)...), + IsUnsafeRequestInterceptor: true, + } + + if options != nil { + c.ParseDataWhenErrors = options.ParseDataAlongWithErrors + } + + return c +} + // Options is a struct that holds some client-specific options that can be passed to NewClient. type Options struct { // ParseDataAlongWithErrors is a flag that indicates whether the client should try to parse and return the data along with error @@ -211,6 +250,9 @@ func (c *Client) Post(ctx context.Context, operationName, query string, respData } f := ChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...) + if c.IsUnsafeRequestInterceptor { + f = UnsafeChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...) + } // if custom do is set, use it instead of the default one if c.CustomDo != nil { @@ -437,8 +479,7 @@ func checkImplements[I any](v reflect.Value) bool { t := v.Type() interfaceType := reflect.TypeOf((*I)(nil)).Elem() - // Check if the type implements the interface directly or as a pointer. - return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PtrTo(t).Implements(interfaceType)) + return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PointerTo(t).Implements(interfaceType)) } // encode returns an appropriate encoder function for the provided value. diff --git a/clientv2/client_test.go b/clientv2/client_test.go index 0fb4e06..7ca68cf 100644 --- a/clientv2/client_test.go +++ b/clientv2/client_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "reflect" "strconv" "testing" "time" @@ -831,3 +832,129 @@ func TestMarshalJSON(t *testing.T) { }) } } + +func TestUnsafeChainInterceptor(t *testing.T) { + t.Run("should modify values through interceptors", func(t *testing.T) { + // Prepare test values + originalCtx := context.Background() + originalReq, _ := http.NewRequest("POST", "http://example.com", nil) + originalGqlInfo := &GQLRequestInfo{ + Request: &Request{Query: "original"}, + } + originalRes := "original" + + // First interceptor: Add value to context + interceptor1 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + ctx = context.WithValue(ctx, "key1", "value1") + return next(ctx, req, gqlInfo, res) + } + + // Second interceptor: Modify request header + interceptor2 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + req.Header.Set("X-Test", "test-value") + return next(ctx, req, gqlInfo, res) + } + + // Third interceptor: Modify GQLInfo and response + interceptor3 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + gqlInfo.Request.Query = "modified" + return next(ctx, req, gqlInfo, "modified") + } + + // Final handler: Verify modified values + finalHandler := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error { + // Verify context + if v := ctx.Value("key1"); v != "value1" { + t.Errorf("context value not propagated, got %v", v) + } + + // Verify request header + if v := req.Header.Get("X-Test"); v != "test-value" { + t.Errorf("request header not modified, got %v", v) + } + + // Verify GQLInfo + if gqlInfo.Request.Query != "modified" { + t.Errorf("GQLInfo not modified, got %v", gqlInfo.Request.Query) + } + + // Verify response + if res != "modified" { + t.Errorf("response not modified, got %v", res) + } + + return nil + } + + // Create interceptor chain + chain := UnsafeChainInterceptor(interceptor1, interceptor2, interceptor3) + + // Execute chain + err := chain(originalCtx, originalReq, originalGqlInfo, originalRes, finalHandler) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("should properly propagate errors", func(t *testing.T) { + expectedError := errors.New("test error") + + // Interceptor that returns an error + errorInterceptor := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + return expectedError + } + + // Create chain + chain := UnsafeChainInterceptor(errorInterceptor) + + // Execute chain + err := chain( + context.Background(), + &http.Request{}, + &GQLRequestInfo{}, + nil, + func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error { + return nil + }, + ) + + if err != expectedError { + t.Errorf("expected error %v, got %v", expectedError, err) + } + }) + + t.Run("should execute interceptors in correct order", func(t *testing.T) { + var order []int + + // Create interceptors that record execution order + makeInterceptor := func(id int) RequestInterceptor { + return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + order = append(order, id) + err := next(ctx, req, gqlInfo, res) + order = append(order, -id) // Record return order as well + return err + } + } + + // Create chain + chain := UnsafeChainInterceptor(makeInterceptor(1), makeInterceptor(2), makeInterceptor(3)) + + // Execute chain + _ = chain( + context.Background(), + &http.Request{}, + &GQLRequestInfo{}, + nil, + func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error { + order = append(order, 0) // Record execution of final handler + return nil + }, + ) + + // Expected execution order: 1 -> 2 -> 3 -> 0 -> -3 -> -2 -> -1 + expected := []int{1, 2, 3, 0, -3, -2, -1} + if !reflect.DeepEqual(order, expected) { + t.Errorf("unexpected execution order\nexpected: %v\ngot: %v", expected, order) + } + }) +}