Skip to content

Commit

Permalink
ensure HasOperationContext checks for nil (#2776)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscerk authored Aug 29, 2023
1 parent a1ca220 commit cc4e0ba
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
4 changes: 2 additions & 2 deletions graphql/context_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func WithOperationContext(ctx context.Context, rc *OperationContext) context.Con
//
// Some errors can happen outside of an operation, eg json unmarshal errors.
func HasOperationContext(ctx context.Context) bool {
_, ok := ctx.Value(operationCtx).(*OperationContext)
return ok
val, ok := ctx.Value(operationCtx).(*OperationContext)
return ok && val != nil
}

// This is just a convenient wrapper method for CollectFields
Expand Down
31 changes: 31 additions & 0 deletions graphql/context_operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,33 @@ package graphql
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/v2/ast"
)

// implement context.Context interface
type testGraphRequestContext struct {
opContext *OperationContext
}

func (t *testGraphRequestContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

func (t *testGraphRequestContext) Done() <-chan struct{} {
return nil
}

func (t *testGraphRequestContext) Err() error {
return nil
}

func (t *testGraphRequestContext) Value(key interface{}) interface{} {
return t.opContext
}

func TestGetOperationContext(t *testing.T) {
rc := &OperationContext{}

Expand All @@ -26,6 +48,15 @@ func TestGetOperationContext(t *testing.T) {
GetOperationContext(ctx)
})
})

t.Run("with nil operation context", func(t *testing.T) {
ctx := &testGraphRequestContext{opContext: nil}

require.False(t, HasOperationContext(ctx))
require.Panics(t, func() {
GetOperationContext(ctx)
})
})
}

func TestCollectAllFields(t *testing.T) {
Expand Down

0 comments on commit cc4e0ba

Please sign in to comment.