From 00406a208fc460cc1be5b2a40435902c7c9d4ace Mon Sep 17 00:00:00 2001 From: Yamashou <1230124fw@gmail.com> Date: Tue, 16 Apr 2024 21:16:15 +0900 Subject: [PATCH 1/2] fix --- clientv2/client.go | 60 ++++++++++-- clientv2/client_test.go | 199 ++++++++++++++++++---------------------- 2 files changed, 140 insertions(+), 119 deletions(-) diff --git a/clientv2/client.go b/clientv2/client.go index 4db910c..478ae9a 100644 --- a/clientv2/client.go +++ b/clientv2/client.go @@ -395,12 +395,21 @@ func (c *Client) unmarshal(data []byte, res interface{}) error { } func MarshalJSON(v interface{}) ([]byte, error) { + if v == nil { + return []byte("null"), nil // Directly return "null" for nil interface{} + } + + val := reflect.ValueOf(v) + if !val.IsValid() || (val.Kind() == reflect.Ptr && val.IsNil()) { + return []byte("null"), nil // Return "null" for nil pointer or invalid reflect value + } + encoderFunc := getTypeEncoder(reflect.TypeOf(v)) return encoderFunc(v) } // getTypeEncoder returns an appropriate encoder function for the provided type. -func getTypeEncoder(t reflect.Type) func(interface{}) ([]byte, error) { +func getTypeEncoder(t reflect.Type) func(a any) ([]byte, error) { if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) { return gqlMarshalerEncoder } @@ -523,14 +532,51 @@ func prepareFields(t reflect.Type) []fieldInfo { } func checkMarshalerFields(t reflect.Type) bool { - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - if f.Type.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) { + switch t.Kind() { + case reflect.Ptr: + return checkMarshalerFields(t.Elem()) + + case reflect.Struct: + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if isMarshalerType(f.Type) { + return true + } + // Recursively check for nested structs + if checkMarshalerFields(f.Type) { + return true + } + } + + case reflect.Map: + // Check both key and value types for Marshaler implementation; usually, value type is what matters + keyType, valueType := t.Key(), t.Elem() + if isMarshalerType(valueType) || isMarshalerType(keyType) { return true } - if reflect.PtrTo(f.Type).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) { + // Recursively check the map value type + if checkMarshalerFields(valueType) { return true } + + case reflect.Slice, reflect.Array: + // Recursively check the element type + return checkMarshalerFields(t.Elem()) + case reflect.Interface, reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return false + default: + return false + } + + return false +} + +func isMarshalerType(t reflect.Type) bool { + if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) { + return true + } + if reflect.PtrTo(t).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) { + return true } return false } @@ -539,7 +585,7 @@ func newStructEncoder(t reflect.Type) func(interface{}) ([]byte, error) { fields := prepareFields(t) marshalerFieldExists := checkMarshalerFields(t) - return func(v interface{}) ([]byte, error) { + return func(v any) ([]byte, error) { // If no field implements the MarshalerGQL interface, use standard JSON marshaling if !marshalerFieldExists { return json.Marshal(v) @@ -591,7 +637,7 @@ func newMapEncoder(t reflect.Type) func(interface{}) ([]byte, error) { if err != nil { return nil, err } - result[keyStr] = json.RawMessage(encodedValue) // Use json.RawMessage to avoid double encoding + result[keyStr] = encodedValue } return json.Marshal(result) diff --git a/clientv2/client_test.go b/clientv2/client_test.go index 8f2c101..2a43925 100644 --- a/clientv2/client_test.go +++ b/clientv2/client_test.go @@ -50,7 +50,7 @@ func TestUnmarshal(t *testing.T) { Line: 6, Column: 4, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "undefinedField", "typeName": "RepositoryConnection", "fieldName": "nsodes", @@ -80,7 +80,7 @@ func TestUnmarshal(t *testing.T) { Line: 6, Column: 4, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "undefinedField", "typeName": "RepositoryConnection", "fieldName": "nsodes", @@ -93,7 +93,7 @@ func TestUnmarshal(t *testing.T) { Line: 1, Column: 1, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "variableNotUsed", "variableName": "languageFirst", }, @@ -105,7 +105,7 @@ func TestUnmarshal(t *testing.T) { Line: 18, Column: 1, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "useAndDefineFragment", "fragmentName": "LanguageFragment", }, @@ -130,7 +130,7 @@ func TestUnmarshal(t *testing.T) { Line: 6, Column: 4, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "undefinedField", "typeName": "RepositoryConnection", "fieldName": "nsodes", @@ -155,7 +155,7 @@ func TestUnmarshal(t *testing.T) { Line: 6, Column: 4, }}, - Extensions: map[string]interface{}{ + Extensions: map[string]any{ "code": "undefinedField", "typeName": "RepositoryConnection", "fieldName": "nsodes", @@ -288,7 +288,7 @@ func TestChainInterceptor(t *testing.T) { OperationName: "GQL", }) outputError := fmt.Errorf("some error") - requireContextValue := func(t *testing.T, ctx context.Context, key string, msg ...interface{}) { + requireContextValue := func(t *testing.T, ctx context.Context, key string, msg ...any) { t.Helper() val := ctx.Value(key) require.NotNil(t, val, msg...) @@ -298,7 +298,7 @@ func TestChainInterceptor(t *testing.T) { req, err := http.NewRequestWithContext(parentContext, http.MethodPost, "https://hogehoge/graphql", bytes.NewBufferString(requestMessage)) require.Nil(t, err) - first := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res interface{}, next RequestInterceptorFunc) error { + first := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { requireContextValue(t, ctx, "parent", "first must know the parent context value") wrappedCtx := context.WithValue(ctx, "first", someValue) @@ -306,7 +306,7 @@ func TestChainInterceptor(t *testing.T) { return next(wrappedCtx, req, gqlInfo, res) } - second := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res interface{}, next RequestInterceptorFunc) error { + second := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { requireContextValue(t, ctx, "parent", "second must know the parent context value") requireContextValue(t, ctx, "first", "second must know the first context value") @@ -315,7 +315,7 @@ func TestChainInterceptor(t *testing.T) { return next(wrappedCtx, req, gqlInfo, res) } - invoker := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res interface{}) error { + invoker := func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any) error { requireContextValue(t, ctx, "parent", "invoker must know the parent context value") requireContextValue(t, ctx, "first", "invoker must know the first context value") requireContextValue(t, ctx, "second", "invoker must know the second context value") @@ -372,7 +372,7 @@ func Test_parseMultipartFiles(t *testing.T) { t.Run("no files in vars", func(t *testing.T) { t.Parallel() - vars := map[string]interface{}{ + vars := map[string]any{ "field": "val", "field2": "val2", } @@ -388,7 +388,7 @@ func Test_parseMultipartFiles(t *testing.T) { t.Run("has file in vars", func(t *testing.T) { t.Parallel() - vars := map[string]interface{}{ + vars := map[string]any{ "field": "val", "fieldFile": graphql.Upload{ Filename: "file.txt", @@ -416,7 +416,7 @@ func Test_parseMultipartFiles(t *testing.T) { t.Run("has few files in vars", func(t *testing.T) { t.Parallel() - vars := map[string]interface{}{ + vars := map[string]any{ "field": "val", "fieldFiles": []*graphql.Upload{ { @@ -455,7 +455,7 @@ const ( NumberTwo Number = 2 ) -func (n *Number) UnmarshalGQL(v interface{}) error { +func (n *Number) UnmarshalGQL(v any) error { str, ok := v.(string) if !ok { return fmt.Errorf("enums must be strings") @@ -485,18 +485,11 @@ func (n Number) MarshalGQL(w io.Writer) { fmt.Fprint(w, strconv.Quote(str)) } -func TestMarshalJSON(t *testing.T) { - type Example1 struct { - Name string `json:"name"` - Age int `json:"age"` - } - type Example2 struct { - Name string `json:"name"` - Number Number `json:"number"` - } +func TestMarshalJSONValueType(t *testing.T) { + t.Parallel() testDate := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) type args struct { - v interface{} + v any } tests := []struct { name string @@ -546,7 +539,7 @@ func TestMarshalJSON(t *testing.T) { NumberOne: "ONE", }, }, - want: []byte(`{"1":"ONE"}`), + want: []byte(`{"ONE":"ONE"}`), }, { name: "marshal slice", @@ -556,95 +549,50 @@ func TestMarshalJSON(t *testing.T) { want: []byte(`["ONE","TWO"]`), }, { - name: "marshal normal struct", - args: args{ - v: Example1{ - Name: "John", - Age: 20, - }, - }, - want: []byte(`{"age":20,"name":"John"}`), - }, - { - name: "marshal pointer struct", - args: args{ - v: &Example1{ - Name: "John", - Age: 20, - }, - }, - want: []byte(`{"age":20,"name":"John"}`), - }, - { - name: "marshal nested struct", - args: args{ - v: struct { - Outer struct { - Inner Example1 `json:"inner"` - } `json:"outer"` - }{ - Outer: struct { - Inner Example1 `json:"inner"` - }{ - Inner: Example1{ - Name: "John", - Age: 22, - }, - }, - }, - }, - want: []byte(`{"outer":{"inner":{"age":22,"name":"John"}}}`), - }, - { - name: "marshal nested map", - args: args{ - v: map[string]any{ - "outer": map[string]any{ - "inner": map[string]int{"value": 5}, - }, - }, - }, - want: []byte(`{"outer":{"inner":{"value":5}}}`), - }, - { - name: "marshal slice of slices", - args: args{ - v: [][]int{{1, 2}, {3, 4}}, - }, - want: []byte(`[[1,2],[3,4]]`), - }, - { - name: "error handling on custom marshaler", - args: args{ - v: struct{ Test Number }{Test: Number(999)}, // Assuming 999 is not handled by Number's MarshalGQL - }, - wantErr: true, - }, - { - name: "marshal array", - args: args{ - v: [2]Number{NumberOne, NumberTwo}, - }, - want: []byte(`["ONE","TWO"]`), - }, - { - name: "marshal pointer array", - args: args{ - v: &[2]Number{NumberOne, NumberTwo}, - }, - want: []byte(`["ONE","TWO"]`), - }, - { - name: "marshal nil pointer", + name: "marshal time.Time", args: args{ - v: (*Example1)(nil), + v: testDate, }, - want: []byte("null"), + want: []byte(`"2021-01-01T00:00:00Z"`), }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalJSON(tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if !cmp.Equal(tt.want, got) { + t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + } + + }) + } +} + +func TestMarshalJSON(t *testing.T) { + t.Parallel() + type Example struct { + Name string `json:"name"` + Number Number `json:"number"` + } + testDate := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + type args struct { + v any + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ { name: "marshal a struct with custom marshaler", args: args{ - v: Example2{ + v: Example{ Name: "John", Number: NumberOne, }, @@ -656,7 +604,7 @@ func TestMarshalJSON(t *testing.T) { args: args{ v: map[string]any{ "number": NumberOne, - "example2": &Example2{ + "example2": &Example{ Name: "John", Number: NumberOne, }, @@ -675,18 +623,45 @@ func TestMarshalJSON(t *testing.T) { }, want: []byte(`{"time":"2021-01-01T00:00:00Z"}`), }, + { + name: "marshal time.Time", + args: args{ + v: struct { + T struct { + Time time.Time `json:"time"` + } + }{ + T: struct { + Time time.Time `json:"time"` + }{ + Time: testDate, + }, + }, + }, + want: []byte(`{"T":{"time":"2021-01-01T00:00:00Z"}}`), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() got, err := MarshalJSON(tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Errorf("MarshalJSON() mismatch (-want +got):\n%s", diff) + + var gotMap, wantMap map[string]any + if err := json.Unmarshal(got, &gotMap); err != nil { + t.Errorf("Failed to unmarshal 'got': %s", string(got)) + return + } + if err := json.Unmarshal(tt.want, &wantMap); err != nil { + t.Errorf("Failed to unmarshal 'want': %s", tt.want) + return + } + + if !cmp.Equal(gotMap, wantMap) { + t.Errorf("MarshalJSON() got = %v, want %v", gotMap, wantMap) } }) } From 37317a15a61c75038534e9c74fd9611a8ef2a089 Mon Sep 17 00:00:00 2001 From: Yamashou <1230124fw@gmail.com> Date: Wed, 17 Apr 2024 14:47:36 +0900 Subject: [PATCH 2/2] fix version --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 6a1f841..db0f228 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "github.com/urfave/cli/v2" ) -const version = "0.20.2" +const version = "0.20.0" var versionCmd = &cli.Command{ Name: "version",