Skip to content

Commit 978b220

Browse files
committed
Ported from 2.8.x
As original didn't change the changes from RasaHQ#10394 are applied without change
1 parent 65ab974 commit 978b220

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

rasa/nlu/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def determine_entity_for_token(
10961096
def do_extractors_support_overlap(extractors: Optional[Set[Text]]) -> bool:
10971097
"""Checks if extractors support overlapping entities"""
10981098
if extractors is None:
1099-
return False
1099+
return True
11001100

11011101
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
11021102

rasa/shared/nlu/training_data/entities_parser.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
GROUP_ENTITY_DICT = "entity_dict"
2020
GROUP_ENTITY_TEXT = "entity_text"
2121
GROUP_COMPLETE_MATCH = 0
22+
GROUP_ENTITY_DICT_LIST = "list_entity_dicts"
2223

23-
# regex for: `[entity_text]((entity_type(:entity_synonym)?)|{entity_dict})`
24+
# regex for: `[entity_text]((entity_type(:entity_synonym)?)|{entity_dict}|[list_entity_dicts])` # noqa: E501, W505
2425
ENTITY_REGEX = re.compile(
25-
r"\[(?P<entity_text>[^\]]+?)\](\((?P<entity>[^:)]+?)(?:\:(?P<value>[^)]+))?\)|\{(?P<entity_dict>[^}]+?)\})" # noqa: E501, W505
26+
r"\[(?P<entity_text>[^\]]+?)\](\((?P<entity>[^:)]+?)(?:\:(?P<value>[^)]+))?\)|\{(?P<entity_dict>[^}]+?)\}|\[(?P<list_entity_dicts>.*?)\])" # noqa: E501, W505
2627
)
2728

29+
SINGLE_ENTITY_DICT = re.compile(r"{(?P<entity_dict>[^}]+?)\}")
30+
2831

2932
class EntityAttributes(NamedTuple):
3033
"""Attributes of an entity defined in markdown data."""
@@ -50,21 +53,47 @@ def find_entities_in_training_example(example: Text) -> List[Dict[Text, Any]]:
5053
offset = 0
5154

5255
for match in re.finditer(ENTITY_REGEX, example):
53-
entity_attributes = extract_entity_attributes(match)
54-
55-
start_index = match.start() - offset
56-
end_index = start_index + len(entity_attributes.text)
57-
offset += len(match.group(0)) - len(entity_attributes.text)
58-
59-
entity = rasa.shared.nlu.training_data.util.build_entity(
60-
start_index,
61-
end_index,
62-
entity_attributes.value,
63-
entity_attributes.type,
64-
entity_attributes.role,
65-
entity_attributes.group,
66-
)
67-
entities.append(entity)
56+
if match.groupdict()[GROUP_ENTITY_DICT] or match.groupdict()[GROUP_ENTITY_TYPE]:
57+
entity_attributes = extract_entity_attributes(match)
58+
59+
start_index = match.start() - offset
60+
end_index = start_index + len(entity_attributes.text)
61+
offset += len(match.group(0)) - len(entity_attributes.text)
62+
63+
entity = rasa.shared.nlu.training_data.util.build_entity(
64+
start_index,
65+
end_index,
66+
entity_attributes.value,
67+
entity_attributes.type,
68+
entity_attributes.role,
69+
entity_attributes.group,
70+
)
71+
entities.append(entity)
72+
else:
73+
entity_text = match.groupdict()[GROUP_ENTITY_TEXT]
74+
# iterate over the list
75+
76+
start_index = match.start() - offset
77+
end_index = start_index + len(entity_text)
78+
offset += len(match.group(0)) - len(entity_text)
79+
80+
for match_inner in re.finditer(
81+
SINGLE_ENTITY_DICT, match.groupdict()[GROUP_ENTITY_DICT_LIST]
82+
):
83+
84+
entity_attributes = extract_entity_attributes_from_dict(
85+
entity_text=entity_text, match=match_inner
86+
)
87+
88+
entity = rasa.shared.nlu.training_data.util.build_entity(
89+
start_index,
90+
end_index,
91+
entity_attributes.value,
92+
entity_attributes.type,
93+
entity_attributes.role,
94+
entity_attributes.group,
95+
)
96+
entities.append(entity)
6897

6998
return entities
7099

tests/nlu/test_evaluation.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,9 @@ def test_determine_token_labels_throws_error():
205205

206206

207207
def test_determine_token_labels_no_extractors():
208-
with pytest.raises(ValueError):
209-
determine_token_labels(
210-
CH_correct_segmentation[0], [CH_correct_entity, CH_wrong_entity], None
211-
)
208+
assert "direction" == determine_token_labels(
209+
CH_correct_segmentation[0], [CH_correct_entity, CH_wrong_entity], None
210+
)
212211

213212

214213
def test_determine_token_labels_no_extractors_no_overlap():

0 commit comments

Comments
 (0)