Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple entities for one token #10394

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ba76878
Allow multiple entities for one token
raoulvm Nov 25, 2021
a2b5380
Formatted w/black
raoulvm Nov 25, 2021
1f923e2
Merge branch '2.8.x' into 2.8.14-test_multi_entities_only
raoulvm Nov 25, 2021
a528aee
line-length=88
raoulvm Nov 25, 2021
80430f0
Merge branch '2.8.14-test_multi_entities_only' of https://github.com/…
raoulvm Nov 25, 2021
8620ca2
black --line-length 88
raoulvm Nov 25, 2021
1b15ef8
add changelog
raoulvm Nov 25, 2021
4603677
Fix docstring liniting in test.py:1091
raoulvm Nov 25, 2021
cde992a
Change test to allow None extractor
raoulvm Nov 26, 2021
29262b4
Spelling in changelog
raoulvm Nov 26, 2021
829195b
Revert accidential revert.
raoulvm Nov 26, 2021
60dcf2e
Some docstring text to force
raoulvm Nov 30, 2021
15d909e
Merge branch '2.8.x' into 2.8.14-test_multi_entities_only
raoulvm Dec 8, 2021
0b5aac3
Merge branch '2.8.x' into 2.8.14-test_multi_entities_only
raoulvm Jan 14, 2022
ca3bcb6
init version of multiple entity writer
raoulvm Jan 19, 2022
558de89
satisfy `make lint`
raoulvm Jan 19, 2022
b0c50d6
reduce complexity for mypy test
raoulvm Jan 19, 2022
1287d1c
assert starts AND ends are equal
raoulvm Jan 19, 2022
0ad61f3
Merge branch '2.8.x' into 2.8.14-test_multi_entities_only
raoulvm Jan 19, 2022
d296d49
activate the change :-)
raoulvm Jan 19, 2022
6b36dcf
Merge branch '2.8.14-test_multi_entities_only' of https://github.com/…
raoulvm Jan 19, 2022
1142e08
make generate_message more readable
raoulvm Jan 21, 2022
f27af50
Merge branch '2.8.x' into 2.8.14-test_multi_entities_only
Jan 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions changelog/10394.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Allow multiple entities to be annotated for the same word/tokens.
When using entity extractors that support generating multiple entities from a single expression, the test stories fail as there is no way to annotate multiple entity_types and entity_values.
Entity Extractors like DIET are not optimized for training with multiple entity extractions, so be sure to use only Regex or FlashText or similar extractors.
New annotation option is
```YAML
stories:
- story: Some story
steps:
- user: |
I would like to cancel my contract for my [iphone][{"entity":"iphone", "value":"iphone"},{"entity":"smartphone", "value":"true"}{"entity":"mobile_service", "value":"true"}]
intent: cancel_contract
```
7 changes: 5 additions & 2 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,9 +1088,12 @@ def determine_entity_for_token(


def do_extractors_support_overlap(extractors: Optional[Set[Text]]) -> bool:
"""Checks if extractors support overlapping entities"""
"""Checks if extractors support overlapping entities.

If no extractor is given, assume support for overlapping entities.
"""
if extractors is None:
return False
return True

from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor

Expand Down
69 changes: 50 additions & 19 deletions rasa/shared/nlu/training_data/entities_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from json import JSONDecodeError
from typing import Text, List, Dict, Match, Optional, NamedTuple, Any
import logging

import rasa.shared.nlu.training_data.util
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
Expand All @@ -13,18 +14,22 @@
)
from rasa.shared.nlu.training_data.message import Message


GROUP_ENTITY_VALUE = "value"
GROUP_ENTITY_TYPE = "entity"
GROUP_ENTITY_DICT = "entity_dict"
GROUP_ENTITY_TEXT = "entity_text"
GROUP_COMPLETE_MATCH = 0
GROUP_ENTITY_DICT_LIST = "list_entity_dicts"

# regex for: `[entity_text]((entity_type(:entity_synonym)?)|{entity_dict})`
# regex for: `[entity_text]((entity_type(:entity_synonym)?)|{entity_dict}|[list_entity_dicts])` # noqa: E501, W505
ENTITY_REGEX = re.compile(
r"\[(?P<entity_text>[^\]]+?)\](\((?P<entity>[^:)]+?)(?:\:(?P<value>[^)]+))?\)|\{(?P<entity_dict>[^}]+?)\})" # noqa: E501, W505
r"\[(?P<entity_text>[^\]]+?)\](\((?P<entity>[^:)]+?)(?:\:(?P<value>[^)]+))?\)|\{(?P<entity_dict>[^}]+?)\}|\[(?P<list_entity_dicts>.*?)\])" # noqa: E501, W505
)

SINGLE_ENTITY_DICT = re.compile(r"{(?P<entity_dict>[^}]+?)\}")

logger = logging.getLogger(__name__)


class EntityAttributes(NamedTuple):
"""Attributes of an entity defined in markdown data."""
Expand All @@ -50,22 +55,48 @@ def find_entities_in_training_example(example: Text) -> List[Dict[Text, Any]]:
offset = 0

for match in re.finditer(ENTITY_REGEX, example):
entity_attributes = extract_entity_attributes(match)

start_index = match.start() - offset
end_index = start_index + len(entity_attributes.text)
offset += len(match.group(0)) - len(entity_attributes.text)

entity = rasa.shared.nlu.training_data.util.build_entity(
start_index,
end_index,
entity_attributes.value,
entity_attributes.type,
entity_attributes.role,
entity_attributes.group,
)
entities.append(entity)

logger.debug(f"{match}")
if match.groupdict()[GROUP_ENTITY_DICT] or match.groupdict()[GROUP_ENTITY_TYPE]:
entity_attributes = extract_entity_attributes(match)

start_index = match.start() - offset
end_index = start_index + len(entity_attributes.text)
offset += len(match.group(0)) - len(entity_attributes.text)

entity = rasa.shared.nlu.training_data.util.build_entity(
start_index,
end_index,
entity_attributes.value,
entity_attributes.type,
entity_attributes.role,
entity_attributes.group,
)
entities.append(entity)
else:
entity_text = match.groupdict()[GROUP_ENTITY_TEXT]
# iterate over the list

start_index = match.start() - offset
end_index = start_index + len(entity_text)
offset += len(match.group(0)) - len(entity_text)

for match_inner in re.finditer(
SINGLE_ENTITY_DICT, match.groupdict()[GROUP_ENTITY_DICT_LIST]
):

entity_attributes = extract_entity_attributes_from_dict(
entity_text=entity_text, match=match_inner
)

entity = rasa.shared.nlu.training_data.util.build_entity(
start_index,
end_index,
entity_attributes.value,
entity_attributes.type,
entity_attributes.role,
entity_attributes.group,
)
entities.append(entity)
return entities


Expand Down
85 changes: 73 additions & 12 deletions rasa/shared/nlu/training_data/formats/readerwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import rasa.shared.utils.io
import typing
from typing import Text, Dict, Any, Union
from typing import List, Text, Dict, Any, Union

if typing.TYPE_CHECKING:
from rasa.shared.nlu.training_data.training_data import TrainingData
Expand Down Expand Up @@ -100,22 +100,44 @@ def generate_message(message: Dict[Text, Any]) -> Text:
entities_with_start_and_end, key=operator.itemgetter("start")
)

# aggregate entities with same start and end (multiple entities from
# same token group)
aggregated_entities = []
last_start = None
last_end = None
for entity in sorted_entities:
md += text[pos : entity["start"]]
if (
last_start is None
or last_end is None
or last_start != entity["start"]
):
last_start = entity["start"]
last_end = entity["end"]
aggregated_entities.append(entity)
else:
agg = aggregated_entities[-1]
if isinstance(agg, list):
agg.append(entity)
else:
agg = aggregated_entities.pop()
aggregated_entities.append([agg, entity])

for entity in aggregated_entities:
entity0 = entity[0] if isinstance(entity, list) else entity
md += text[pos : entity0["start"]]
md += TrainingDataWriter.generate_entity(text, entity)
pos = entity["end"]
pos = entity0["end"]

md += text[pos:]

return md

@staticmethod
def generate_entity(text: Text, entity: Dict[Text, Any]) -> Text:
"""Generates text for an entity object."""

entity_text = text[
entity[ENTITY_ATTRIBUTE_START] : entity[ENTITY_ATTRIBUTE_END]
]
def generate_entity_attributes(
text: Text, entity: Dict[Text, Any], short_allowed: bool = True
) -> Text:
"""Generates text for the entity attributes."""
entity_text = text
entity_type = entity.get(ENTITY_ATTRIBUTE_TYPE)
entity_value = entity.get(ENTITY_ATTRIBUTE_VALUE)
entity_role = entity.get(ENTITY_ATTRIBUTE_ROLE)
Expand All @@ -125,11 +147,14 @@ def generate_entity(text: Text, entity: Dict[Text, Any]) -> Text:
entity_value = None

use_short_syntax = (
entity_value is None and entity_role is None and entity_group is None
short_allowed
and entity_value is None
and entity_role is None
and entity_group is None
)

if use_short_syntax:
return f"[{entity_text}]({entity_type})"
return f"({entity_type})"
else:
entity_dict = OrderedDict(
[
Expand All @@ -143,10 +168,46 @@ def generate_entity(text: Text, entity: Dict[Text, Any]) -> Text:
[(k, v) for k, v in entity_dict.items() if v is not None]
)

return f"[{entity_text}]{json.dumps(entity_dict)}"
return f"{json.dumps(entity_dict)}"

@staticmethod
def generate_entity(
text: Text, entity: Union[Dict[Text, Any], List[Dict[Text, Any]]]
) -> Text:
"""Generates text for an entity object."""
if isinstance(entity, list):
entity_text = text[
entity[0][ENTITY_ATTRIBUTE_START] : entity[0][ENTITY_ATTRIBUTE_END]
]
return (
f"[{entity_text}]["
+ ",".join(
[
TrainingDataWriter.generate_entity_attributes(
text=entity_text, entity=e, short_allowed=False
)
for e in entity
]
)
+ "]"
)
else:
entity_text = text[
entity[ENTITY_ATTRIBUTE_START] : entity[ENTITY_ATTRIBUTE_END]
]
return f"[{entity_text}]" + TrainingDataWriter.generate_entity_attributes(
text=entity_text, entity=entity, short_allowed=True
)


class JsonTrainingDataReader(TrainingDataReader):
"""Add a docstring here.

Lint complains:
rasa/shared/nlu/training_data/formats/readerwriter.py:206:1:
D101 Missing docstring in public class
"""

def reads(self, s: Text, **kwargs: Any) -> "TrainingData":
"""Transforms string into json object and passes it on."""
js = json.loads(s)
Expand Down
10 changes: 6 additions & 4 deletions tests/nlu/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def test_determine_token_labels_throws_error():


def test_determine_token_labels_no_extractors():
with pytest.raises(ValueError):
determine_token_labels(
CH_correct_segmentation[0], [CH_correct_entity, CH_wrong_entity], None
)
"""
If no extractor is given, entities might overlap.
"""
assert "direction" == determine_token_labels(
CH_correct_segmentation[0], [CH_correct_entity, CH_wrong_entity], None
)


def test_determine_token_labels_no_extractors_no_overlap():
Expand Down