Skip to content

Commit

Permalink
Fix displacy span stacking (#13068)
Browse files Browse the repository at this point in the history
* Fix displacy span stacking.

* Format. Remove counter.

* Remove test files.

* Add unit test. Refactor to allow for unit test.

* Fix off-by-one error in tests.
  • Loading branch information
rmitsch authored Nov 2, 2023
1 parent a804b83 commit c4e2daf
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
39 changes: 30 additions & 9 deletions spacy/displacy/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,25 @@ def render_spans(
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
title (str / None): Document title set in Doc.user_data['title'].
"""
per_token_info = []
per_token_info = self._assemble_per_token_info(tokens, spans)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup

@staticmethod
def _assemble_per_token_info(
tokens: List[str], spans: List[Dict[str, Any]]
) -> List[Dict[str, List[Dict[str, Any]]]]:
"""Assembles token info used to generate markup in render_spans().
tokens (List[str]): Tokens in text.
spans (List[Dict[str, Any]]): Spans in text.
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
and spans.
"""
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []

# we must sort so that we can correctly describe when spans need to "stack"
# which is determined by their start token, then span length (longer spans on top),
# then break any remaining ties with the span label
Expand All @@ -154,29 +172,35 @@ def render_spans(
s["label"],
),
)

for s in spans:
# this is the vertical 'slot' that the span will be rendered in
# vertical_position = span_label_offset + (offset_step * (slot - 1))
s["render_slot"] = 0

for idx, token in enumerate(tokens):
# Identify if a token belongs to a Span (and which) and if it's a
# start token of said Span. We'll use this for the final HTML render
token_markup: Dict[str, Any] = {}
token_markup["text"] = token
concurrent_spans = 0
intersecting_spans: List[Dict[str, Any]] = []
entities = []
for span in spans:
ent = {}
if span["start_token"] <= idx < span["end_token"]:
concurrent_spans += 1
span_start = idx == span["start_token"]
ent["label"] = span["label"]
ent["is_start"] = span_start
if span_start:
# When the span starts, we need to know how many other
# spans are on the 'span stack' and will be rendered.
# This value becomes the vertical render slot for this entire span
span["render_slot"] = concurrent_spans
span["render_slot"] = (
intersecting_spans[-1]["render_slot"]
if len(intersecting_spans)
else 0
) + 1
intersecting_spans.append(span)
ent["render_slot"] = span["render_slot"]
kb_id = span.get("kb_id", "")
kb_url = span.get("kb_url", "#")
Expand All @@ -193,11 +217,8 @@ def render_spans(
span["render_slot"] = 0
token_markup["entities"] = entities
per_token_info.append(token_markup)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup

return per_token_info

def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
"""Render the markup from per-token information"""
Expand Down
22 changes: 21 additions & 1 deletion spacy/tests/test_displacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from spacy import displacy
from spacy.displacy.render import DependencyRenderer, EntityRenderer
from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
from spacy.lang.en import English
from spacy.lang.fa import Persian
from spacy.tokens import Doc, Span
Expand Down Expand Up @@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
# Verify that the HTML tag is still escaped
html = displacy.render(doc, style="span")
assert "&lt;TEST&gt;" in html


@pytest.mark.issue(13056)
def test_displacy_span_stacking():
"""Test whether span stacking works properly for multiple overlapping spans."""
spans = [
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
{"start_token": 0, "end_token": 2, "label": "Skill"},
{"start_token": 1, "end_token": 3, "label": "Skill"},
]
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)

assert len(per_token_info) == len(tokens)
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
assert per_token_info[1]["entities"][0]["render_slot"] == 1
assert per_token_info[1]["entities"][1]["render_slot"] == 2
assert per_token_info[2]["entities"][0]["render_slot"] == 2
assert per_token_info[2]["entities"][1]["render_slot"] == 3

0 comments on commit c4e2daf

Please sign in to comment.