Skip to content

Commit

Permalink
Merge pull request #7210 from RasaHQ/slot_mapping_validation
Browse files Browse the repository at this point in the history
add validation for slot mappings
  • Loading branch information
wochinge authored Nov 12, 2020
2 parents d81b5fd + b9e39f0 commit 8f248f1
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 29 deletions.
2 changes: 2 additions & 0 deletions changelog/7122.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add validations for [slot mappings](forms.mdx#slot-mappings).
If a slot mapping is not valid, an `InvalidDomain` error is raised.
21 changes: 9 additions & 12 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from enum import Enum
from typing import Text, List, Optional, Union, Any, Dict, Tuple, Set
import logging
import json

from rasa.core.actions import action
from rasa.core.actions.loops import LoopAction
from rasa.core.channels import OutputChannel
from rasa.shared.core.domain import Domain, InvalidDomain
from rasa.shared.core.domain import Domain, InvalidDomain, SlotMapping

from rasa.core.actions.action import ActionExecutionRejection, RemoteAction
from rasa.shared.core.constants import (
Expand All @@ -29,20 +28,18 @@
logger = logging.getLogger(__name__)


class SlotMapping(Enum):
FROM_ENTITY = 0
FROM_INTENT = 1
FROM_TRIGGER_INTENT = 2
FROM_TEXT = 3

def __str__(self) -> Text:
return self.name.lower()


class FormAction(LoopAction):
"""Action which implements and executes the form logic."""

def __init__(
self, form_name: Text, action_endpoint: Optional[EndpointConfig]
) -> None:
"""Creates a `FormAction`.
Args:
form_name: Name of the form.
action_endpoint: Endpoint to execute custom actions.
"""
self._form_name = form_name
self.action_endpoint = action_endpoint
# creating it requires domain, which we don't have in init
Expand Down
101 changes: 101 additions & 0 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
from enum import Enum
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -168,6 +169,9 @@ def from_dict(cls, data: Dict) -> "Domain":
additional_arguments = data.get("config", {})
session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {}))
intents = data.get(KEY_INTENTS, {})
forms = data.get(KEY_FORMS, {})

_validate_slot_mappings(forms)

return cls(
intents,
Expand Down Expand Up @@ -1458,3 +1462,100 @@ def slot_mapping_for_form(self, form_name: Text) -> Dict[Text, Any]:
The slot mapping or an empty dictionary in case no mapping was found.
"""
return self.forms.get(form_name, {})


class SlotMapping(Enum):
"""Defines the available slot mappings."""

FROM_ENTITY = 0
FROM_INTENT = 1
FROM_TRIGGER_INTENT = 2
FROM_TEXT = 3

def __str__(self) -> Text:
"""Returns a string representation of the object."""
return self.name.lower()

@staticmethod
def validate(mapping: Dict[Text, Any], form_name: Text, slot_name: Text) -> None:
"""Validates a slot mapping.
Args:
mapping: The mapping which is validated.
form_name: The name of the form which uses this slot mapping.
slot_name: The name of the slot which is mapped by this mapping.
Raises:
InvalidDomain: In case the slot mapping is not valid.
"""
if not isinstance(mapping, dict):
raise InvalidDomain(
f"Please make sure that the slot mappings for slot '{slot_name}' in "
f"your form '{form_name}' are valid dictionaries. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

validations = {
str(SlotMapping.FROM_ENTITY): ["entity"],
str(SlotMapping.FROM_INTENT): ["value"],
str(SlotMapping.FROM_TRIGGER_INTENT): ["value"],
str(SlotMapping.FROM_TEXT): [],
}

mapping_type = mapping.get("type")
required_keys = validations.get(mapping_type)

if required_keys is None:
raise InvalidDomain(
f"Your form '{form_name}' uses an invalid slot mapping of type "
f"'{mapping_type}' for slot '{slot_name}'. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

for required_key in required_keys:
if mapping.get(required_key) is None:
raise InvalidDomain(
f"You need to specify a value for the key "
f"'{required_key}' in the slot mapping of type '{mapping_type}' "
f"for slot '{slot_name}' in the form '{form_name}'. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)


def _validate_slot_mappings(forms: Union[Dict, List]) -> None:
if isinstance(forms, list):
if not all(isinstance(form_name, str) for form_name in forms):
raise InvalidDomain(
f"If you use the deprecated list syntax for forms, "
f"all form names have to be strings. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

return

if not isinstance(forms, dict):
raise InvalidDomain("Forms have to be specified as dictionary.")

for form_name, slots in forms.items():
if slots is None:
continue

if not isinstance(slots, Dict):
raise InvalidDomain(
f"The slots for form '{form_name}' were specified "
f"as '{type(slots)}'. They need to be specified "
f"as dictionary. Please see {rasa.shared.constants.DOCS_URL_FORMS} "
f"for more information."
)

for slot_name, slot_mappings in slots.items():
if not isinstance(slot_mappings, list):
raise InvalidDomain(
f"The slot mappings for slot '{slot_name}' in "
f"form '{form_name}' have type "
f"'{type(slot_mappings)}'. It is required to "
f"provide a list of slot mappings. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)
for slot_mapping in slot_mappings:
SlotMapping.validate(slot_mapping, form_name, slot_name)
16 changes: 0 additions & 16 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,22 +849,6 @@ def test_extract_requested_slot_from_entity(
assert slot_values == expected_slot_values


def test_invalid_slot_mapping():
form_name = "my_form"
form = FormAction(form_name, None)
slot_name = "test"
tracker = DialogueStateTracker.from_events(
"sender", [SlotSet(REQUESTED_SLOT, slot_name)]
)

domain = Domain.from_dict(
{"forms": {form_name: {slot_name: [{"type": "invalid"}]}}}
)

with pytest.raises(InvalidDomain):
form.extract_requested_slot(tracker, domain)


@pytest.mark.parametrize(
"some_other_slot_mapping, some_slot_mapping, entities, intent, expected_slot_values",
[
Expand Down
76 changes: 75 additions & 1 deletion tests/shared/core/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
IGNORE_ENTITIES_KEY,
State,
Domain,
KEY_FORMS,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered
Expand Down Expand Up @@ -477,7 +478,7 @@ def test_merge_domain_with_forms():
forms:
my_form3:
slot1:
type: from_text
- type: from_text
"""

domain_1 = Domain.from_yaml(test_yaml_1)
Expand Down Expand Up @@ -1068,3 +1069,76 @@ def test_get_featurized_entities():
featurized_entities = domain._get_featurized_entities(user_uttered)

assert featurized_entities == {"GPE", f"GPE{ENTITY_LABEL_SEPARATOR}destination"}


@pytest.mark.parametrize(
"domain_as_dict",
[
# No forms
{KEY_FORMS: {}},
# Deprecated but still support form syntax
{KEY_FORMS: ["my form", "other form"]},
# No slot mappings
{KEY_FORMS: {"my_form": None}},
{KEY_FORMS: {"my_form": {}}},
# Valid slot mappings
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_entity", "entity": "name"}]}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_intent", "value": 5}]}}},
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_intent", "value": "some value"}]}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_intent", "value": False}]}}},
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_trigger_intent", "value": 5}]}
}
},
{
KEY_FORMS: {
"my_form": {
"slot_x": [{"type": "from_trigger_intent", "value": "some value"}]
}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_text"}]}}},
],
)
def test_valid_slot_mappings(domain_as_dict: Dict[Text, Any]):
Domain.from_dict(domain_as_dict)


@pytest.mark.parametrize(
"domain_as_dict",
[
# Wrong type for slot names
{KEY_FORMS: {"my_form": []}},
{KEY_FORMS: {"my_form": 5}},
# Slot mappings not defined as list
{KEY_FORMS: {"my_form": {"slot1": {}}}},
# Unknown mapping
{KEY_FORMS: {"my_form": {"slot1": [{"type": "my slot mapping"}]}}},
# Mappings with missing keys
{
KEY_FORMS: {
"my_form": {"slot1": [{"type": "from_entity", "intent": "greet"}]}
}
},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_intent"}]}}},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_intent", "value": None}]}}},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_trigger_intent"}]}}},
{
KEY_FORMS: {
"my_form": {"slot1": [{"type": "from_trigger_intent", "value": None}]}
}
},
],
)
def test_form_invalid_mappings(domain_as_dict: Dict[Text, Any]):
with pytest.raises(InvalidDomain):
Domain.from_dict(domain_as_dict)

0 comments on commit 8f248f1

Please sign in to comment.