diff --git a/changelog/1078.improvement.md b/changelog/1078.improvement.md new file mode 100644 index 000000000..9c03daa59 --- /dev/null +++ b/changelog/1078.improvement.md @@ -0,0 +1 @@ +Add a `stack` property to the `Tracker` class which corresponds to the dialogue stack. \ No newline at end of file diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index d4ec22194..79b81c597 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -33,6 +33,7 @@ def from_dict(cls, state: "TrackerState") -> "Tracker": state.get("followup_action"), state.get("active_loop", state.get("active_form", {})), state.get("latest_action_name"), + state.get("stack", []), ) def __init__( @@ -45,6 +46,7 @@ def __init__( followup_action: Optional[Text], active_loop: Dict[Text, Any], latest_action_name: Optional[Text], + stack: List[Dict[Text, Any]] = [], ) -> None: """Initialize the tracker.""" @@ -66,6 +68,7 @@ def __init__( self.latest_message = latest_message if latest_message else {} self.active_loop = active_loop self.latest_action_name = latest_action_name + self.stack = stack @property def active_form(self) -> Dict[Text, Any]: @@ -93,6 +96,7 @@ def current_state(self) -> Dict[Text, Any]: "latest_input_channel": self.get_latest_input_channel(), "active_loop": self.active_loop, "latest_action_name": self.latest_action_name, + "stack": self.stack, } def current_slot_values(self) -> Dict[Text, Any]: @@ -196,6 +200,7 @@ def copy(self) -> "Tracker": self.followup_action, self.active_loop, self.latest_action_name, + self.stack, ) def last_executed_action_has(self, name: Text, skip: int = 0) -> bool: diff --git a/rasa_sdk/types.py b/rasa_sdk/types.py index b5a5054d9..870643ded 100644 --- a/rasa_sdk/types.py +++ b/rasa_sdk/types.py @@ -26,6 +26,8 @@ class TrackerState(TypedDict): active_form: Dict[Text, Any] # the name of the previously executed action or text of e2e action latest_action_name: Optional[Text] + # the current dialogue stack + stack: List[Dict[Text, Any]] class DomainDict(TypedDict): diff --git a/tests/conftest.py b/tests/conftest.py index 1d7466097..4b3c5abc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,16 @@ from sanic import Sanic Sanic.test_mode = True + + +def get_stack(): + dialogue_stack = [ + { + "frame_id": "CP6JP9GQ", + "flow_id": "check_balance", + "step_id": "0_check_balance", + "frame_type": "regular", + "type": "flow", + } + ] + return dialogue_stack diff --git a/tests/test_actions.py b/tests/test_actions.py index e16acbed6..e054d361c 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -43,3 +43,16 @@ def run( domain: DomainDict, ) -> List[Dict[Text, Any]]: raise Exception("test exception") + + +class CustomActionWithDialogueStack(Action): + def name(cls) -> Text: + return "custom_action_with_dialogue_stack" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("stack", tracker.stack)] diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 06e779e80..42f48b5bf 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Text import json import logging import zlib @@ -6,6 +7,7 @@ import rasa_sdk.endpoint as ep from rasa_sdk.events import SlotSet +from tests.conftest import get_stack # noinspection PyTypeChecker app = ep.create_app(None) @@ -23,7 +25,7 @@ def test_server_health_returns_200(): def test_server_list_actions_returns_200(): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 5 + assert len(response.json) == 6 # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS expected = [ @@ -31,6 +33,7 @@ def test_server_list_actions_returns_200(): {"name": "custom_async_action"}, {"name": "custom_action"}, {"name": "custom_action_exception"}, + {"name": "custom_action_with_dialogue_stack"}, # defined in tests/tracing/instrumentation/conftest.py {"name": "mock_validation_action"}, {"name": "mock_form_validation_action"}, @@ -119,6 +122,27 @@ def test_server_webhook_custom_action_encoded_data_returns_200(): assert response.status == 200 +@pytest.mark.parametrize( + "stack_state, dialogue_stack", + [ + ({}, []), + ({"stack": get_stack()}, get_stack()), + ], +) +def test_server_webhook_custom_action_with_dialogue_stack_returns_200( + stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]] +): + data = { + "next_action": "custom_action_with_dialogue_stack", + "tracker": {"sender_id": "1", "conversation_id": "default", **stack_state}, + } + _, response = app.test_client.post("/webhook", data=json.dumps(data)) + events = response.json.get("events") + + assert events == [SlotSet("stack", dialogue_stack)] + assert response.status == 200 + + # ENSURE THIS IS ALWAYS THE LAST TEST FOR OTHER TESTS TO RUN # because the call to sys.exit() terminates pytest process def test_endpoint_exit_for_unknown_actions_package(): diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 8b319e496..9d706793a 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -1,9 +1,10 @@ -from typing import Dict +from typing import Any, Dict, List, Text import pytest from rasa_sdk.events import SlotSet from rasa_sdk.interfaces import Tracker +from tests.conftest import get_stack @pytest.mark.parametrize( @@ -61,3 +62,22 @@ def test_tracker_with_slots(): assert tracker.slots["my slot"] == 5 assert tracker.slots["my slot 2"] is None + + +@pytest.mark.parametrize( + "stack_state, dialogue_stack", + [ + ({}, []), + ({"stack": get_stack()}, get_stack()), + ], +) +def test_stack_in_tracker_state( + stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]] +): + + state = {"events": [], "sender_id": "old", "active_loop": {}, **stack_state} + tracker = Tracker.from_dict(state) + + assert tracker.stack == dialogue_stack + assert tracker.copy().stack == dialogue_stack + assert tracker.current_state()["stack"] == dialogue_stack diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index a0ece6cea..f4ebe5729 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -70,7 +70,6 @@ def test_get_tracer_and_context() -> None: app = ep.create_app(None) request, _ = app.test_client.post("/webhook", data=json.dumps(data)) tracer, context, span_name = get_tracer_and_context(None, request) - print(type(tracer)) assert isinstance(tracer, ProxyTracer) assert span_name == "create_app.webhook"