Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add validation for slot mappings #7210

Merged
merged 11 commits into from
Nov 12, 2020
2 changes: 2 additions & 0 deletions changelog/7122.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add validations for [slot mappings](forms.mdx#slot-mappings).
If a slot mapping is not valid, an `InvalidDomain` error is raised.
21 changes: 9 additions & 12 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from enum import Enum
from typing import Text, List, Optional, Union, Any, Dict, Tuple, Set
import logging
import json

from rasa.core.actions import action
from rasa.core.actions.loops import LoopAction
from rasa.core.channels import OutputChannel
from rasa.shared.core.domain import Domain, InvalidDomain
from rasa.shared.core.domain import Domain, InvalidDomain, SlotMapping

from rasa.core.actions.action import ActionExecutionRejection, RemoteAction
from rasa.shared.core.constants import (
Expand All @@ -29,20 +28,18 @@
logger = logging.getLogger(__name__)


class SlotMapping(Enum):
FROM_ENTITY = 0
FROM_INTENT = 1
FROM_TRIGGER_INTENT = 2
FROM_TEXT = 3

def __str__(self) -> Text:
return self.name.lower()


class FormAction(LoopAction):
"""Action which implements and executes the form logic."""

def __init__(
self, form_name: Text, action_endpoint: Optional[EndpointConfig]
) -> None:
"""Creates a `FormAction`.

Args:
form_name: Name of the form.
action_endpoint: Endpoint to execute custom actions.
"""
self._form_name = form_name
self.action_endpoint = action_endpoint
# creating it requires domain, which we don't have in init
Expand Down
101 changes: 101 additions & 0 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
from enum import Enum
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -168,6 +169,9 @@ def from_dict(cls, data: Dict) -> "Domain":
additional_arguments = data.get("config", {})
session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {}))
intents = data.get(KEY_INTENTS, {})
forms = data.get(KEY_FORMS, {})

_validate_slot_mappings(forms)

return cls(
intents,
Expand Down Expand Up @@ -1458,3 +1462,100 @@ def slot_mapping_for_form(self, form_name: Text) -> Dict[Text, Any]:
The slot mapping or an empty dictionary in case no mapping was found.
"""
return self.forms.get(form_name, {})


class SlotMapping(Enum):
"""Defines the available slot mappings."""

FROM_ENTITY = 0
FROM_INTENT = 1
FROM_TRIGGER_INTENT = 2
FROM_TEXT = 3

def __str__(self) -> Text:
"""Returns a string representation of the object."""
return self.name.lower()

@staticmethod
def validate(mapping: Dict[Text, Any], form_name: Text, slot_name: Text) -> None:
"""Validates a slot mapping.

Args:
mapping: The mapping which is validated.
form_name: The name of the form which uses this slot mapping.
slot_name: The name of the slot which is mapped by this mapping.

Raises:
InvalidDomain: In case the slot mapping is not valid.
"""
if not isinstance(mapping, dict):
raise InvalidDomain(
f"Please make sure that the slot mappings for slot '{slot_name}' in "
f"your form '{form_name}' are valid dictionaries. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

validations = {
str(SlotMapping.FROM_ENTITY): ["entity"],
str(SlotMapping.FROM_INTENT): ["value"],
str(SlotMapping.FROM_TRIGGER_INTENT): ["value"],
str(SlotMapping.FROM_TEXT): [],
}

mapping_type = mapping.get("type")
required_keys = validations.get(mapping_type)

if required_keys is None:
raise InvalidDomain(
f"Your form '{form_name}' uses an invalid slot mapping of type "
wochinge marked this conversation as resolved.
Show resolved Hide resolved
f"'{mapping_type}' for slot '{slot_name}'. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

for required_key in required_keys:
if mapping.get(required_key) is None:
raise InvalidDomain(
f"You need to specify a value for the key "
f"'{required_key}' in the slot mapping of type '{mapping_type}' "
f"for slot '{slot_name}' in the form '{form_name}'. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)


def _validate_slot_mappings(forms: Union[Dict, List]) -> None:
if isinstance(forms, list):
if not all(isinstance(form_name, str) for form_name in forms):
raise InvalidDomain(
f"If you use the deprecated list syntax for forms, "
f"all form names have to be strings. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)

return

if not isinstance(forms, dict):
raise InvalidDomain("Forms have to be specified as dictionary.")

for form_name, slots in forms.items():
if slots is None:
continue

if not isinstance(slots, Dict):
raise InvalidDomain(
f"The slots for form '{form_name}' were specified "
f"as '{type(slots)}'. They need to be specified "
f"as dictionary. Please see {rasa.shared.constants.DOCS_URL_FORMS} "
f"for more information."
)

for slot_name, slot_mappings in slots.items():
if not isinstance(slot_mappings, list):
raise InvalidDomain(
f"The slot mappings for slot '{slot_name}' in "
f"form '{form_name}' have type "
f"'{type(slot_mappings)}'. It is required to "
f"provide a list of slot mappings. Please see "
f"{rasa.shared.constants.DOCS_URL_FORMS} for more information."
)
for slot_mapping in slot_mappings:
SlotMapping.validate(slot_mapping, form_name, slot_name)
16 changes: 0 additions & 16 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,22 +849,6 @@ def test_extract_requested_slot_from_entity(
assert slot_values == expected_slot_values


def test_invalid_slot_mapping():
form_name = "my_form"
form = FormAction(form_name, None)
slot_name = "test"
tracker = DialogueStateTracker.from_events(
"sender", [SlotSet(REQUESTED_SLOT, slot_name)]
)

domain = Domain.from_dict(
{"forms": {form_name: {slot_name: [{"type": "invalid"}]}}}
)

with pytest.raises(InvalidDomain):
form.extract_requested_slot(tracker, domain)


@pytest.mark.parametrize(
"some_other_slot_mapping, some_slot_mapping, entities, intent, expected_slot_values",
[
Expand Down
76 changes: 75 additions & 1 deletion tests/shared/core/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
IGNORE_ENTITIES_KEY,
State,
Domain,
KEY_FORMS,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered
Expand Down Expand Up @@ -477,7 +478,7 @@ def test_merge_domain_with_forms():
forms:
my_form3:
slot1:
type: from_text
- type: from_text
"""

domain_1 = Domain.from_yaml(test_yaml_1)
Expand Down Expand Up @@ -1068,3 +1069,76 @@ def test_get_featurized_entities():
featurized_entities = domain._get_featurized_entities(user_uttered)

assert featurized_entities == {"GPE", f"GPE{ENTITY_LABEL_SEPARATOR}destination"}


@pytest.mark.parametrize(
"domain_as_dict",
[
# No forms
{KEY_FORMS: {}},
# Deprecated but still support form syntax
{KEY_FORMS: ["my form", "other form"]},
# No slot mappings
{KEY_FORMS: {"my_form": None}},
{KEY_FORMS: {"my_form": {}}},
# Valid slot mappings
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_entity", "entity": "name"}]}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_intent", "value": 5}]}}},
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_intent", "value": "some value"}]}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_intent", "value": False}]}}},
{
KEY_FORMS: {
"my_form": {"slot_x": [{"type": "from_trigger_intent", "value": 5}]}
}
},
{
KEY_FORMS: {
"my_form": {
"slot_x": [{"type": "from_trigger_intent", "value": "some value"}]
}
}
},
{KEY_FORMS: {"my_form": {"slot_x": [{"type": "from_text"}]}}},
],
)
def test_valid_slot_mappings(domain_as_dict: Dict[Text, Any]):
Domain.from_dict(domain_as_dict)


@pytest.mark.parametrize(
"domain_as_dict",
[
# Wrong type for slot names
{KEY_FORMS: {"my_form": []}},
{KEY_FORMS: {"my_form": 5}},
# Slot mappings not defined as list
{KEY_FORMS: {"my_form": {"slot1": {}}}},
# Unknown mapping
{KEY_FORMS: {"my_form": {"slot1": [{"type": "my slot mapping"}]}}},
# Mappings with missing keys
{
KEY_FORMS: {
"my_form": {"slot1": [{"type": "from_entity", "intent": "greet"}]}
}
},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_intent"}]}}},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_intent", "value": None}]}}},
{KEY_FORMS: {"my_form": {"slot1": [{"type": "from_trigger_intent"}]}}},
{
KEY_FORMS: {
"my_form": {"slot1": [{"type": "from_trigger_intent", "value": None}]}
}
},
],
)
def test_form_invalid_mappings(domain_as_dict: Dict[Text, Any]):
with pytest.raises(InvalidDomain):
Domain.from_dict(domain_as_dict)