Skip to content

Commit

Permalink
Merge pull request #288 from RasaHQ/form-slots-validator-action
Browse files Browse the repository at this point in the history
FormSlotsValidatorAction added
  • Loading branch information
alwx authored Oct 5, 2020
2 parents df94fb8 + 9dc99c7 commit 57aca39
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 1 deletion.
25 changes: 25 additions & 0 deletions changelog/288.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Add the `FormSlotsValidatorAction` abstract class that can be used
to validate slots which were extracted by a Form.

Example:

```python
class MyFormValidationAction(FormValidationAction):
def name(self) -> Text:
return "some_form"

def validate_slot1(
self,
slot_value: Any,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
if slot_value == "correct_value":
return {
"slot1": "validated_value",
}
return {
"slot1": None,
}
```
59 changes: 59 additions & 0 deletions rasa_sdk/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import Dict, Text, Any, List, Union, Optional, Tuple, cast

from abc import ABC
from rasa_sdk import utils
from rasa_sdk.events import SlotSet, EventType, ActiveLoop
from rasa_sdk.interfaces import Action, ActionExecutionRejection
Expand Down Expand Up @@ -691,3 +692,61 @@ def _get_entity_type_of_slot_to_fill(
return None

return entity_type


class FormValidationAction(Action, ABC):
"""An action that validates if every extracted slot is valid."""

async def validate(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> List[EventType]:
"""Validate slots by calling a validation function for each slot.
Args:
dispatcher: the dispatcher which is used to
send messages back to the user.
tracker: the conversation tracker for the current user.
domain: the bot's domain.
Returns:
`SlotSet` events for every validated slot.
"""
slots: Dict[Text, Any] = tracker.slots_to_validate()

for slot_name, slot_value in slots.items():
function_name = f"validate_{slot_name}"
validate_func = getattr(self, function_name, None)

if not validate_func:
warnings.warn(
f"Cannot validate `{slot_name}`: there is no validation function specified."
)
continue

if utils.is_coroutine_action(validate_func):
validation_output = await validate_func(
slot_value, dispatcher, tracker, domain
)
else:
validation_output = validate_func(
slot_value, dispatcher, tracker, domain
)

if validation_output:
slots.update(validation_output)
else:
warnings.warn(
f"Cannot validate `{slot_name}`: make sure the validation function returns the correct output."
)

return [SlotSet(slot, value) for slot, value in slots.items()]

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> List[EventType]:
return await self.validate(dispatcher, tracker, domain)
128 changes: 127 additions & 1 deletion tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
from typing import Type, Text, Dict, Any, List, Optional

from rasa_sdk import Tracker, ActionExecutionRejection
from rasa_sdk.types import DomainDict
from rasa_sdk.events import SlotSet, ActiveLoop
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.forms import FormAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY
from rasa_sdk.forms import (
FormAction,
FormValidationAction,
REQUESTED_SLOT,
LOOP_INTERRUPTED_KEY,
)


def test_extract_requested_slot_default():
Expand Down Expand Up @@ -1493,3 +1499,123 @@ async def test_submit(form_class: Type[FormAction]):
def test_form_deprecation():
with pytest.warns(FutureWarning):
FormAction()


class TestFormValidationAction(FormValidationAction):
def name(self) -> Text:
return "some_form"

def validate_slot1(
self,
slot_value: Any,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
if slot_value == "correct_value":
return {
"slot1": "validated_value",
}
return {
"slot1": None,
}

def validate_slot2(
self,
slot_value: Any,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
if slot_value == "correct_value":
return {
"slot2": "validated_value",
}
return {
"slot2": None,
}

async def validate_slot3(
self,
slot_value: Any,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
if slot_value == "correct_value":
return {
"slot3": "validated_value",
}
# this function doesn't return anything when the slot value is incorrect


async def test_form_validation_action():
form = TestFormValidationAction()

# tracker with active form
tracker = Tracker(
"default",
{},
{},
[SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value")],
False,
None,
{"name": "some_form", "is_interrupted": False, "rejected": False},
"action_listen",
)

dispatcher = CollectingDispatcher()
events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None)
assert events == [
SlotSet("slot2", None),
SlotSet("slot1", "validated_value"),
]


async def test_form_validation_action_async():
form = TestFormValidationAction()

# tracker with active form
tracker = Tracker(
"default",
{},
{},
[SlotSet("slot3", "correct_value")],
False,
None,
{"name": "some_form", "is_interrupted": False, "rejected": False},
"action_listen",
)

dispatcher = CollectingDispatcher()
events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None)
assert events == [SlotSet("slot3", "validated_value")]


async def test_form_validation_without_validate_function():
form = TestFormValidationAction()

# tracker with active form
tracker = Tracker(
"default",
{},
{},
[
SlotSet("slot1", "correct_value"),
SlotSet("slot2", "incorrect_value"),
SlotSet("slot3", "some_value"),
],
False,
None,
{"name": "some_form", "is_interrupted": False, "rejected": False},
"action_listen",
)

dispatcher = CollectingDispatcher()
with pytest.warns(UserWarning):
events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None)
assert events == [
SlotSet("slot3", "some_value"),
SlotSet("slot2", None),
SlotSet("slot1", "validated_value"),
]

0 comments on commit 57aca39

Please sign in to comment.