Skip to content

Commit

Permalink
adds ability to add list_fields to directives
Browse files Browse the repository at this point in the history
  • Loading branch information
jhnnsrs committed Nov 15, 2024
1 parent 4094cd1 commit 8d9d5aa
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "turms"
version = "0.7.0"
version = "0.8.0"
description = "graphql-codegen powered by pydantic"
authors = ["jhnnsrs <jhnnsrs@gmail.com>"]
license = "MIT"
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def nested_input_schema():
def union_schema():
return build_schema_from_schema_type(build_relative_glob("/schemas/union.graphql"))

@pytest.fixture(scope="session")
def directive_schema():
return build_schema_from_schema_type(build_relative_glob("/schemas/list_field_directive.graphql"))


@pytest.fixture(scope="session")
def schema_directive_schema():
Expand Down
3 changes: 3 additions & 0 deletions tests/documents/directives/list_field_directive.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
query X {
x
}
Empty file added tests/plugins/__init__.py
Empty file.
Empty file.
36 changes: 36 additions & 0 deletions tests/plugins/strawberry/test_list_directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

from ...utils import build_relative_glob, unit_test_with
from turms.config import GeneratorConfig
from turms.run import generate_ast
from turms.plugins.enums import EnumsPlugin
from turms.plugins.inputs import InputsPlugin
from turms.plugins.fragments import FragmentsPlugin
from turms.plugins.operations import OperationsPlugin
from turms.plugins.funcs import (
FunctionDefinition,
FuncsPlugin,
FuncsPluginConfig,
)
from turms.plugins.strawberry import StrawberryPlugin
from turms.stylers.snake_case import SnakeCaseStyler
from turms.stylers.capitalize import CapitalizeStyler
from turms.run import generate_ast


def test_list_directive_funcs(directive_schema):
config = GeneratorConfig(
documents=build_relative_glob("/documents/directives/*.graphql"),
)
generated_ast = generate_ast(
config,
directive_schema,
stylers=[CapitalizeStyler(), SnakeCaseStyler()],
plugins=[
StrawberryPlugin(),
],
)

unit_test_with(
generated_ast,
""
)
22 changes: 22 additions & 0 deletions tests/schemas/list_field_directive.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
The directive is responsible for authorization check.
"""
directive @auth(
"""
Permissions which are required for field access.
"""
permissions: [String!]

"""
The list of roles that an authorized user should have to get the access.
"""
roles: [String!] = []
) on FIELD_DEFINITION

type X {
name: String! @auth(permissions: ["read"])
}

type Query {
x: [X!]! @auth(permissions: ["read"])
}
2 changes: 1 addition & 1 deletion turms/plugins/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FuncsPluginConfig(PluginConfig):
definitions: List[FunctionDefinition] = []
extract_documentation: bool = True
argument_key_is_styled: bool = False
expand_input_types: List[str] = ["input"]
expand_input_types: List[str] = []


def camel_to_snake(name):
Expand Down
192 changes: 180 additions & 12 deletions turms/plugins/strawberry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from graphql import (
BooleanValueNode,
ConstListValueNode,
ConstObjectValueNode,
ConstValueNode,
EnumValueNode,
FloatValueNode,
GraphQLField,
GraphQLInputObjectType,
GraphQLInterfaceType,
Expand All @@ -8,6 +14,11 @@
GraphQLScalarType,
GraphQLType,
GraphQLUnionType,
IntValueNode,
ListValueNode,
NullValueNode,
ObjectValueNode,
StringValueNode,
Undefined,
GraphQLArgument,
ObjectTypeDefinitionNode,
Expand Down Expand Up @@ -43,6 +54,133 @@ def __call__(
) -> List[ast.AST]: ... # pragma: no cover


def build_directive_type_annotation(value: GraphQLType, registry: ClassRegistry, is_optional=True):

if isinstance(value, GraphQLScalarType):
if is_optional:
registry.register_import("typing.Optional")
return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=registry.reference_scalar(value.name),
ctx=ast.Load(),
)

return registry.reference_scalar(value.name)
if isinstance(value, GraphQLObjectType):
raise NotImplementedError("Object types cannot be used as arguments")
if isinstance(value, GraphQLInterfaceType):
raise NotImplementedError("Interface types cannot be used as arguments")
if isinstance(value, GraphQLUnionType):
raise NotImplementedError("Union types cannot be used as arguments")
if isinstance(value, GraphQLEnumType):
if is_optional:
registry.register_import("typing.Optional")
return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=registry.reference_enum(value.name),
ctx=ast.Load(),
)

return registry.reference_enum(value.name)
if isinstance(value, GraphQLNonNull):
return build_directive_type_annotation(value.of_type, registry, is_optional=False)
if isinstance(value, GraphQLList):
registry.register_import("typing.List")

if is_optional:
registry.register_import("typing.Optional")

return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=ast.Subscript(
value=ast.Name("List", ctx=ast.Load()),
slice=build_directive_type_annotation(value.of_type, registry, is_optional=True),
ctx=ast.Load(),
),
ctx=ast.Load(),
)

return ast.Subscript(
value=ast.Name("List", ctx=ast.Load()),
slice=build_directive_type_annotation(value.of_type, registry, is_optional=True),
ctx=ast.Load(),
)
if isinstance(value, GraphQLInputObjectType):
raise NotImplementedError("Input types cannot be used as arguments")

raise NotImplementedError(f"Unknown type {repr(value)}")



def convert_valuenode_to_ast(value: ConstValueNode):
if isinstance(value, NullValueNode):
return ast.Constant(value=None)
if isinstance(value, StringValueNode):
return ast.Constant(value=value.value)
if isinstance(value, IntValueNode):
return ast.Constant(value=value.value)
if isinstance(value, FloatValueNode):
return ast.Constant(value=value.value)
if isinstance(value, BooleanValueNode):
return ast.Constant(value=value.value)

if isinstance(value, EnumValueNode):
return ast.Constant(value=value)
if isinstance(value, ListValueNode):
return ast.List(elts=[convert_valuenode_to_ast(x) for x in value.values], ctx=ast.Load())
if isinstance(value, ObjectValueNode):

keys = []
values = []

for field in value.fields:
keys.append(field.name.value)
values.append(convert_valuenode_to_ast(field.value))

return ast.Dict(
keys=keys,
values=values,
)

raise NotImplementedError(f"Unknown default value {repr(value)}")



def convert_default_value_to_ast(value):
if value is Undefined:
return None
if value is None:
return ast.Constant(value=None)
if isinstance(value, str):
return ast.Constant(value=value)
if isinstance(value, int):
return ast.Constant(value=value)
if isinstance(value, float):
return ast.Constant(value=value)
if isinstance(value, bool):
return ast.Constant(value=value)
if isinstance(value, list):
return ast.List(elts=[convert_default_value_to_ast(x) for x in value], ctx=ast.Load())
if isinstance(value, dict):
keys = []
values = []

for key, value in value.items():
keys.append(key)
values.append(convert_default_value_to_ast(value))

return ast.Dict(
keys=keys,
values=values,
)
raise NotImplementedError(f"Unknown default value {repr(value)}")







def default_generate_directives(
client_schema: GraphQLSchema,
config: GeneratorConfig,
Expand Down Expand Up @@ -96,23 +234,53 @@ def default_generate_directives(

type = value.type

if isinstance(value.type, GraphQLNonNull):
type = value.type.of_type

assert isinstance(
type, GraphQLScalarType
), "Only scalar (or nonnull version of this) are supported"

if value.default_value:
default = ast.Constant(value=value.default_value)
if value.default_value is not None:
default = convert_default_value_to_ast(value.default_value)
else:
default = None

needs_factory = False
if isinstance(default, ast.List):
needs_factory = True
if isinstance(default, ast.Dict):
needs_factory = True


field_value = None

if default:
if needs_factory:
field_value = ast.Call(
func=ast.Name(id="strawberry.field", ctx=ast.Load()),
keywords=[
ast.keyword(
arg="default_factory",
value=ast.Lambda(
args=[], body=default
),
),
],
args=[],
)
else:
field_value = ast.Call(
func=ast.Name(id="strawberry.field", ctx=ast.Load()),
keywords=[
ast.keyword(
arg="default",
value=default,
),
],
args=[],
)


assign = ast.AnnAssign(
target=ast.Name(
id=registry.generate_node_name(value_key), ctx=ast.Store()
),
annotation=registry.reference_scalar(type.name),
value=default,
annotation=build_directive_type_annotation(type, registry),
value=field_value,
simple=1,
)

Expand Down Expand Up @@ -560,7 +728,7 @@ def generate_directive_keywords(
ctx=ast.Load(),
),
keywords=[
ast.keyword(arg=arg.name.value, value=ast.Constant(arg.value.value))
ast.keyword(arg=arg.name.value, value=convert_valuenode_to_ast(arg.value))
for arg in directive.arguments
],
args=[],
Expand Down

0 comments on commit 8d9d5aa

Please sign in to comment.