Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous Minor SpanGroups/DocBin Improvements #10250

Merged
merged 13 commits into from
Feb 21, 2022
13 changes: 13 additions & 0 deletions spacy/tests/doc/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,16 @@ def user_hook(doc):
def test_span_sents_not_parsed(doc_not_parsed):
with pytest.raises(ValueError):
list(Span(doc_not_parsed, 0, 3).sents)


def test_span_group_copy(doc):
doc.spans["test"] = [doc[0:1], doc[2:4]]
assert len(doc.spans["test"]) == 2
doc_copy = doc.copy()
# check that the spans were indeed copied
assert len(doc_copy.spans["test"]) == 2
# add a new span to the original doc
doc.spans["test"].append(doc[3:4])
assert len(doc.spans["test"]) == 3
# check that the copy spans were not modified and this is an isolated doc
assert len(doc_copy.spans["test"]) == 2
7 changes: 6 additions & 1 deletion spacy/tokens/_dict_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .span_group import SpanGroup
from ..errors import Errors


if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .doc import Doc # noqa: F401
Expand All @@ -19,6 +20,8 @@
class SpanGroups(UserDict):
"""A dict-like proxy held by the Doc, to control access to span groups."""

_EMPTY_BYTES = srsly.msgpack_dumps([])

def __init__(
self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple()
) -> None:
Expand All @@ -43,11 +46,13 @@ def copy(self, doc: Optional["Doc"] = None) -> "SpanGroups":
def to_bytes(self) -> bytes:
# We don't need to serialize this as a dict, because the groups
# know their names.
if len(self) == 0:
return self._EMPTY_BYTES
msg = [value.to_bytes() for value in self.values()]
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
msg = srsly.msgpack_loads(bytes_data)
msg = [] if bytes_data == self._EMPTY_BYTES else srsly.msgpack_loads(bytes_data)
self.clear()
doc = self._ensure_doc()
for value_bytes in msg:
Expand Down
3 changes: 2 additions & 1 deletion spacy/tokens/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..attrs import SPACY, ORTH, intify_attr, IDS
from ..errors import Errors
from ..util import ensure_path, SimpleFrozenList
from ._dict_proxies import SpanGroups

# fmt: off
ALL_ATTRS = ("ORTH", "NORM", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "ENT_ID", "LEMMA", "MORPH", "POS", "SENT_START")
Expand Down Expand Up @@ -146,7 +147,7 @@ def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces) # type: ignore
doc = doc.from_array(self.attrs, tokens) # type: ignore
doc.cats = self.cats[i]
if self.span_groups[i]:
if self.span_groups[i] != SpanGroups._EMPTY_BYTES:
doc.spans.from_bytes(self.span_groups[i])
else:
doc.spans.clear()
Expand Down