diff --git a/internal/common/lexer.go b/internal/common/lexer.go index 65d049274..20dc3c62e 100644 --- a/internal/common/lexer.go +++ b/internal/common/lexer.go @@ -20,7 +20,7 @@ type Ident struct { Loc errors.Location } -func New(sc *scanner.Scanner) *Lexer { +func NewLexer(sc *scanner.Scanner) *Lexer { l := &Lexer{sc: sc} l.Consume() return l diff --git a/internal/query/query.go b/internal/query/query.go index a5721380a..5df25af20 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -100,7 +100,7 @@ func Parse(queryString string) (*Document, *errors.QueryError) { } sc.Init(strings.NewReader(queryString)) - l := common.New(sc) + l := common.NewLexer(sc) var doc *Document err := l.CatchSyntaxError(func() { doc = parseDocument(l) diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 1f3c3bfde..1b547729e 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -154,13 +154,14 @@ func New() *Schema { return s } +// Parse the schema string. func (s *Schema) Parse(schemaString string) error { sc := &scanner.Scanner{ Mode: scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings, } sc.Init(strings.NewReader(schemaString)) - l := common.New(sc) + l := common.NewLexer(sc) err := l.CatchSyntaxError(func() { parseSchema(s, l) }) @@ -316,7 +317,7 @@ func parseSchema(s *Schema, l *common.Lexer) { } l.ConsumeToken('}') case "type": - obj := parseObjectDecl(l) + obj := parseObjectDeclaration(l) obj.Desc = desc s.Types[obj.Name] = obj s.objects = append(s.objects, obj) @@ -351,22 +352,26 @@ func parseSchema(s *Schema, l *common.Lexer) { } } -func parseObjectDecl(l *common.Lexer) *Object { - o := &Object{} - o.Name = l.ConsumeIdent() +func parseObjectDeclaration(l *common.Lexer) *Object { + object := &Object{Name: l.ConsumeIdent()} + if l.Peek() == scanner.Ident { l.ConsumeKeyword("implements") - for { - o.interfaceNames = append(o.interfaceNames, l.ConsumeIdent()) - if l.Peek() == '{' { - break + + for l.Peek() != '{' { + object.interfaceNames = append(object.interfaceNames, l.ConsumeIdent()) + + if l.Peek() == '&' { + l.ConsumeToken('&') } } } + l.ConsumeToken('{') - o.Fields = parseFields(l) + object.Fields = parseFields(l) l.ConsumeToken('}') - return o + + return object } func parseInterfaceDecl(l *common.Lexer) *Interface { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go new file mode 100644 index 000000000..79a8f776e --- /dev/null +++ b/internal/schema/schema_test.go @@ -0,0 +1,71 @@ +package schema + +import ( + "strings" + "testing" + "text/scanner" + + "github.com/graph-gophers/graphql-go/internal/common" +) + +type testCase struct { + description string + declaration string + expected *Object +} + +func TestParseObjectDeclaration(t *testing.T) { + tests := []testCase{ + { + "allows '&' separator", + "Alien implements Being & Intelligent { name: String, iq: Int }", + &Object{ + Name: "Alien", + interfaceNames: []string{"Being", "Intelligent"}, + }, + }, + { + "allows legacy ',' separator", + "Alien implements Being, Intelligent { name: String, iq: Int }", + &Object{ + Name: "Alien", + interfaceNames: []string{"Being", "Intelligent"}, + }, + }, + } + + setup := func(schema string) *common.Lexer { + sc := &scanner.Scanner{ + Mode: scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings, + } + sc.Init(strings.NewReader(schema)) + return common.NewLexer(sc) + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + lex := setup(test.declaration) + var actual *Object + + parse := func() { actual = parseObjectDeclaration(lex) } + if err := lex.CatchSyntaxError(parse); err != nil { + t.Fatal(err) + } + + if test.expected.Name != actual.Name { + t.Errorf("wrong object name: want %q, got %q", test.expected.Name, actual.Name) + } + + if len(test.expected.interfaceNames) != len(actual.interfaceNames) { + t.Fatalf("wrong number of interface names: want %s, got %s", test.expected.interfaceNames, actual.interfaceNames) + } + + for i, expectedName := range test.expected.interfaceNames { + actualName := actual.interfaceNames[i] + if expectedName != actualName { + t.Errorf("wrong interface name: want %q, got %q", expectedName, actualName) + } + } + }) + } +}