diff --git a/changelog/288.improvement.md b/changelog/288.improvement.md new file mode 100644 index 000000000..ef345b237 --- /dev/null +++ b/changelog/288.improvement.md @@ -0,0 +1,25 @@ +Add the `FormSlotsValidatorAction` abstract class that can be used +to validate slots which were extracted by a Form. + +Example: + +```python +class MyFormValidationAction(FormValidationAction): + def name(self) -> Text: + return "some_form" + + def validate_slot1( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot1": "validated_value", + } + return { + "slot1": None, + } +``` \ No newline at end of file diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 84dac4e73..82df25534 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -3,6 +3,7 @@ import warnings from typing import Dict, Text, Any, List, Union, Optional, Tuple, cast +from abc import ABC from rasa_sdk import utils from rasa_sdk.events import SlotSet, EventType, ActiveLoop from rasa_sdk.interfaces import Action, ActionExecutionRejection @@ -691,3 +692,61 @@ def _get_entity_type_of_slot_to_fill( return None return entity_type + + +class FormValidationAction(Action, ABC): + """An action that validates if every extracted slot is valid.""" + + async def validate( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> List[EventType]: + """Validate slots by calling a validation function for each slot. + + Args: + dispatcher: the dispatcher which is used to + send messages back to the user. + tracker: the conversation tracker for the current user. + domain: the bot's domain. + Returns: + `SlotSet` events for every validated slot. + """ + slots: Dict[Text, Any] = tracker.slots_to_validate() + + for slot_name, slot_value in slots.items(): + function_name = f"validate_{slot_name}" + validate_func = getattr(self, function_name, None) + + if not validate_func: + warnings.warn( + f"Cannot validate `{slot_name}`: there is no validation function specified." + ) + continue + + if utils.is_coroutine_action(validate_func): + validation_output = await validate_func( + slot_value, dispatcher, tracker, domain + ) + else: + validation_output = validate_func( + slot_value, dispatcher, tracker, domain + ) + + if validation_output: + slots.update(validation_output) + else: + warnings.warn( + f"Cannot validate `{slot_name}`: make sure the validation function returns the correct output." + ) + + return [SlotSet(slot, value) for slot, value in slots.items()] + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> List[EventType]: + return await self.validate(dispatcher, tracker, domain) diff --git a/tests/test_forms.py b/tests/test_forms.py index b7a7d15db..6fdf01c11 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -3,9 +3,15 @@ from typing import Type, Text, Dict, Any, List, Optional from rasa_sdk import Tracker, ActionExecutionRejection +from rasa_sdk.types import DomainDict from rasa_sdk.events import SlotSet, ActiveLoop from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.forms import FormAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY +from rasa_sdk.forms import ( + FormAction, + FormValidationAction, + REQUESTED_SLOT, + LOOP_INTERRUPTED_KEY, +) def test_extract_requested_slot_default(): @@ -1493,3 +1499,123 @@ async def test_submit(form_class: Type[FormAction]): def test_form_deprecation(): with pytest.warns(FutureWarning): FormAction() + + +class TestFormValidationAction(FormValidationAction): + def name(self) -> Text: + return "some_form" + + def validate_slot1( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot1": "validated_value", + } + return { + "slot1": None, + } + + def validate_slot2( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot2": "validated_value", + } + return { + "slot2": None, + } + + async def validate_slot3( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot3": "validated_value", + } + # this function doesn't return anything when the slot value is incorrect + + +async def test_form_validation_action(): + form = TestFormValidationAction() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value")], + False, + None, + {"name": "some_form", "is_interrupted": False, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + assert events == [ + SlotSet("slot2", None), + SlotSet("slot1", "validated_value"), + ] + + +async def test_form_validation_action_async(): + form = TestFormValidationAction() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("slot3", "correct_value")], + False, + None, + {"name": "some_form", "is_interrupted": False, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + assert events == [SlotSet("slot3", "validated_value")] + + +async def test_form_validation_without_validate_function(): + form = TestFormValidationAction() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [ + SlotSet("slot1", "correct_value"), + SlotSet("slot2", "incorrect_value"), + SlotSet("slot3", "some_value"), + ], + False, + None, + {"name": "some_form", "is_interrupted": False, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + with pytest.warns(UserWarning): + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + assert events == [ + SlotSet("slot3", "some_value"), + SlotSet("slot2", None), + SlotSet("slot1", "validated_value"), + ]