Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow propagation of errors from Subscriptions channels into Request.… #317

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import (
"fmt"
)

type SubscriptionError interface {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this interface?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used below with streaming responses.. so the object you stream back can actually transform the response into a proper error (GraphQL-style).. otherwise, it was impossible to return an object with an error.. you were forced to make the data contain some error field or whatnot.

SubscriptionError() error
}

type QueryError struct {
Message string `json:"message"`
Locations []Location `json:"locations,omitempty"`
Expand Down
35 changes: 27 additions & 8 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"time"

"github.com/graph-gophers/graphql-go/errors"
"github.com/graph-gophers/graphql-go/internal/common"
Expand Down Expand Up @@ -41,7 +42,7 @@ func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) (
return nil, err
}

r, err := resolvable.ApplyResolver(s.schema, resolver)
r, err := resolvable.ApplyResolver(s.schema, resolver, s.prefixRootFunctions)
if err != nil {
return nil, err
}
Expand All @@ -64,13 +65,15 @@ type Schema struct {
schema *schema.Schema
res *resolvable.Schema

maxDepth int
maxParallelism int
tracer trace.Tracer
validationTracer trace.ValidationTracer
logger log.Logger
useStringDescriptions bool
disableIntrospection bool
maxDepth int
maxParallelism int
tracer trace.Tracer
validationTracer trace.ValidationTracer
logger log.Logger
useStringDescriptions bool
disableIntrospection bool
prefixRootFunctions bool
subscribeResolverTimeout time.Duration
}

// SchemaOpt is an option to pass to ParseSchema or MustParseSchema.
Expand Down Expand Up @@ -100,6 +103,13 @@ func MaxDepth(n int) SchemaOpt {
}
}

// Add the Query, Subscription and Mutation prefixes to the root resolver function when doing reflection from schema to Go code.
func PrefixRootFunctions() SchemaOpt {
return func(s *Schema) {
s.prefixRootFunctions = true
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed in order to propagate subscription errors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?


// MaxParallelism specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10.
func MaxParallelism(n int) SchemaOpt {
return func(s *Schema) {
Expand Down Expand Up @@ -135,6 +145,15 @@ func DisableIntrospection() SchemaOpt {
}
}

// SubscribeResolverTimeout is an option to control the amount of time
// we allow for a single subscribe message resolver to complete it's job
// before it times out and returns an error to the subscriber.
func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt {
return func(s *Schema) {
s.subscribeResolverTimeout = timeout
}
}

// Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or
// it may be further processed to a custom response type, for example to include custom error data.
// Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107
Expand Down
8 changes: 5 additions & 3 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"reflect"
"sync"
"time"

"github.com/graph-gophers/graphql-go/errors"
"github.com/graph-gophers/graphql-go/internal/common"
Expand All @@ -20,9 +21,10 @@ import (

type Request struct {
selected.Request
Limiter chan struct{}
Tracer trace.Tracer
Logger log.Logger
Limiter chan struct{}
Tracer trace.Tracer
Logger log.Logger
SubscribeResolverTimeout time.Duration
}

func (r *Request) handlePanic(ctx context.Context) {
Expand Down
3 changes: 2 additions & 1 deletion internal/exec/packer/packer.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ func (p *StructPacker) Pack(value interface{}) (reflect.Value, error) {
for _, f := range p.fields {
if value, ok := values[f.field.Name.Name]; ok {
packed, err := f.fieldPacker.Pack(value)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary empty line, Please, remove it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course :) I will..

if err != nil {
return reflect.Value{}, err
return reflect.Value{}, fmt.Errorf("field [%s]: %s", f.field.Name.Name, err)
}
v.Elem().FieldByIndex(f.fieldIndex).Set(packed)
}
Expand Down
36 changes: 22 additions & 14 deletions internal/exec/resolvable/resolvable.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (*Object) isResolvable() {}
func (*List) isResolvable() {}
func (*Scalar) isResolvable() {}

func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) {
func ApplyResolver(s *schema.Schema, resolver interface{}, prefixRootFuncs bool) (*Schema, error) {
if resolver == nil {
return &Schema{Meta: newMeta(s), Schema: *s}, nil
}
Expand All @@ -71,19 +71,19 @@ func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) {
var query, mutation, subscription Resolvable

if t, ok := s.EntryPoints["query"]; ok {
if err := b.assignExec(&query, t, reflect.TypeOf(resolver)); err != nil {
if err := b.assignExec(&query, t, reflect.TypeOf(resolver), prefixRootFuncs); err != nil {
return nil, err
}
}

if t, ok := s.EntryPoints["mutation"]; ok {
if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver)); err != nil {
if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver), prefixRootFuncs); err != nil {
return nil, err
}
}

if t, ok := s.EntryPoints["subscription"]; ok {
if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver)); err != nil {
if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver), prefixRootFuncs); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -136,14 +136,14 @@ func (b *execBuilder) finish() error {
return b.packerBuilder.Finish()
}

func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type) error {
func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type, prefixFuncs bool) error {
k := typePair{t, resolverType}
ref, ok := b.resMap[k]
if !ok {
ref = &resMapEntry{}
b.resMap[k] = ref
var err error
ref.exec, err = b.makeExec(t, resolverType)
ref.exec, err = b.makeExec(t, resolverType, prefixFuncs)
if err != nil {
return err
}
Expand All @@ -152,13 +152,13 @@ func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType
return nil
}

func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolvable, error) {
func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type, prefixFuncs bool) (Resolvable, error) {
var nonNull bool
t, nonNull = unwrapNonNull(t)

switch t := t.(type) {
case *schema.Object:
return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType)
return b.makeObjectExecWithPrefix(t.Name, t.Fields, nil, nonNull, resolverType, prefixFuncs)

case *schema.Interface:
return b.makeObjectExec(t.Name, t.Fields, t.PossibleTypes, nonNull, resolverType)
Expand Down Expand Up @@ -186,7 +186,7 @@ func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolv
return nil, fmt.Errorf("%s is not a slice", resolverType)
}
e := &List{}
if err := b.assignExec(&e.Elem, t.OfType, resolverType.Elem()); err != nil {
if err := b.assignExec(&e.Elem, t.OfType, resolverType.Elem(), false); err != nil {
return nil, err
}
return e, nil
Expand Down Expand Up @@ -218,6 +218,9 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er

func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object,
nonNull bool, resolverType reflect.Type) (*Object, error) {
return b.makeObjectExecWithPrefix(typeName, fields, possibleTypes, nonNull, resolverType, false)
}
func (b *execBuilder) makeObjectExecWithPrefix(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, nonNull bool, resolverType reflect.Type, prefixFuncs bool) (*Object, error) {
if !nonNull {
if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface {
return nil, fmt.Errorf("%s is not a pointer or interface", resolverType)
Expand All @@ -230,8 +233,13 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
rt := unwrapPtr(resolverType)
fieldsCount := fieldCount(rt, map[string]int{})
for _, f := range fields {
methodName := f.Name
if prefixFuncs {
methodName = typeName + f.Name
}

var fieldIndex []int
methodIndex := findMethod(resolverType, f.Name)
methodIndex := findMethod(resolverType, methodName)
if b.schema.UseFieldResolvers && methodIndex == -1 {
if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 {
return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name)
Expand All @@ -240,10 +248,10 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
}
if methodIndex == -1 && len(fieldIndex) == 0 {
hint := ""
if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 {
if findMethod(reflect.PtrTo(resolverType), methodName) != -1 {
hint = " (hint: the method exists on the pointer type)"
}
return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint)
return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, methodName, hint)
}

var m reflect.Method
Expand Down Expand Up @@ -276,7 +284,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
a := &TypeAssertion{
MethodIndex: methodIndex,
}
if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil {
if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0), false); err != nil {
return nil, err
}
typeAssertions[impl.Name] = a
Expand Down Expand Up @@ -369,7 +377,7 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.
} else {
out = sf.Type
}
if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil {
if err := b.assignExec(&fe.ValueExec, f.Type, out, false); err != nil {
return nil, err
}

Expand Down
50 changes: 39 additions & 11 deletions internal/exec/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@ type Response struct {
func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response {
var result reflect.Value
var f *fieldToExec
var err *errors.QueryError
var errs []*errors.QueryError
func() {
defer r.handlePanic(ctx)

sels := selected.ApplyOperation(&r.Request, s, op)
var fields []*fieldToExec
collectFieldsToResolve(sels, s, s.Resolver, &fields, make(map[string]*fieldToExec))

if len(r.Errs) > 0 {
errs = r.Errs
return
}

// TODO: move this check into validation.Validate
if len(fields) != 1 {
err = errors.Errorf("%s", "can subscribe to at most one subscription at a time")
errs = []*errors.QueryError{errors.Errorf("%s", "can subscribe to at most one subscription at a time")}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what circumstances can we have more than one error?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was tweaked to accomodate line 34 up here.. it can hold multiple errors, so it was easier to pass down a list of errors.. instead of assuming there would only be one in there.

return
}
f = fields[0]
Expand All @@ -49,21 +54,29 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query
result = callOut[0]

if f.field.HasError && !callOut[1].IsNil() {
resolverErr := callOut[1].Interface().(error)
err = errors.Errorf("%s", resolverErr)
err.ResolverError = resolverErr
errIface := callOut[1].Interface()
switch resolverErr := errIface.(type) {
case *errors.QueryError:
errs = []*errors.QueryError{resolverErr}
case error:
err := errors.Errorf("%s", resolverErr)
err.ResolverError = resolverErr
errs = []*errors.QueryError{err}
default:
panic("dead code path")
}
}
}()

if f == nil {
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
return sendAndReturnClosed(&Response{Errors: errs})
}

if err != nil {
if len(errs) > 0 {
if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild {
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
return sendAndReturnClosed(&Response{Errors: errs})
}
return sendAndReturnClosed(&Response{Data: []byte(fmt.Sprintf(`{"%s":null}`, f.field.Alias)), Errors: []*errors.QueryError{err}})
return sendAndReturnClosed(&Response{Data: []byte(fmt.Sprintf(`{"%s":null}`, f.field.Alias)), Errors: errs})
}

if ctxErr := ctx.Err(); ctxErr != nil {
Expand Down Expand Up @@ -103,6 +116,17 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query
return
}

if subErr, ok := resp.Interface().(errors.SubscriptionError); ok {
if err := subErr.SubscriptionError(); err != nil {
if gqlError, ok := err.(*errors.QueryError); ok {
c <- &Response{Errors: []*errors.QueryError{gqlError}}
} else {
c <- &Response{Errors: []*errors.QueryError{errors.Errorf("%s", err)}}
}
return
}
}

subR := &Request{
Request: selected.Request{
Doc: r.Request.Doc,
Expand All @@ -115,8 +139,12 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query
}
var out bytes.Buffer
func() {
// TODO: configurable timeout
subCtx, cancel := context.WithTimeout(ctx, time.Second)
timeout := r.SubscribeResolverTimeout
if timeout == 0 {
timeout = time.Second
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this one and I'd be happy to merge it. Not sure how I've missed the hard-coded time.Second. I'll cherry pick it in a separate PR, though.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted into it's own PR #418


subCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// resolve response
Expand Down
14 changes: 7 additions & 7 deletions subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ func TestSchemaSubscribe(t *testing.T) {
helloSaidResolver: &helloSaidResolver{
upstream: closedUpstream(
&helloSaidEventResolver{msg: "Hello world!"},
&helloSaidEventResolver{err: resolverErr},
&helloSaidEventResolver{msg: "Hello again!"},
&helloSaidEventResolver{err: resolverErr},
),
},
}),
Expand All @@ -147,12 +147,6 @@ func TestSchemaSubscribe(t *testing.T) {
}
`),
},
{
Data: json.RawMessage(`
null
`),
Errors: []*qerrors.QueryError{qerrors.Errorf("%s", resolverErr)},
},
{
Data: json.RawMessage(`
{
Expand All @@ -162,6 +156,12 @@ func TestSchemaSubscribe(t *testing.T) {
}
`),
},
{
Data: json.RawMessage(`
null
`),
Errors: []*qerrors.QueryError{qerrors.Errorf("%s", resolverErr)},
},
},
},
{
Expand Down
10 changes: 7 additions & 3 deletions subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam
Vars: variables,
Schema: s.schema,
},
Limiter: make(chan struct{}, s.maxParallelism),
Tracer: s.tracer,
Logger: s.logger,
Limiter: make(chan struct{}, s.maxParallelism),
Tracer: s.tracer,
Logger: s.logger,
SubscribeResolverTimeout: s.subscribeResolverTimeout,
}
varTypes := make(map[string]*introspection.Type)
for _, v := range op.Vars {
Expand All @@ -80,6 +81,9 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam
Data: resp.Data,
Errors: resp.Errors,
}
if len(resp.Errors) > 0 {
break
}
}
close(c)
}()
Expand Down