diff --git a/graphql/rules.lua b/graphql/rules.lua index 83a4213..956ced8 100644 --- a/graphql/rules.lua +++ b/graphql/rules.lua @@ -31,13 +31,16 @@ end function rules.fieldsDefinedOnType(node, context) if context.objects[#context.objects] == false then local parent = context.objects[#context.objects - 1] + if(parent.__type == 'List') then + parent = parent.ofType + end error('Field "' .. node.name.value .. '" is not defined on type "' .. parent.name .. '"') end end function rules.argumentsDefinedOnType(node, context) if node.arguments then - local parentField = context.objects[#context.objects - 1].fields[node.name.value] + local parentField = util.getParentField(context, node.name.value) for _, argument in pairs(node.arguments) do local name = argument.name.value if not parentField.arguments[name] then @@ -175,7 +178,7 @@ end function rules.argumentsOfCorrectType(node, context) if node.arguments then - local parentField = context.objects[#context.objects - 1].fields[node.name.value] + local parentField = util.getParentField(context, node.name.value) for _, argument in pairs(node.arguments) do local name = argument.name.value local argumentType = parentField.arguments[name] @@ -186,13 +189,7 @@ end function rules.requiredArgumentsPresent(node, context) local arguments = node.arguments or {} - local parentField - if context.objects[#context.objects - 1].__type == 'List' then - parentField = context.objects[#context.objects - 2].fields[node.name.value] - else - parentField = context.objects[#context.objects - 1].fields[node.name.value] - end - + local parentField = util.getParentField(context, node.name.value) for name, argument in pairs(parentField.arguments) do if argument.__type == 'NonNull' then local present = util.find(arguments, function(argument) @@ -451,7 +448,7 @@ function rules.variableUsageAllowed(node, context) if not arguments then return end for field in pairs(arguments) do - local parentField = context.objects[#context.objects - 1].fields[field] + local parentField = util.getParentField(context, field) for i = 1, #arguments[field] do local argument = arguments[field][i] if argument.value.kind == 'variable' then diff --git a/graphql/util.lua b/graphql/util.lua index e66e5fa..72cfb2d 100644 --- a/graphql/util.lua +++ b/graphql/util.lua @@ -23,6 +23,16 @@ function util.bind1(func, x) end end +function util.getParentField(context, name, count) + count = count == nil and 1 or count + local obj = context.objects[#context.objects - count] + if obj.__type == 'List' then + return obj.ofType.fields[name] + else + return obj.fields[name] + end +end + function util.coerceValue(node, schemaType, variables) variables = variables or {} @@ -58,7 +68,7 @@ function util.coerceValue(node, schemaType, variables) error('Unknown input object field "' .. field.name .. '"') end - return util.coerceValue(schemaType.fields[field.name].kind, field.value, variables) + return util.coerceValue(field.value, schemaType.fields[field.name].kind, variables) end) end diff --git a/graphql/validate.lua b/graphql/validate.lua index 08eaf02..d373e7e 100644 --- a/graphql/validate.lua +++ b/graphql/validate.lua @@ -1,5 +1,6 @@ local path = (...):gsub('%.[^%.]+$', '') local rules = require(path .. '.rules') +local util = require(path .. '.util') local visitors = { document = { @@ -58,16 +59,10 @@ local visitors = { field = { enter = function(node, context) - local parentField - if context.objects[#context.objects].__type == 'List' then - parentField = context.objects[#context.objects - 1].fields[node.name.value] - else - parentField = context.objects[#context.objects].fields[node.name.value] - end + local parentField = util.getParentField(context, node.name.value, 0) -- false is a special value indicating that the field was not present in the type definition. local field = parentField and parentField.kind or false - table.insert(context.objects, field) end,