Skip to content

Commit

Permalink
feat!: add a default global ID encoder, add a id_field on the query type
Browse files Browse the repository at this point in the history
BREAKING CHANGE: the global ID function takes only a str with the GID, the query type uses the `id_field` arg to determine which kwarg to pass to the decoder
  • Loading branch information
pkucmus committed Feb 4, 2025
1 parent 000ca3d commit 2c8f064
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 18 deletions.
3 changes: 3 additions & 0 deletions ariadne/contrib/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RelayQueryType,
)
from ariadne.contrib.relay.types import ConnectionResolver, GlobalIDTuple
from ariadne.contrib.relay.utils import decode_global_id, encode_global_id

__all__ = [
"ConnectionArguments",
Expand All @@ -17,4 +18,6 @@
"RelayQueryType",
"ConnectionResolver",
"GlobalIDTuple",
"decode_global_id",
"encode_global_id",
]
16 changes: 10 additions & 6 deletions ariadne/contrib/relay/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Sequence

from typing_extensions import Any
from typing import Sequence, Any

from ariadne.contrib.relay.arguments import ConnectionArgumentsUnion

Expand All @@ -18,8 +16,11 @@ def __init__(
self.has_next_page = has_next_page
self.has_previous_page = has_previous_page

def get_cursor(self, node):
return node["id"]
def get_cursor(self, obj):
return obj["id"]

def get_node(self, obj):
return obj

def get_page_info(
self, connection_arguments: ConnectionArgumentsUnion
Expand All @@ -32,4 +33,7 @@ def get_page_info(
}

def get_edges(self):
return [{"node": node, "cursor": self.get_cursor(node)} for node in self.edges]
return [
{"node": self.get_node(obj), "cursor": self.get_cursor(obj)}
for obj in self.edges
]
12 changes: 5 additions & 7 deletions ariadne/contrib/relay/objects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from base64 import b64decode
from inspect import iscoroutinefunction
from typing import Optional, Tuple

Expand All @@ -13,15 +12,11 @@
from ariadne.contrib.relay.types import (
ConnectionResolver,
GlobalIDDecoder,
GlobalIDTuple,
)
from ariadne.contrib.relay.utils import decode_global_id
from ariadne.types import Resolver


def decode_global_id(kwargs) -> GlobalIDTuple:
return GlobalIDTuple(*b64decode(kwargs["id"]).decode().split(":"))


class RelayObjectType(ObjectType):
_node_resolver: Optional[Resolver] = None

Expand Down Expand Up @@ -89,6 +84,7 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:


class RelayNodeInterfaceType(InterfaceType):

def __init__(
self,
type_resolver: Optional[Resolver] = None,
Expand All @@ -101,13 +97,15 @@ def __init__(
self,
node: Optional[RelayNodeInterfaceType] = None,
global_id_decoder: GlobalIDDecoder = decode_global_id,
id_field: str = "id",
) -> None:
super().__init__("Query")
if node is None:
node = RelayNodeInterfaceType()
self.node = node
self.set_field("node", self.resolve_node)
self.global_id_decoder = global_id_decoder
self.id_field = id_field

@property
def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]:
Expand All @@ -121,7 +119,7 @@ def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver:
raise ValueError(f"No node resolver for type {type_name}") from exc

def resolve_node(self, obj, info, *args, **kwargs):
type_name, _ = self.global_id_decoder(kwargs)
type_name, _ = self.global_id_decoder(kwargs[self.id_field])

resolver = self.get_node_resolver(type_name, info.schema)

Expand Down
5 changes: 3 additions & 2 deletions ariadne/contrib/relay/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import namedtuple
from typing import Any, Callable, Dict
from typing import Callable

from typing_extensions import TypeVar

from ariadne.contrib.relay.connection import RelayConnection

ConnectionResolver = TypeVar("ConnectionResolver", bound=Callable[..., RelayConnection])
GlobalIDTuple = namedtuple("GlobalIDTuple", ["type", "id"])
GlobalIDDecoder = Callable[[Dict[str, Any]], GlobalIDTuple]
GlobalIDDecoder = Callable[[str], GlobalIDTuple]
GlobalIDEncoder = Callable[[str, str], str]
13 changes: 13 additions & 0 deletions ariadne/contrib/relay/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from base64 import b64decode, b64encode

from ariadne.contrib.relay.types import (
GlobalIDTuple,
)


def decode_global_id(gid: str) -> GlobalIDTuple:
return GlobalIDTuple(*b64decode(gid).decode().split(":"))


def encode_global_id(type_name: str, _id: str) -> str:
return b64encode(f"{type_name}:{_id}".encode()).decode()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "ariadne"
version = "0.25.1"
version = "0.25.2"
description = "Ariadne is a Python library for implementing GraphQL servers."
authors = [{ name = "Mirumee Software", email = "hello@mirumee.com" }]
readme = "README.md"
Expand Down
3 changes: 2 additions & 1 deletion tests/relay/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def relay_type_defs():

@pytest.fixture
def global_id_decoder():
return lambda kwargs: GlobalIDTuple(*b64decode(kwargs["bid"]).decode().split(":"))
return lambda gid: GlobalIDTuple(*b64decode(gid).decode().split(":"))


@pytest.fixture
Expand All @@ -71,6 +71,7 @@ def relay_query(factions, relay_node_interface, global_id_decoder):
query = RelayQueryType(
node=relay_node_interface,
global_id_decoder=global_id_decoder,
id_field="bid",
)
query.set_field("rebels", lambda *_: factions[0])
query.set_field("empire", lambda *_: factions[1])
Expand Down
2 changes: 1 addition & 1 deletion tests/relay/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def friends_connection():


def test_decode_global_id():
assert decode_global_id({"id": "VXNlcjox"}) == GlobalIDTuple("User", "1")
assert decode_global_id("VXNlcjox") == GlobalIDTuple("User", "1")


def test_default_id_decoder():
Expand Down

0 comments on commit 2c8f064

Please sign in to comment.