Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnsafeChainInterceptor #244

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion clientgenv2/source_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ func (r *SourceGenerator) OperationArguments(variableDefinitions ast.VariableDef
func (r *SourceGenerator) Type(typeName string) types.Type {
goType, err := r.binder.FindTypeFromName(r.cfg.Models[typeName].Model[0])
if err != nil {
// 実装として正しいtypeNameを渡していれば必ず見つかるはずなのでpanic
panic(fmt.Sprintf("%+v", err))
}

Expand Down
45 changes: 43 additions & 2 deletions clientv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,35 @@ func ChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor {
}
}

func UnsafeChainInterceptor(interceptors ...RequestInterceptor) RequestInterceptor {
n := len(interceptors)

return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
chainer := func(currentInter RequestInterceptor, currentFunc RequestInterceptorFunc) RequestInterceptorFunc {
return func(currentCtx context.Context, currentReq *http.Request, currentGqlInfo *GQLRequestInfo, currentRes any) error {
return currentInter(currentCtx, currentReq, currentGqlInfo, currentRes, func(nextCtx context.Context, nextReq *http.Request, nextGqlInfo *GQLRequestInfo, nextRes any) error {
return currentFunc(nextCtx, nextReq, nextGqlInfo, nextRes)
})
}
}

chainedHandler := next
for i := n - 1; i >= 0; i-- {
chainedHandler = chainer(interceptors[i], chainedHandler)
}

return chainedHandler(ctx, req, gqlInfo, res)
}
}

// Client is the http client wrapper
type Client struct {
Client HttpClient
BaseURL string
RequestInterceptor RequestInterceptor
CustomDo RequestInterceptorFunc
ParseDataWhenErrors bool
IsUnsafeRequestInterceptor bool
}

// Request represents an outgoing GraphQL request
Expand All @@ -91,6 +113,23 @@ func NewClient(client HttpClient, baseURL string, options *Options, interceptors
return c
}

func NewClientWithUnsafeRequestInterceptor(client HttpClient, baseURL string, options *Options, interceptors ...RequestInterceptor) *Client {
c := &Client{
Client: client,
BaseURL: baseURL,
RequestInterceptor: UnsafeChainInterceptor(append([]RequestInterceptor{func(ctx context.Context, requestSet *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
return next(ctx, requestSet, gqlInfo, res)
}}, interceptors...)...),
IsUnsafeRequestInterceptor: true,
}

if options != nil {
c.ParseDataWhenErrors = options.ParseDataAlongWithErrors
}

return c
}

// Options is a struct that holds some client-specific options that can be passed to NewClient.
type Options struct {
// ParseDataAlongWithErrors is a flag that indicates whether the client should try to parse and return the data along with error
Expand Down Expand Up @@ -211,6 +250,9 @@ func (c *Client) Post(ctx context.Context, operationName, query string, respData
}

f := ChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...)
if c.IsUnsafeRequestInterceptor {
f = UnsafeChainInterceptor(append([]RequestInterceptor{c.RequestInterceptor}, interceptors...)...)
}

// if custom do is set, use it instead of the default one
if c.CustomDo != nil {
Expand Down Expand Up @@ -437,8 +479,7 @@ func checkImplements[I any](v reflect.Value) bool {
t := v.Type()
interfaceType := reflect.TypeOf((*I)(nil)).Elem()

// Check if the type implements the interface directly or as a pointer.
return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PtrTo(t).Implements(interfaceType))
return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PointerTo(t).Implements(interfaceType))
}

// encode returns an appropriate encoder function for the provided value.
Expand Down
127 changes: 127 additions & 0 deletions clientv2/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -831,3 +832,129 @@ func TestMarshalJSON(t *testing.T) {
})
}
}

func TestUnsafeChainInterceptor(t *testing.T) {
t.Run("should modify values through interceptors", func(t *testing.T) {
// Prepare test values
originalCtx := context.Background()
originalReq, _ := http.NewRequest("POST", "http://example.com", nil)
originalGqlInfo := &GQLRequestInfo{
Request: &Request{Query: "original"},
}
originalRes := "original"

// First interceptor: Add value to context
interceptor1 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
ctx = context.WithValue(ctx, "key1", "value1")
return next(ctx, req, gqlInfo, res)
}

// Second interceptor: Modify request header
interceptor2 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
req.Header.Set("X-Test", "test-value")
return next(ctx, req, gqlInfo, res)
}

// Third interceptor: Modify GQLInfo and response
interceptor3 := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
gqlInfo.Request.Query = "modified"
return next(ctx, req, gqlInfo, "modified")
}

// Final handler: Verify modified values
finalHandler := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
// Verify context
if v := ctx.Value("key1"); v != "value1" {
t.Errorf("context value not propagated, got %v", v)
}

// Verify request header
if v := req.Header.Get("X-Test"); v != "test-value" {
t.Errorf("request header not modified, got %v", v)
}

// Verify GQLInfo
if gqlInfo.Request.Query != "modified" {
t.Errorf("GQLInfo not modified, got %v", gqlInfo.Request.Query)
}

// Verify response
if res != "modified" {
t.Errorf("response not modified, got %v", res)
}

return nil
}

// Create interceptor chain
chain := UnsafeChainInterceptor(interceptor1, interceptor2, interceptor3)

// Execute chain
err := chain(originalCtx, originalReq, originalGqlInfo, originalRes, finalHandler)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})

t.Run("should properly propagate errors", func(t *testing.T) {
expectedError := errors.New("test error")

// Interceptor that returns an error
errorInterceptor := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
return expectedError
}

// Create chain
chain := UnsafeChainInterceptor(errorInterceptor)

// Execute chain
err := chain(
context.Background(),
&http.Request{},
&GQLRequestInfo{},
nil,
func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
return nil
},
)

if err != expectedError {
t.Errorf("expected error %v, got %v", expectedError, err)
}
})

t.Run("should execute interceptors in correct order", func(t *testing.T) {
var order []int

// Create interceptors that record execution order
makeInterceptor := func(id int) RequestInterceptor {
return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error {
order = append(order, id)
err := next(ctx, req, gqlInfo, res)
order = append(order, -id) // Record return order as well
return err
}
}

// Create chain
chain := UnsafeChainInterceptor(makeInterceptor(1), makeInterceptor(2), makeInterceptor(3))

// Execute chain
_ = chain(
context.Background(),
&http.Request{},
&GQLRequestInfo{},
nil,
func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error {
order = append(order, 0) // Record execution of final handler
return nil
},
)

// Expected execution order: 1 -> 2 -> 3 -> 0 -> -3 -> -2 -> -1
expected := []int{1, 2, 3, 0, -3, -2, -1}
if !reflect.DeepEqual(order, expected) {
t.Errorf("unexpected execution order\nexpected: %v\ngot: %v", expected, order)
}
})
}