Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ATO 207 run form validation on form activation #11326

Merged
merged 7 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/11326.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revert change in #10295 that removed running the form validation action on activation of the form before the loop is active.
12 changes: 3 additions & 9 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,7 @@ def _update_slot_values(
domain: Domain,
slot_values: Dict[Text, Any],
) -> Dict[Text, Any]:
slot_mappings = self.get_mappings_for_slot(event.key, domain)

for mapping in slot_mappings:
slot_values[event.key] = event.value
slot_values[event.key] = event.value

return slot_values

Expand Down Expand Up @@ -534,11 +531,8 @@ async def _validate_if_required(
- the form is called after `action_listen`
- form validation was not cancelled
"""
# No active_loop means there are no form filled slots to validate yet
if not tracker.active_loop:
return []

needs_validation = (
# no active_loop means that it is called during activation
needs_validation = not tracker.active_loop or (
tracker.latest_action_name == ACTION_LISTEN_NAME
and not tracker.is_active_loop_interrupted
)
Expand Down
99 changes: 79 additions & 20 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.utils.endpoints import EndpointConfig

ACTION_SERVER_URL = "http://my-action-server:5055/webhook"


async def test_activate():
tracker = DialogueStateTracker.from_events(sender_id="bla", evts=[])
Expand Down Expand Up @@ -66,6 +68,68 @@ async def test_activate():
domain,
)
assert events[:-1] == [ActiveLoop(form_name), SlotSet(REQUESTED_SLOT, slot_name)]


async def test_activate_with_custom_slot_mapping():
tracker = DialogueStateTracker.from_events(sender_id="bla", evts=[])
form_name = "my_form"
action_server = EndpointConfig(ACTION_SERVER_URL)
action = FormAction(form_name, action_server)
domain_required_slot_name = "num_people"
slot_set_by_remote_custom_extraction_method = "some_slot"
slot_value_set_by_remote_custom_extraction_method = "anything"
domain = textwrap.dedent(
f"""
slots:
{domain_required_slot_name}:
type: float
mappings:
- type: from_entity
entity: number
{slot_set_by_remote_custom_extraction_method}:
type: any
mappings:
- type: custom
forms:
{form_name}:
{REQUIRED_SLOTS_KEY}:
- {domain_required_slot_name}
responses:
utter_ask_num_people:
- text: "How many people?"
actions:
- validate_{form_name}
"""
)
domain = Domain.from_yaml(domain)

form_validation_events = [
{
"event": "slot",
"timestamp": None,
"name": slot_set_by_remote_custom_extraction_method,
"value": slot_value_set_by_remote_custom_extraction_method,
}
]
with aioresponses() as mocked:
mocked.post(
ACTION_SERVER_URL,
payload={"events": form_validation_events},
)
events = await action.run(
CollectingOutputChannel(),
TemplatedNaturalLanguageGenerator(domain.responses),
tracker,
domain,
)
assert events[:-1] == [
ActiveLoop(form_name),
SlotSet(
slot_set_by_remote_custom_extraction_method,
slot_value_set_by_remote_custom_extraction_method,
),
SlotSet(REQUESTED_SLOT, domain_required_slot_name),
]
assert isinstance(events[-1], BotUttered)


Expand Down Expand Up @@ -112,7 +176,7 @@ async def test_activate_with_prefilled_slot():
tracker = DialogueStateTracker.from_events(
sender_id="bla", evts=[SlotSet(slot_name, slot_value)]
)
form_name = "my form"
form_name = "my_form"
action = FormAction(form_name, None)

next_slot_to_request = "next slot to request"
Expand Down Expand Up @@ -575,12 +639,10 @@ async def test_validate_slots(
assert slot_events == [SlotSet(slot_name, slot_value), SlotSet("num_tables", 5)]
tracker.update_with_events(slot_events, domain)

action_server_url = "http:/my-action-server:5055/webhook"

with aioresponses() as mocked:
mocked.post(action_server_url, payload={"events": validate_return_events})
mocked.post(ACTION_SERVER_URL, payload={"events": validate_return_events})

action_server = EndpointConfig(action_server_url)
action_server = EndpointConfig(ACTION_SERVER_URL)
action = FormAction(form_name, action_server)

events = await action.run(
Expand Down Expand Up @@ -633,8 +695,6 @@ async def test_request_correct_slots_after_unhappy_path_with_custom_required_slo
],
)

action_server_url = "http://my-action-server:5055/webhook"

# Custom form validation action changes the order of the requested slots
validate_return_events = [
{"event": "slot", "name": REQUESTED_SLOT, "value": slot_name_2}
Expand All @@ -644,9 +704,9 @@ async def test_request_correct_slots_after_unhappy_path_with_custom_required_slo
expected_events = [SlotSet(REQUESTED_SLOT, slot_name_2)]

with aioresponses() as mocked:
mocked.post(action_server_url, payload={"events": validate_return_events})
mocked.post(ACTION_SERVER_URL, payload={"events": validate_return_events})

action_server = EndpointConfig(action_server_url)
action_server = EndpointConfig(ACTION_SERVER_URL)
action = FormAction(form_name, action_server)

events = await action.run(
Expand Down Expand Up @@ -694,12 +754,11 @@ async def test_no_slots_extracted_with_custom_slot_mappings(custom_events: List[
- validate_{form_name}
"""
domain = Domain.from_yaml(domain)
action_server_url = "http:/my-action-server:5055/webhook"

with aioresponses() as mocked:
mocked.post(action_server_url, payload={"events": custom_events})
mocked.post(ACTION_SERVER_URL, payload={"events": custom_events})

action_server = EndpointConfig(action_server_url)
action_server = EndpointConfig(ACTION_SERVER_URL)
action = FormAction(form_name, action_server)

with pytest.raises(ActionExecutionRejection):
Expand Down Expand Up @@ -736,20 +795,18 @@ async def test_validate_slots_on_activation_with_other_action_after_user_utteran
- validate_{form_name}
"""
domain = Domain.from_yaml(domain)
action_server_url = "http:/my-action-server:5055/webhook"

expected_slot_value = "✅"
with aioresponses() as mocked:
mocked.post(
action_server_url,
ACTION_SERVER_URL,
payload={
"events": [
{"event": "slot", "name": slot_name, "value": expected_slot_value}
]
},
)

action_server = EndpointConfig(action_server_url)
action_server = EndpointConfig(ACTION_SERVER_URL)
action_extract_slots = ActionExtractSlots(action_endpoint=None)
slot_events = await action_extract_slots.run(
CollectingOutputChannel(),
Expand All @@ -761,6 +818,10 @@ async def test_validate_slots_on_activation_with_other_action_after_user_utteran

form_action = FormAction(form_name, action_server)

mocked.post(
ACTION_SERVER_URL,
payload={"events": []},
)
events = await form_action.run(
CollectingOutputChannel(),
TemplatedNaturalLanguageGenerator(domain.responses),
Expand Down Expand Up @@ -1663,17 +1724,15 @@ async def test_action_extract_slots_custom_mapping_with_condition():
sender_id="test_id", evts=events, slots=domain.slots
)

action_server_url = "http:/my-action-server:5055/webhook"

with aioresponses() as mocked:
mocked.post(
action_server_url,
ACTION_SERVER_URL,
payload={
"events": [{"event": "slot", "name": "custom_slot", "value": "test"}]
},
)

action_server = EndpointConfig(action_server_url)
action_server = EndpointConfig(ACTION_SERVER_URL)
action_extract_slots = ActionExtractSlots(action_server)
events = await action_extract_slots.run(
CollectingOutputChannel(),
Expand Down