Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: queryGraph selection for query and update #2546

Merged
merged 2 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions rdflib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,13 +1544,19 @@ def query(
initBindings = initBindings or {} # noqa: N806
initNs = initNs or dict(self.namespaces()) # noqa: N806

if self.default_union:
query_graph = "__UNION__"
elif isinstance(self, ConjunctiveGraph):
query_graph = self.default_context.identifier
else:
query_graph = self.identifier
if hasattr(self.store, "query") and use_store_provided:
try:
return self.store.query(
query_object,
initNs,
initBindings,
self.default_union and "__UNION__" or self.identifier,
query_graph,
**kwargs,
)
except NotImplementedError:
Expand Down Expand Up @@ -1592,13 +1598,20 @@ def update(
initBindings = initBindings or {} # noqa: N806
initNs = initNs or dict(self.namespaces()) # noqa: N806

if self.default_union:
query_graph = "__UNION__"
elif isinstance(self, ConjunctiveGraph):
query_graph = self.default_context.identifier
else:
query_graph = self.identifier

if hasattr(self.store, "update") and use_store_provided:
try:
return self.store.update(
update_object,
initNs,
initBindings,
self.default_union and "__UNION__" or self.identifier,
query_graph,
**kwargs,
)
except NotImplementedError:
Expand Down
121 changes: 114 additions & 7 deletions test/test_graph/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,32 @@

import itertools
import logging
from test.data import SIMPLE_TRIPLE_GRAPH
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from unittest.mock import patch

import pytest

import rdflib.namespace
from rdflib.graph import Graph
from rdflib.graph import ConjunctiveGraph, Dataset, Graph
from rdflib.namespace import Namespace
from rdflib.plugins.sparql.sparql import Query
from rdflib.plugins.stores.memory import Memory
from rdflib.query import Result
from rdflib.store import Store
from rdflib.term import URIRef
from rdflib.term import Identifier, URIRef, Variable

if TYPE_CHECKING:
from _pytest.mark.structures import ParameterSet
Expand Down Expand Up @@ -69,7 +74,7 @@ def bind(self, prefix, namespace, override=True, replace=False) -> None:
EGNS_V2 = EGNS["v2"]


def make_test_graph_store_bind_cases(
def make_graph_store_bind_cases(
store_type: Type[Store] = Memory,
graph_type: Type[Graph] = Graph,
) -> Iterable[Union[Tuple[Any, ...], "ParameterSet"]]:
Expand Down Expand Up @@ -194,9 +199,9 @@ def _p(
@pytest.mark.parametrize(
["graph_factory", "ops", "expected_bindings"],
itertools.chain(
make_test_graph_store_bind_cases(),
make_test_graph_store_bind_cases(store_type=MemoryWithoutBindOverride),
make_test_graph_store_bind_cases(graph_type=GraphWithoutBindOverrideFix),
make_graph_store_bind_cases(),
make_graph_store_bind_cases(store_type=MemoryWithoutBindOverride),
make_graph_store_bind_cases(graph_type=GraphWithoutBindOverrideFix),
),
)
def test_graph_store_bind(
Expand All @@ -205,9 +210,111 @@ def test_graph_store_bind(
expected_bindings: NamespaceBindings,
) -> None:
"""
The expected sequence of graph operations results in the expected namespace bindings.
The expected sequence of graph operations results in the expected
namespace bindings.
"""
graph = graph_factory()
for op in ops:
op(graph)
check_ns(graph, expected_bindings)


@pytest.mark.parametrize(
("graph_factory", "query_graph"),
[
(Graph, lambda graph: graph.identifier),
(ConjunctiveGraph, "__UNION__"),
(Dataset, lambda graph: graph.default_context.identifier),
(lambda store: Dataset(store=store, default_union=True), "__UNION__"),
],
)
def test_query_query_graph(
graph_factory: Callable[[Store], Graph],
query_graph: Union[str, Callable[[Graph], str]],
) -> None:
"""
The `Graph.query` method passes the correct ``queryGraph`` argument
to stores that have implemented a `Store.query` method.
"""

mock_result = Result("SELECT")
mock_result.vars = [Variable("s"), Variable("p"), Variable("o")]
mock_result.bindings = [
{
Variable("s"): URIRef("http://example.org/subject"),
Variable("p"): URIRef("http://example.org/predicate"),
Variable("o"): URIRef("http://example.org/object"),
},
]

query_string = r"FAKE QUERY, NOT USED"
store = Memory()
graph = graph_factory(store)

if callable(query_graph):
query_graph = query_graph(graph)

def mock_query(
query: Union[Query, str],
initNs: Mapping[str, Any], # noqa: N803
initBindings: Mapping[str, Identifier], # noqa: N803
queryGraph: str,
**kwargs,
) -> Result:
assert query_string == query
assert dict(store.namespaces()) == initNs
assert {} == initBindings
assert query_graph == queryGraph
assert {} == kwargs
return mock_result

with patch.object(store, "query", wraps=mock_query) as wrapped_query:
actual_result = graph.query(query_string)
assert actual_result.type == "SELECT"
assert list(actual_result) == list(
SIMPLE_TRIPLE_GRAPH.triples((None, None, None))
)
assert wrapped_query.call_count == 1


@pytest.mark.parametrize(
("graph_factory", "query_graph"),
[
(Graph, lambda graph: graph.identifier),
(ConjunctiveGraph, "__UNION__"),
(Dataset, lambda graph: graph.default_context.identifier),
(lambda store: Dataset(store=store, default_union=True), "__UNION__"),
],
)
def test_update_query_graph(
graph_factory: Callable[[Store], Graph],
query_graph: Union[str, Callable[[Graph], str]],
) -> None:
"""
The `Graph.update` method passes the correct ``queryGraph`` argument
to stores that have implemented a `Store.update` method.
"""

update_string = r"FAKE UPDATE, NOT USED"
store = Memory()
graph = graph_factory(store)

if callable(query_graph):
query_graph = query_graph(graph)

def mock_update(
query: Union[Query, str],
initNs: Mapping[str, Any], # noqa: N803
initBindings: Mapping[str, Identifier], # noqa: N803
queryGraph: str,
**kwargs,
) -> None:
assert update_string == query
assert dict(store.namespaces()) == initNs
assert {} == initBindings
assert query_graph == queryGraph
assert {} == kwargs

with patch.object(store, "update", wraps=mock_update) as wrapped_update:
graph.update(update_string)
assert wrapped_update.call_count == 1