Skip to content

Commit

Permalink
Add util.filter_spans helper (#3686)
Browse files Browse the repository at this point in the history
  • Loading branch information
ines authored and honnibal committed May 8, 2019
1 parent dd1e6b0 commit 505c9e0
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
19 changes: 19 additions & 0 deletions spacy/tests/doc/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.errors import ModelsWarning
from spacy.util import filter_spans

from ..util import get_doc

Expand Down Expand Up @@ -219,3 +220,21 @@ def test_span_ents_property(doc):
assert sentences[2].ents[0].label_ == "PRODUCT"
assert sentences[2].ents[0].start == 11
assert sentences[2].ents[0].end == 14


def test_filter_spans(doc):
# Test filtering duplicates
spans = [doc[1:4], doc[6:8], doc[1:4], doc[10:14]]
filtered = filter_spans(spans)
assert len(filtered) == 3
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 6 and filtered[1].end == 8
assert filtered[2].start == 10 and filtered[2].end == 14
# Test filtering overlaps with longest preference
spans = [doc[1:4], doc[1:3], doc[5:10], doc[7:9], doc[1:4]]
filtered = filter_spans(spans)
assert len(filtered) == 2
assert len(filtered[0]) == 3
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10
22 changes: 22 additions & 0 deletions spacy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,28 @@ def itershuffle(iterable, bufsize=1000):
raise StopIteration


def filter_spans(spans):
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans.
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result


def to_bytes(getters, exclude):
serialized = OrderedDict()
for key, getter in getters.items():
Expand Down
21 changes: 21 additions & 0 deletions website/docs/api/top-level.md
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,27 @@ for batching. Larger `buffsize` means less bias.
| `buffsize` | int | Items to hold back. |
| **YIELDS** | iterable | The shuffled iterator. |
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
Filter a sequence of [`Span`](/api/span) objects and remove duplicates or
overlaps. Useful for creating named entities (where one token can only be part
of one entity) or when merging spans with
[`Retokenizer.merge`](/api/doc#retokenizer.merge). When spans overlap, the
(first) longest span is preferred over shorter spans.
> #### Example
>
> ```python
> doc = nlp("This is a sentence.")
> spans = [doc[0:2], doc[0:2], doc[0:4]]
> filtered = filter_spans(spans)
> ```
| Name | Type | Description |
| ----------- | -------- | -------------------- |
| `spans` | iterable | The spans to filter. |
| **RETURNS** | list | The filtered spans. |
## Compatibility functions {#compat source="spacy/compaty.py"}
All Python code is written in an **intersection of Python 2 and Python 3**. This
Expand Down

0 comments on commit 505c9e0

Please sign in to comment.