From 83819da7c2aa1878bad7b4b1b3a0cfdac3062aca Mon Sep 17 00:00:00 2001 From: alwx Date: Thu, 1 Oct 2020 15:51:09 +0200 Subject: [PATCH 01/16] FormSlotsValidatorAction added --- rasa_sdk/forms.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 84dac4e73..212eacae1 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -691,3 +691,45 @@ def _get_entity_type_of_slot_to_fill( return None return entity_type + + +class FormSlotsValidatorAction(Action): + def name(self) -> Text: + """Unique identifier of this action.""" + + raise NotImplementedError("An action must implement a name") + + def run( + self, dispatcher, tracker: Tracker, domain: Dict + ) -> List[EventType]: + """Execute the side effects of this action. + + Args: + dispatcher: the dispatcher which is used to + send messages back to the user. Use + ``dispatcher.utter_message()`` for sending messages. + tracker: the state tracker for the current + user. You can access slot values using + ``tracker.get_slot(slot_name)``, the most recent user message + is ``tracker.latest_message.text`` and any other + ``rasa_sdk.Tracker`` property. + domain: the bot's domain + Returns: + A dictionary of ``rasa_sdk.events.Event`` instances that is + returned through the endpoint + """ + + slots_to_validate: Dict[Text, Any] = tracker.form_slots_to_validate() + validation_events = [] + + for slot_name, slot_value in slots_to_validate.items(): + function_name = f"validate_{slot_name}" + fn = getattr(self, function_name) + if fn(tracker, domain, slot_value): + validation_events.append(SlotSet(slot_name, slot_value)) + else: + # Return a `SlotSet` event with value `None` to indicate that this + # slot still needs to be filled. + validation_events.append(SlotSet(slot_name, None)) + + return validation_events From a8ee5b3548c27956f2bebbdda5ba1ca8fc5a645f Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 10:24:50 +0200 Subject: [PATCH 02/16] FormSlotsValidatorAction tests --- rasa_sdk/forms.py | 23 ++++-------- rasa_sdk/interfaces.py | 6 +-- rasa_sdk/knowledge_base/actions.py | 17 +++------ tests/test_actions.py | 10 +---- tests/test_forms.py | 59 +++++++++++++++++++++++++++++- 5 files changed, 75 insertions(+), 40 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 212eacae1..f5bfd70e6 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -672,8 +672,7 @@ def __str__(self) -> Text: return f"FormAction('{self.name()}')" def _get_entity_type_of_slot_to_fill( - self, - slot_to_fill: Optional[Text], + self, slot_to_fill: Optional[Text], ) -> Optional[Text]: if not slot_to_fill: return None @@ -694,29 +693,23 @@ def _get_entity_type_of_slot_to_fill( class FormSlotsValidatorAction(Action): + """An action that validates if every extracted slot is valid.""" + def name(self) -> Text: """Unique identifier of this action.""" raise NotImplementedError("An action must implement a name") - def run( - self, dispatcher, tracker: Tracker, domain: Dict - ) -> List[EventType]: + async def run(self, dispatcher, tracker: "Tracker", domain: Dict) -> List[EventType]: """Execute the side effects of this action. Args: dispatcher: the dispatcher which is used to - send messages back to the user. Use - ``dispatcher.utter_message()`` for sending messages. - tracker: the state tracker for the current - user. You can access slot values using - ``tracker.get_slot(slot_name)``, the most recent user message - is ``tracker.latest_message.text`` and any other - ``rasa_sdk.Tracker`` property. - domain: the bot's domain + send messages back to the user. + tracker: the state tracker for the current user. + domain: the bot's domain. Returns: - A dictionary of ``rasa_sdk.events.Event`` instances that is - returned through the endpoint + A dictionary of ``rasa_sdk.events.Event`` instances. """ slots_to_validate: Dict[Text, Any] = tracker.form_slots_to_validate() diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index a19ddd267..5f7591008 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -245,6 +245,7 @@ def slots_to_validate(self) -> Dict[Text, Any]: slots: Dict[Text, Any] = {} + print(self.events) for event in reversed(self.events): # The `FormAction` in Rasa Open Source will append all slot candidates # at the end of the tracker events. @@ -267,10 +268,7 @@ def name(self) -> Text: raise NotImplementedError("An action must implement a name") async def run( - self, - dispatcher, - tracker: Tracker, - domain: "DomainDict", + self, dispatcher, tracker: Tracker, domain: "DomainDict", ) -> List[Dict[Text, Any]]: """Execute the side effects of this action. diff --git a/rasa_sdk/knowledge_base/actions.py b/rasa_sdk/knowledge_base/actions.py index 0176cc961..c8a28e201 100644 --- a/rasa_sdk/knowledge_base/actions.py +++ b/rasa_sdk/knowledge_base/actions.py @@ -92,10 +92,8 @@ async def utter_objects( if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): - repr_function = ( - await self.knowledge_base.get_representation_function_of_object( - object_type - ) + repr_function = await self.knowledge_base.get_representation_function_of_object( + object_type ) else: # see https://github.com/python/mypy/issues/5206 @@ -114,10 +112,7 @@ async def utter_objects( ) async def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: "DomainDict", + self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: "DomainDict", ) -> List[Dict[Text, Any]]: """ Executes this action. If the user ask a question about an attribute, @@ -273,10 +268,8 @@ async def _query_attribute( if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): - repr_function = ( - await self.knowledge_base.get_representation_function_of_object( - object_type - ) + repr_function = await self.knowledge_base.get_representation_function_of_object( + object_type ) else: # see https://github.com/python/mypy/issues/5206 diff --git a/tests/test_actions.py b/tests/test_actions.py index 3be07e750..de6d82644 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -12,10 +12,7 @@ def name(cls) -> Text: return "custom_async_action" async def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, + self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: DomainDict, ) -> List[Dict[Text, Any]]: return [SlotSet("test", "foo"), SlotSet("test2", "boo")] @@ -25,10 +22,7 @@ def name(cls) -> Text: return "custom_action" def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, + self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: DomainDict, ) -> List[Dict[Text, Any]]: return [SlotSet("test", "bar")] diff --git a/tests/test_forms.py b/tests/test_forms.py index b7a7d15db..f5a0668ec 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -5,7 +5,7 @@ from rasa_sdk import Tracker, ActionExecutionRejection from rasa_sdk.events import SlotSet, ActiveLoop from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.forms import FormAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY +from rasa_sdk.forms import FormAction, FormSlotsValidatorAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY def test_extract_requested_slot_default(): @@ -1493,3 +1493,60 @@ async def test_submit(form_class: Type[FormAction]): def test_form_deprecation(): with pytest.warns(FutureWarning): FormAction() + + +class TestFormSlotValidator(FormSlotsValidatorAction): + def name(self) -> Text: + return "some_form" + + @staticmethod + def validate_slot1(tracker: Tracker, domain: Dict, slot_value: Any) -> bool: + return slot_value == "correct_value" + + @staticmethod + def validate_slot2(tracker: Tracker, domain: Dict, slot_value: Any) -> bool: + return slot_value == "correct_value" + + +async def test_form_slot_validator(): + form = TestFormSlotValidator() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value")], + False, + None, + {"name": "some_form", "validate": True, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + # check that the form was activated and validation was performed + assert events == [ + SlotSet("slot2", None), + SlotSet("slot1", "correct_value"), + ] + + +async def test_form_slot_validator_missing_method(): + form = TestFormSlotValidator() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value"), SlotSet("slot3", "some_value")], + False, + None, + {"name": "some_form", "validate": True, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + with pytest.raises(AttributeError): + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) \ No newline at end of file From 3474274884180c2e3f2d6f1d2069d58b1a663f9d Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 10:26:23 +0200 Subject: [PATCH 03/16] Black reformatting; minor test name change --- rasa_sdk/forms.py | 4 +++- tests/test_forms.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index f5bfd70e6..524aeb5f8 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -700,7 +700,9 @@ def name(self) -> Text: raise NotImplementedError("An action must implement a name") - async def run(self, dispatcher, tracker: "Tracker", domain: Dict) -> List[EventType]: + async def run( + self, dispatcher, tracker: "Tracker", domain: Dict + ) -> List[EventType]: """Execute the side effects of this action. Args: diff --git a/tests/test_forms.py b/tests/test_forms.py index f5a0668ec..e64a88d39 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1532,7 +1532,7 @@ async def test_form_slot_validator(): ] -async def test_form_slot_validator_missing_method(): +async def test_form_slot_validator_attribute_error(): form = TestFormSlotValidator() # tracker with active form @@ -1540,7 +1540,11 @@ async def test_form_slot_validator_missing_method(): "default", {}, {}, - [SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value"), SlotSet("slot3", "some_value")], + [ + SlotSet("slot1", "correct_value"), + SlotSet("slot2", "incorrect_value"), + SlotSet("slot3", "some_value"), + ], False, None, {"name": "some_form", "validate": True, "rejected": False}, @@ -1549,4 +1553,4 @@ async def test_form_slot_validator_missing_method(): dispatcher = CollectingDispatcher() with pytest.raises(AttributeError): - events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) \ No newline at end of file + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) From ea2fc8ed4246552472cc2c9f09d9105f260d7a66 Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 10:30:31 +0200 Subject: [PATCH 04/16] Changelog entry --- changelog/288.improvement.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelog/288.improvement.md diff --git a/changelog/288.improvement.md b/changelog/288.improvement.md new file mode 100644 index 000000000..d832bae19 --- /dev/null +++ b/changelog/288.improvement.md @@ -0,0 +1,2 @@ +Add the `FormSlotsValidatorAction` abstract class that can be used +to validate slots which were extracted by a Form. \ No newline at end of file From b470378322b5de3b3187ed5678a4bef214bb134e Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 10:35:48 +0200 Subject: [PATCH 05/16] Test fix --- tests/test_forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forms.py b/tests/test_forms.py index e64a88d39..58fc91215 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1553,4 +1553,4 @@ async def test_form_slot_validator_attribute_error(): dispatcher = CollectingDispatcher() with pytest.raises(AttributeError): - events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) From 6ce562ba2e7884654f15ed6a44a102884e1fd338 Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 11:13:49 +0200 Subject: [PATCH 06/16] Black reformatting --- rasa_sdk/forms.py | 3 ++- rasa_sdk/interfaces.py | 5 ++++- rasa_sdk/knowledge_base/actions.py | 17 ++++++++++++----- tests/test_actions.py | 10 ++++++++-- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 524aeb5f8..084f74ecc 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -672,7 +672,8 @@ def __str__(self) -> Text: return f"FormAction('{self.name()}')" def _get_entity_type_of_slot_to_fill( - self, slot_to_fill: Optional[Text], + self, + slot_to_fill: Optional[Text], ) -> Optional[Text]: if not slot_to_fill: return None diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index 5f7591008..aea3c829c 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -268,7 +268,10 @@ def name(self) -> Text: raise NotImplementedError("An action must implement a name") async def run( - self, dispatcher, tracker: Tracker, domain: "DomainDict", + self, + dispatcher, + tracker: Tracker, + domain: "DomainDict", ) -> List[Dict[Text, Any]]: """Execute the side effects of this action. diff --git a/rasa_sdk/knowledge_base/actions.py b/rasa_sdk/knowledge_base/actions.py index c8a28e201..0176cc961 100644 --- a/rasa_sdk/knowledge_base/actions.py +++ b/rasa_sdk/knowledge_base/actions.py @@ -92,8 +92,10 @@ async def utter_objects( if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): - repr_function = await self.knowledge_base.get_representation_function_of_object( - object_type + repr_function = ( + await self.knowledge_base.get_representation_function_of_object( + object_type + ) ) else: # see https://github.com/python/mypy/issues/5206 @@ -112,7 +114,10 @@ async def utter_objects( ) async def run( - self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: "DomainDict", + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: "DomainDict", ) -> List[Dict[Text, Any]]: """ Executes this action. If the user ask a question about an attribute, @@ -268,8 +273,10 @@ async def _query_attribute( if utils.is_coroutine_action( self.knowledge_base.get_representation_function_of_object ): - repr_function = await self.knowledge_base.get_representation_function_of_object( - object_type + repr_function = ( + await self.knowledge_base.get_representation_function_of_object( + object_type + ) ) else: # see https://github.com/python/mypy/issues/5206 diff --git a/tests/test_actions.py b/tests/test_actions.py index de6d82644..3be07e750 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -12,7 +12,10 @@ def name(cls) -> Text: return "custom_async_action" async def run( - self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: DomainDict, + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> List[Dict[Text, Any]]: return [SlotSet("test", "foo"), SlotSet("test2", "boo")] @@ -22,7 +25,10 @@ def name(cls) -> Text: return "custom_action" def run( - self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: DomainDict, + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> List[Dict[Text, Any]]: return [SlotSet("test", "bar")] From 7ff7b94059d882a1d7417c0f059bf1797e2a8271 Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 11:21:36 +0200 Subject: [PATCH 07/16] Types update --- rasa_sdk/forms.py | 5 ++++- tests/test_forms.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 084f74ecc..ef7e02da0 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -702,7 +702,10 @@ def name(self) -> Text: raise NotImplementedError("An action must implement a name") async def run( - self, dispatcher, tracker: "Tracker", domain: Dict + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", ) -> List[EventType]: """Execute the side effects of this action. diff --git a/tests/test_forms.py b/tests/test_forms.py index 58fc91215..feb360e48 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -3,6 +3,7 @@ from typing import Type, Text, Dict, Any, List, Optional from rasa_sdk import Tracker, ActionExecutionRejection +from rasa_sdk.types import DomainDict from rasa_sdk.events import SlotSet, ActiveLoop from rasa_sdk.executor import CollectingDispatcher from rasa_sdk.forms import FormAction, FormSlotsValidatorAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY @@ -1500,11 +1501,11 @@ def name(self) -> Text: return "some_form" @staticmethod - def validate_slot1(tracker: Tracker, domain: Dict, slot_value: Any) -> bool: + def validate_slot1(tracker: Tracker, domain: "DomainDict", slot_value: Any) -> bool: return slot_value == "correct_value" @staticmethod - def validate_slot2(tracker: Tracker, domain: Dict, slot_value: Any) -> bool: + def validate_slot2(tracker: Tracker, domain: "DomainDict", slot_value: Any) -> bool: return slot_value == "correct_value" From ed852a6e7dce15ec441a668a30e650824a4d86ca Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 16:33:53 +0200 Subject: [PATCH 08/16] Update --- rasa_sdk/forms.py | 76 ++++++++++++++++++++++++++++++++++-------- rasa_sdk/interfaces.py | 1 - tests/test_forms.py | 38 +++++++++++++-------- 3 files changed, 88 insertions(+), 27 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index ef7e02da0..782b93fd2 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -693,21 +693,42 @@ def _get_entity_type_of_slot_to_fill( return entity_type -class FormSlotsValidatorAction(Action): - """An action that validates if every extracted slot is valid.""" +class FormValidationAction(Action): + """An action that validates if every extracted slot is valid. + + Example of usage: + ``` + class MyFormValidationAction(FormValidationAction): + def name(self) -> Text: + return "some_form" + + def validate_slot1( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) + return slot_value == "correct_value" + ``` + """ def name(self) -> Text: - """Unique identifier of this action.""" + """Unique identifier of this action. + + Returns: + Name of the action. + """ raise NotImplementedError("An action must implement a name") - async def run( + async def validate( self, dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: - """Execute the side effects of this action. + """Validate slots by calling a validation function for each slot. Args: dispatcher: the dispatcher which is used to @@ -715,20 +736,49 @@ async def run( tracker: the state tracker for the current user. domain: the bot's domain. Returns: - A dictionary of ``rasa_sdk.events.Event`` instances. + A dictionary of `rasa_sdk.events.Event` instances. """ - slots_to_validate: Dict[Text, Any] = tracker.form_slots_to_validate() - validation_events = [] + validated_events = [] for slot_name, slot_value in slots_to_validate.items(): function_name = f"validate_{slot_name}" - fn = getattr(self, function_name) - if fn(tracker, domain, slot_value): - validation_events.append(SlotSet(slot_name, slot_value)) + validate_func = getattr(self, function_name, lambda *x: False) + + if utils.is_coroutine_action(validate_func): + validation_output = await validate_func( + slot_value, dispatcher, tracker, domain + ) + else: + validation_output = validate_func(slot_value, dispatcher, tracker, domain) + + if validation_output: + validated_events.append(SlotSet(slot_name, slot_value)) else: # Return a `SlotSet` event with value `None` to indicate that this # slot still needs to be filled. - validation_events.append(SlotSet(slot_name, None)) + warnings.warn( + f"Cannot validate `{slot_name}`: make sure the validation function is specified and " + f"returns `True`" + ) + + return validated_events + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> List[EventType]: + """Execute the side effects of this action. + + Args: + dispatcher: the dispatcher which is used to + send messages back to the user. + tracker: the state tracker for the current user. + domain: the bot's domain. + Returns: + A dictionary of `rasa_sdk.events.Event` instances. + """ - return validation_events + return await self.validate(dispatcher, tracker, domain) diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index aea3c829c..a19ddd267 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -245,7 +245,6 @@ def slots_to_validate(self) -> Dict[Text, Any]: slots: Dict[Text, Any] = {} - print(self.events) for event in reversed(self.events): # The `FormAction` in Rasa Open Source will append all slot candidates # at the end of the tracker events. diff --git a/tests/test_forms.py b/tests/test_forms.py index feb360e48..19f01cd08 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -6,7 +6,7 @@ from rasa_sdk.types import DomainDict from rasa_sdk.events import SlotSet, ActiveLoop from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.forms import FormAction, FormSlotsValidatorAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY +from rasa_sdk.forms import FormAction, FormValidationAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY def test_extract_requested_slot_default(): @@ -1496,21 +1496,31 @@ def test_form_deprecation(): FormAction() -class TestFormSlotValidator(FormSlotsValidatorAction): +class TestFormValidationAction(FormValidationAction): def name(self) -> Text: return "some_form" - @staticmethod - def validate_slot1(tracker: Tracker, domain: "DomainDict", slot_value: Any) -> bool: + def validate_slot1( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> bool: return slot_value == "correct_value" - @staticmethod - def validate_slot2(tracker: Tracker, domain: "DomainDict", slot_value: Any) -> bool: + def validate_slot2( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> bool: return slot_value == "correct_value" -async def test_form_slot_validator(): - form = TestFormSlotValidator() +async def test_form_validation_action(): + form = TestFormValidationAction() # tracker with active form tracker = Tracker( @@ -1528,13 +1538,12 @@ async def test_form_slot_validator(): events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) # check that the form was activated and validation was performed assert events == [ - SlotSet("slot2", None), SlotSet("slot1", "correct_value"), ] -async def test_form_slot_validator_attribute_error(): - form = TestFormSlotValidator() +async def test_form_validation_action_attribute_error(): + form = TestFormValidationAction() # tracker with active form tracker = Tracker( @@ -1553,5 +1562,8 @@ async def test_form_slot_validator_attribute_error(): ) dispatcher = CollectingDispatcher() - with pytest.raises(AttributeError): - await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + # check that the form was activated and validation was performed + assert events == [ + SlotSet("slot1", "correct_value"), + ] From 69c28a42c7360a82d18d42a608af40d22aae5677 Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 16:34:23 +0200 Subject: [PATCH 09/16] Code style update --- rasa_sdk/forms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 782b93fd2..d7e859837 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -711,7 +711,7 @@ def validate_slot1( ) return slot_value == "correct_value" ``` - """ + """ def name(self) -> Text: """Unique identifier of this action. @@ -750,7 +750,9 @@ async def validate( slot_value, dispatcher, tracker, domain ) else: - validation_output = validate_func(slot_value, dispatcher, tracker, domain) + validation_output = validate_func( + slot_value, dispatcher, tracker, domain + ) if validation_output: validated_events.append(SlotSet(slot_name, slot_value)) From c02d12fc30b54b4cae8d323406bc01cec194cfa3 Mon Sep 17 00:00:00 2001 From: alwx Date: Fri, 2 Oct 2020 16:40:37 +0200 Subject: [PATCH 10/16] Small update: type annotation --- rasa_sdk/forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index d7e859837..5f97aa3e1 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -708,7 +708,7 @@ def validate_slot1( dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", - ) + ) -> bool: return slot_value == "correct_value" ``` """ From 4f99ddae2fbab837cb69d1facb6098c3ae36aae0 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 10:14:24 +0200 Subject: [PATCH 11/16] Updated logic --- changelog/288.improvement.md | 22 +++++++++++++++++++++- rasa_sdk/forms.py | 32 ++++++-------------------------- tests/test_forms.py | 19 +++++++++++++------ 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/changelog/288.improvement.md b/changelog/288.improvement.md index d832bae19..d6ebc16b5 100644 --- a/changelog/288.improvement.md +++ b/changelog/288.improvement.md @@ -1,2 +1,22 @@ Add the `FormSlotsValidatorAction` abstract class that can be used -to validate slots which were extracted by a Form. \ No newline at end of file +to validate slots which were extracted by a Form. + +Example: + +```python +class MyFormValidationAction(FormValidationAction): + def name(self) -> Text: + return "some_form" + + def validate_slot1( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot1": "validated_value", + } +``` \ No newline at end of file diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 5f97aa3e1..8a4ebb937 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -694,24 +694,7 @@ def _get_entity_type_of_slot_to_fill( class FormValidationAction(Action): - """An action that validates if every extracted slot is valid. - - Example of usage: - ``` - class MyFormValidationAction(FormValidationAction): - def name(self) -> Text: - return "some_form" - - def validate_slot1( - self, - slot_value: Any, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> bool: - return slot_value == "correct_value" - ``` - """ + """An action that validates if every extracted slot is valid.""" def name(self) -> Text: """Unique identifier of this action. @@ -738,10 +721,9 @@ async def validate( Returns: A dictionary of `rasa_sdk.events.Event` instances. """ - slots_to_validate: Dict[Text, Any] = tracker.form_slots_to_validate() - validated_events = [] + slots: Dict[Text, Any] = tracker.form_slots_to_validate() - for slot_name, slot_value in slots_to_validate.items(): + for slot_name, slot_value in slots.items(): function_name = f"validate_{slot_name}" validate_func = getattr(self, function_name, lambda *x: False) @@ -755,16 +737,14 @@ async def validate( ) if validation_output: - validated_events.append(SlotSet(slot_name, slot_value)) + slots.update(validation_output) else: - # Return a `SlotSet` event with value `None` to indicate that this - # slot still needs to be filled. warnings.warn( f"Cannot validate `{slot_name}`: make sure the validation function is specified and " - f"returns `True`" + f"returns a list of `rasa_sdk.events.Event` instances." ) - return validated_events + return [SlotSet(slot, value) for slot, value in slots.items()] async def run( self, diff --git a/tests/test_forms.py b/tests/test_forms.py index 19f01cd08..dc8e1dfcf 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1506,8 +1506,11 @@ def validate_slot1( dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", - ) -> bool: - return slot_value == "correct_value" + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot1": "validated_value", + } def validate_slot2( self, @@ -1515,8 +1518,11 @@ def validate_slot2( dispatcher: "CollectingDispatcher", tracker: "Tracker", domain: "DomainDict", - ) -> bool: - return slot_value == "correct_value" + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot2": "validated_value", + } async def test_form_validation_action(): @@ -1538,7 +1544,8 @@ async def test_form_validation_action(): events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) # check that the form was activated and validation was performed assert events == [ - SlotSet("slot1", "correct_value"), + SlotSet("slot1", "validated_value"), + SlotSet("slot2", "incorrect_value") ] @@ -1565,5 +1572,5 @@ async def test_form_validation_action_attribute_error(): events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) # check that the form was activated and validation was performed assert events == [ - SlotSet("slot1", "correct_value"), + SlotSet("slot1", "validated_value"), ] From 1c74bbb7755841d9649feb9d7dd0169cccfdd838 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 10:28:27 +0200 Subject: [PATCH 12/16] Check warning --- tests/test_forms.py | 50 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/test_forms.py b/tests/test_forms.py index dc8e1dfcf..7a986bce5 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1524,6 +1524,18 @@ def validate_slot2( "slot2": "validated_value", } + async def validate_slot3( + self, + slot_value: Any, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> Dict[Text, Any]: + if slot_value == "correct_value": + return { + "slot3": "validated_value", + } + async def test_form_validation_action(): form = TestFormValidationAction() @@ -1544,8 +1556,31 @@ async def test_form_validation_action(): events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) # check that the form was activated and validation was performed assert events == [ + SlotSet("slot2", "incorrect_value"), SlotSet("slot1", "validated_value"), - SlotSet("slot2", "incorrect_value") + ] + + +async def test_form_validation_action_async(): + form = TestFormValidationAction() + + # tracker with active form + tracker = Tracker( + "default", + {}, + {}, + [SlotSet("slot3", "correct_value")], + False, + None, + {"name": "some_form", "validate": True, "rejected": False}, + "action_listen", + ) + + dispatcher = CollectingDispatcher() + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + # check that the form was activated and validation was performed + assert events == [ + SlotSet("slot3", "validated_value") ] @@ -1569,8 +1604,11 @@ async def test_form_validation_action_attribute_error(): ) dispatcher = CollectingDispatcher() - events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) - # check that the form was activated and validation was performed - assert events == [ - SlotSet("slot1", "validated_value"), - ] + with pytest.warns(None): + events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) + # check that the form was activated and validation was performed + assert events == [ + SlotSet("slot3", "some_value"), + SlotSet("slot2", "incorrect_value"), + SlotSet("slot1", "validated_value"), + ] From f07893d9308b854a50ab05eecb553bd0f03cfd98 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 13:09:29 +0200 Subject: [PATCH 13/16] Update --- changelog/288.improvement.md | 3 +++ rasa_sdk/forms.py | 38 ++++++++++++------------------------ tests/test_forms.py | 26 ++++++++++++------------ 3 files changed, 29 insertions(+), 38 deletions(-) diff --git a/changelog/288.improvement.md b/changelog/288.improvement.md index d6ebc16b5..ef345b237 100644 --- a/changelog/288.improvement.md +++ b/changelog/288.improvement.md @@ -19,4 +19,7 @@ class MyFormValidationAction(FormValidationAction): return { "slot1": "validated_value", } + return { + "slot1": None, + } ``` \ No newline at end of file diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 8a4ebb937..098e50a25 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -3,6 +3,7 @@ import warnings from typing import Dict, Text, Any, List, Union, Optional, Tuple, cast +from abc import ABC from rasa_sdk import utils from rasa_sdk.events import SlotSet, EventType, ActiveLoop from rasa_sdk.interfaces import Action, ActionExecutionRejection @@ -693,18 +694,9 @@ def _get_entity_type_of_slot_to_fill( return entity_type -class FormValidationAction(Action): +class FormValidationAction(Action, ABC): """An action that validates if every extracted slot is valid.""" - def name(self) -> Text: - """Unique identifier of this action. - - Returns: - Name of the action. - """ - - raise NotImplementedError("An action must implement a name") - async def validate( self, dispatcher: "CollectingDispatcher", @@ -716,16 +708,22 @@ async def validate( Args: dispatcher: the dispatcher which is used to send messages back to the user. - tracker: the state tracker for the current user. + tracker: the conversation tracker for the current user. domain: the bot's domain. Returns: - A dictionary of `rasa_sdk.events.Event` instances. + `SlotSet` events for every validated slot. """ slots: Dict[Text, Any] = tracker.form_slots_to_validate() for slot_name, slot_value in slots.items(): function_name = f"validate_{slot_name}" - validate_func = getattr(self, function_name, lambda *x: False) + validate_func = getattr(self, function_name, None) + + if not validate_func: + warnings.warn( + f"Cannot validate `{slot_name}`: there is no validation function specified." + ) + continue if utils.is_coroutine_action(validate_func): validation_output = await validate_func( @@ -740,8 +738,7 @@ async def validate( slots.update(validation_output) else: warnings.warn( - f"Cannot validate `{slot_name}`: make sure the validation function is specified and " - f"returns a list of `rasa_sdk.events.Event` instances." + f"Cannot validate `{slot_name}`: make sure the validation function returns the correct output." ) return [SlotSet(slot, value) for slot, value in slots.items()] @@ -752,15 +749,4 @@ async def run( tracker: "Tracker", domain: "DomainDict", ) -> List[EventType]: - """Execute the side effects of this action. - - Args: - dispatcher: the dispatcher which is used to - send messages back to the user. - tracker: the state tracker for the current user. - domain: the bot's domain. - Returns: - A dictionary of `rasa_sdk.events.Event` instances. - """ - return await self.validate(dispatcher, tracker, domain) diff --git a/tests/test_forms.py b/tests/test_forms.py index 7a986bce5..f66e5c43f 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1511,6 +1511,9 @@ def validate_slot1( return { "slot1": "validated_value", } + return { + "slot1": None, + } def validate_slot2( self, @@ -1523,6 +1526,9 @@ def validate_slot2( return { "slot2": "validated_value", } + return { + "slot2": None, + } async def validate_slot3( self, @@ -1535,6 +1541,7 @@ async def validate_slot3( return { "slot3": "validated_value", } + # this function doesn't return anything when the slot value is incorrect async def test_form_validation_action(): @@ -1548,15 +1555,14 @@ async def test_form_validation_action(): [SlotSet("slot1", "correct_value"), SlotSet("slot2", "incorrect_value")], False, None, - {"name": "some_form", "validate": True, "rejected": False}, + {"name": "some_form", "is_interrupted": False, "rejected": False}, "action_listen", ) dispatcher = CollectingDispatcher() events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) - # check that the form was activated and validation was performed assert events == [ - SlotSet("slot2", "incorrect_value"), + SlotSet("slot2", None), SlotSet("slot1", "validated_value"), ] @@ -1572,16 +1578,13 @@ async def test_form_validation_action_async(): [SlotSet("slot3", "correct_value")], False, None, - {"name": "some_form", "validate": True, "rejected": False}, + {"name": "some_form", "is_interrupted": False, "rejected": False}, "action_listen", ) dispatcher = CollectingDispatcher() events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) - # check that the form was activated and validation was performed - assert events == [ - SlotSet("slot3", "validated_value") - ] + assert events == [SlotSet("slot3", "validated_value")] async def test_form_validation_action_attribute_error(): @@ -1599,16 +1602,15 @@ async def test_form_validation_action_attribute_error(): ], False, None, - {"name": "some_form", "validate": True, "rejected": False}, + {"name": "some_form", "is_interrupted": False, "rejected": False}, "action_listen", ) dispatcher = CollectingDispatcher() - with pytest.warns(None): + with pytest.warns(UserWarning): events = await form.run(dispatcher=dispatcher, tracker=tracker, domain=None) - # check that the form was activated and validation was performed assert events == [ SlotSet("slot3", "some_value"), - SlotSet("slot2", "incorrect_value"), + SlotSet("slot2", None), SlotSet("slot1", "validated_value"), ] From 620dec22ee8be8066a28816048a9c6cb75571c96 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 14:02:15 +0200 Subject: [PATCH 14/16] Naming update --- tests/test_forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forms.py b/tests/test_forms.py index f66e5c43f..af2ea248a 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1587,7 +1587,7 @@ async def test_form_validation_action_async(): assert events == [SlotSet("slot3", "validated_value")] -async def test_form_validation_action_attribute_error(): +async def test_form_validation_without_validate_function(): form = TestFormValidationAction() # tracker with active form From 02e30bde1c767ec06c868372dfc5165f857bb4f5 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 14:05:11 +0200 Subject: [PATCH 15/16] Post-rebase update --- rasa_sdk/forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa_sdk/forms.py b/rasa_sdk/forms.py index 098e50a25..82df25534 100644 --- a/rasa_sdk/forms.py +++ b/rasa_sdk/forms.py @@ -713,7 +713,7 @@ async def validate( Returns: `SlotSet` events for every validated slot. """ - slots: Dict[Text, Any] = tracker.form_slots_to_validate() + slots: Dict[Text, Any] = tracker.slots_to_validate() for slot_name, slot_value in slots.items(): function_name = f"validate_{slot_name}" From 9dc99c7a652ea0f937ab9efb1766c57a8d975e85 Mon Sep 17 00:00:00 2001 From: alwx Date: Mon, 5 Oct 2020 14:17:22 +0200 Subject: [PATCH 16/16] Code quality update --- tests/test_forms.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_forms.py b/tests/test_forms.py index af2ea248a..6fdf01c11 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -6,7 +6,12 @@ from rasa_sdk.types import DomainDict from rasa_sdk.events import SlotSet, ActiveLoop from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.forms import FormAction, FormValidationAction, REQUESTED_SLOT, LOOP_INTERRUPTED_KEY +from rasa_sdk.forms import ( + FormAction, + FormValidationAction, + REQUESTED_SLOT, + LOOP_INTERRUPTED_KEY, +) def test_extract_requested_slot_default():