diff --git a/rdflib/graph.py b/rdflib/graph.py index 7f6dc3ef4..e6f6e25b1 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1544,13 +1544,19 @@ def query( initBindings = initBindings or {} # noqa: N806 initNs = initNs or dict(self.namespaces()) # noqa: N806 + if self.default_union: + query_graph = "__UNION__" + elif isinstance(self, ConjunctiveGraph): + query_graph = self.default_context.identifier + else: + query_graph = self.identifier if hasattr(self.store, "query") and use_store_provided: try: return self.store.query( query_object, initNs, initBindings, - self.default_union and "__UNION__" or self.identifier, + query_graph, **kwargs, ) except NotImplementedError: @@ -1592,13 +1598,20 @@ def update( initBindings = initBindings or {} # noqa: N806 initNs = initNs or dict(self.namespaces()) # noqa: N806 + if self.default_union: + query_graph = "__UNION__" + elif isinstance(self, ConjunctiveGraph): + query_graph = self.default_context.identifier + else: + query_graph = self.identifier + if hasattr(self.store, "update") and use_store_provided: try: return self.store.update( update_object, initNs, initBindings, - self.default_union and "__UNION__" or self.identifier, + query_graph, **kwargs, ) except NotImplementedError: diff --git a/test/test_graph/test_graph_store.py b/test/test_graph/test_graph_store.py index 144434cc6..300d9c85e 100644 --- a/test/test_graph/test_graph_store.py +++ b/test/test_graph/test_graph_store.py @@ -4,27 +4,32 @@ import itertools import logging +from test.data import SIMPLE_TRIPLE_GRAPH from typing import ( TYPE_CHECKING, Any, Callable, Dict, Iterable, + Mapping, Optional, Sequence, Tuple, Type, Union, ) +from unittest.mock import patch import pytest import rdflib.namespace -from rdflib.graph import Graph +from rdflib.graph import ConjunctiveGraph, Dataset, Graph from rdflib.namespace import Namespace +from rdflib.plugins.sparql.sparql import Query from rdflib.plugins.stores.memory import Memory +from rdflib.query import Result from rdflib.store import Store -from rdflib.term import URIRef +from rdflib.term import Identifier, URIRef, Variable if TYPE_CHECKING: from _pytest.mark.structures import ParameterSet @@ -69,7 +74,7 @@ def bind(self, prefix, namespace, override=True, replace=False) -> None: EGNS_V2 = EGNS["v2"] -def make_test_graph_store_bind_cases( +def make_graph_store_bind_cases( store_type: Type[Store] = Memory, graph_type: Type[Graph] = Graph, ) -> Iterable[Union[Tuple[Any, ...], "ParameterSet"]]: @@ -194,9 +199,9 @@ def _p( @pytest.mark.parametrize( ["graph_factory", "ops", "expected_bindings"], itertools.chain( - make_test_graph_store_bind_cases(), - make_test_graph_store_bind_cases(store_type=MemoryWithoutBindOverride), - make_test_graph_store_bind_cases(graph_type=GraphWithoutBindOverrideFix), + make_graph_store_bind_cases(), + make_graph_store_bind_cases(store_type=MemoryWithoutBindOverride), + make_graph_store_bind_cases(graph_type=GraphWithoutBindOverrideFix), ), ) def test_graph_store_bind( @@ -205,9 +210,111 @@ def test_graph_store_bind( expected_bindings: NamespaceBindings, ) -> None: """ - The expected sequence of graph operations results in the expected namespace bindings. + The expected sequence of graph operations results in the expected + namespace bindings. """ graph = graph_factory() for op in ops: op(graph) check_ns(graph, expected_bindings) + + +@pytest.mark.parametrize( + ("graph_factory", "query_graph"), + [ + (Graph, lambda graph: graph.identifier), + (ConjunctiveGraph, "__UNION__"), + (Dataset, lambda graph: graph.default_context.identifier), + (lambda store: Dataset(store=store, default_union=True), "__UNION__"), + ], +) +def test_query_query_graph( + graph_factory: Callable[[Store], Graph], + query_graph: Union[str, Callable[[Graph], str]], +) -> None: + """ + The `Graph.query` method passes the correct ``queryGraph`` argument + to stores that have implemented a `Store.query` method. + """ + + mock_result = Result("SELECT") + mock_result.vars = [Variable("s"), Variable("p"), Variable("o")] + mock_result.bindings = [ + { + Variable("s"): URIRef("http://example.org/subject"), + Variable("p"): URIRef("http://example.org/predicate"), + Variable("o"): URIRef("http://example.org/object"), + }, + ] + + query_string = r"FAKE QUERY, NOT USED" + store = Memory() + graph = graph_factory(store) + + if callable(query_graph): + query_graph = query_graph(graph) + + def mock_query( + query: Union[Query, str], + initNs: Mapping[str, Any], # noqa: N803 + initBindings: Mapping[str, Identifier], # noqa: N803 + queryGraph: str, + **kwargs, + ) -> Result: + assert query_string == query + assert dict(store.namespaces()) == initNs + assert {} == initBindings + assert query_graph == queryGraph + assert {} == kwargs + return mock_result + + with patch.object(store, "query", wraps=mock_query) as wrapped_query: + actual_result = graph.query(query_string) + assert actual_result.type == "SELECT" + assert list(actual_result) == list( + SIMPLE_TRIPLE_GRAPH.triples((None, None, None)) + ) + assert wrapped_query.call_count == 1 + + +@pytest.mark.parametrize( + ("graph_factory", "query_graph"), + [ + (Graph, lambda graph: graph.identifier), + (ConjunctiveGraph, "__UNION__"), + (Dataset, lambda graph: graph.default_context.identifier), + (lambda store: Dataset(store=store, default_union=True), "__UNION__"), + ], +) +def test_update_query_graph( + graph_factory: Callable[[Store], Graph], + query_graph: Union[str, Callable[[Graph], str]], +) -> None: + """ + The `Graph.update` method passes the correct ``queryGraph`` argument + to stores that have implemented a `Store.update` method. + """ + + update_string = r"FAKE UPDATE, NOT USED" + store = Memory() + graph = graph_factory(store) + + if callable(query_graph): + query_graph = query_graph(graph) + + def mock_update( + query: Union[Query, str], + initNs: Mapping[str, Any], # noqa: N803 + initBindings: Mapping[str, Identifier], # noqa: N803 + queryGraph: str, + **kwargs, + ) -> None: + assert update_string == query + assert dict(store.namespaces()) == initNs + assert {} == initBindings + assert query_graph == queryGraph + assert {} == kwargs + + with patch.object(store, "update", wraps=mock_update) as wrapped_update: + graph.update(update_string) + assert wrapped_update.call_count == 1