Skip to content

Commit

Permalink
Add UnsafeChainInterceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yamashou committed Nov 30, 2024
1 parent 4a0fc2f commit 4818563
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 3 deletions.
1 change: 0 additions & 1 deletion clientgenv2/source_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ func (r *SourceGenerator) OperationArguments(variableDefinitions ast.VariableDef
func (r *SourceGenerator) Type(typeName string) types.Type {
goType, err := r.binder.FindTypeFromName(r.cfg.Models[typeName].Model[0])
if err != nil {
// 実装として正しいtypeNameを渡していれば必ず見つかるはずなのでpanic
panic(fmt.Sprintf("%+v", err))
}

Expand Down
45 changes: 43 additions & 2 deletions clientv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,35 @@ func ChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor {
}
}

func UnsafeChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor {
n := len(interceptors)

return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
chainer := func(currentInter RequestInterceptor, currentFunc RequestInterceptorFunc) RequestInterceptorFunc {
return func(currentCtx context.Context, currentReq *http.Request, currentGqlInfo *GQLRequestInfo, currentRes any) error {
return currentInter(currentCtx, currentReq, currentGqlInfo, currentRes, func(nextCtx context.Context, nextReq *http.Request, nextGqlInfo *GQLRequestInfo, nextRes any) error {
return currentFunc(nextCtx, nextReq, nextGqlInfo, nextRes)
})
}
}

chainedHandler := next
for i := n - 1; i >= 0; i-- {
chainedHandler = chainer(interceptors[i], chainedHandler)
}

return chainedHandler(ctx, req, gqlInfo, res)
}
}

// Client is the http client wrapper
type Client struct {
Client HttpClient
BaseURL string
RequestInterceptor RequestInterceptor
CustomDo RequestInterceptorFunc
ParseDataWhenErrors bool
IsUnsafeRequestInterceptor bool
}

// Request represents an outgoing GraphQL request
Expand All @@ -91,6 +113,23 @@ func NewClient(client HttpClient, baseURL string, options *Options, interceptors
return c
}

func NewClientWithUnsafeRequestInterceptor(client HttpClient, baseURL string, options *Options, interceptors ...RequestInterceptor) *Client {
c := &Client{
Client: client,
BaseURL: baseURL,
RequestInterceptor: UnsafeChainInterceptor(append([]RequestInterceptor{func(ctx context.Context, requestSet *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
return next(ctx, requestSet, gqlInfo, res)
}}, interceptors...)...),
IsUnsafeRequestInterceptor: true,
}

if options != nil {
c.ParseDataWhenErrors = options.ParseDataAlongWithErrors
}

return c
}

// Options is a struct that holds some client-specific options that can be passed to NewClient.
type Options struct {
// ParseDataAlongWithErrors is a flag that indicates whether the client should try to parse and return the data along with error
Expand Down Expand Up @@ -211,6 +250,9 @@ func (c *Client) Post(ctx context.Context, operationName, query string, respData
}

f := ChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...)
if c.IsUnsafeRequestInterceptor {
f = UnsafeChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...)
}

// if custom do is set, use it instead of the default one
if c.CustomDo != nil {
Expand Down Expand Up @@ -437,8 +479,7 @@ func checkImplements[I any](v reflect.Value) bool {
t := v.Type()
interfaceType := reflect.TypeOf((*I)(nil)).Elem()

// Check if the type implements the interface directly or as a pointer.
return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PtrTo(t).Implements(interfaceType))
return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PointerTo(t).Implements(interfaceType))
}

// encode returns an appropriate encoder function for the provided value.
Expand Down
127 changes: 127 additions & 0 deletions clientv2/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -831,3 +832,129 @@ func TestMarshalJSON(t *testing.T) {
})
}
}

func TestUnsafeChainInterceptor(t *testing.T) {
t.Run("should modify values through interceptors", func(t *testing.T) {
// Prepare test values
originalCtx := context.Background()
originalReq, _ := http.NewRequest("POST", "http://example.com", nil)
originalGqlInfo := &GQLRequestInfo{
Request: &Request{Query: "original"},
}
originalRes := "original"

// First interceptor: Add value to context
interceptor1 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
ctx = context.WithValue(ctx, "key1", "value1")
return next(ctx, req, gqlInfo, res)
}

// Second interceptor: Modify request header
interceptor2 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
req.Header.Set("X-Test", "test-value")
return next(ctx, req, gqlInfo, res)
}

// Third interceptor: Modify GQLInfo and response
interceptor3 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
gqlInfo.Request.Query = "modified"
return next(ctx, req, gqlInfo, "modified")
}

// Final handler: Verify modified values
finalHandler := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
// Verify context
if v := ctx.Value("key1"); v != "value1" {
t.Errorf("context value not propagated, got %v", v)
}

// Verify request header
if v := req.Header.Get("X-Test"); v != "test-value" {
t.Errorf("request header not modified, got %v", v)
}

// Verify GQLInfo
if gqlInfo.Request.Query != "modified" {
t.Errorf("GQLInfo not modified, got %v", gqlInfo.Request.Query)
}

// Verify response
if res != "modified" {
t.Errorf("response not modified, got %v", res)
}

return nil
}

// Create interceptor chain
chain := UnsafeChainInterceptor(interceptor1, interceptor2, interceptor3)

// Execute chain
err := chain(originalCtx, originalReq, originalGqlInfo, originalRes, finalHandler)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})

t.Run("should properly propagate errors", func(t *testing.T) {
expectedError := errors.New("test error")

// Interceptor that returns an error
errorInterceptor := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
return expectedError
}

// Create chain
chain := UnsafeChainInterceptor(errorInterceptor)

// Execute chain
err := chain(
context.Background(),
&http.Request{},
&GQLRequestInfo{},
nil,
func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
return nil
},
)

if err != expectedError {
t.Errorf("expected error %v, got %v", expectedError, err)
}
})

t.Run("should execute interceptors in correct order", func(t *testing.T) {
var order []int

// Create interceptors that record execution order
makeInterceptor := func(id int) RequestInterceptor {
return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
order = append(order, id)
err := next(ctx, req, gqlInfo, res)
order = append(order, -id) // Record return order as well
return err
}
}

// Create chain
chain := UnsafeChainInterceptor(makeInterceptor(1), makeInterceptor(2), makeInterceptor(3))

// Execute chain
_ = chain(
context.Background(),
&http.Request{},
&GQLRequestInfo{},
nil,
func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
order = append(order, 0) // Record execution of final handler
return nil
},
)

// Expected execution order: 1 -> 2 -> 3 -> 0 -> -3 -> -2 -> -1
expected := []int{1, 2, 3, 0, -3, -2, -1}
if !reflect.DeepEqual(order, expected) {
t.Errorf("unexpected execution order\nexpected: %v\ngot: %v", expected, order)
}
})
}

0 comments on commit 4818563

Please sign in to comment.