Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-2188] Allow access to Dialogue stack from custom actions #1078

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1078.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `stack` property to the `Tracker` class which corresponds to the dialogue stack.
ancalita marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def from_dict(cls, state: "TrackerState") -> "Tracker":
state.get("followup_action"),
state.get("active_loop", state.get("active_form", {})),
state.get("latest_action_name"),
state.get("stack", []),
)

def __init__(
Expand All @@ -45,6 +46,7 @@ def __init__(
followup_action: Optional[Text],
active_loop: Dict[Text, Any],
latest_action_name: Optional[Text],
stack: List[Dict[Text, Any]] = [],
) -> None:
"""Initialize the tracker."""

Expand All @@ -66,6 +68,7 @@ def __init__(
self.latest_message = latest_message if latest_message else {}
self.active_loop = active_loop
self.latest_action_name = latest_action_name
self.stack = stack
ancalita marked this conversation as resolved.
Show resolved Hide resolved

@property
def active_form(self) -> Dict[Text, Any]:
Expand Down Expand Up @@ -93,6 +96,7 @@ def current_state(self) -> Dict[Text, Any]:
"latest_input_channel": self.get_latest_input_channel(),
"active_loop": self.active_loop,
"latest_action_name": self.latest_action_name,
"stack": self.stack,
}

def current_slot_values(self) -> Dict[Text, Any]:
Expand Down Expand Up @@ -196,6 +200,7 @@ def copy(self) -> "Tracker":
self.followup_action,
self.active_loop,
self.latest_action_name,
self.stack,
)

def last_executed_action_has(self, name: Text, skip: int = 0) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class TrackerState(TypedDict):
active_form: Dict[Text, Any]
# the name of the previously executed action or text of e2e action
latest_action_name: Optional[Text]
# the current dialogue stack
stack: List[Dict[Text, Any]]


class DomainDict(TypedDict):
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
from sanic import Sanic

Sanic.test_mode = True


def get_stack():
dialogue_stack = [
{
"frame_id": "CP6JP9GQ",
"flow_id": "check_balance",
"step_id": "0_check_balance",
"frame_type": "regular",
"type": "flow",
}
]
return dialogue_stack
13 changes: 13 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ def run(
domain: DomainDict,
) -> List[Dict[Text, Any]]:
raise Exception("test exception")


class CustomActionWithDialogueStack(Action):
def name(cls) -> Text:
return "custom_action_with_dialogue_stack"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("stack", tracker.stack)]
26 changes: 25 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, List, Text
import json
import logging
import zlib
Expand All @@ -6,6 +7,7 @@

import rasa_sdk.endpoint as ep
from rasa_sdk.events import SlotSet
from tests.conftest import get_stack

# noinspection PyTypeChecker
app = ep.create_app(None)
Expand All @@ -23,14 +25,15 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 5
assert len(response.json) == 6

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
# defined in tests/test_actions.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
{"name": "custom_action_with_dialogue_stack"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
Expand Down Expand Up @@ -119,6 +122,27 @@ def test_server_webhook_custom_action_encoded_data_returns_200():
assert response.status == 200


@pytest.mark.parametrize(
"stack_state, dialogue_stack",
[
({}, []),
({"stack": get_stack()}, get_stack()),
],
)
def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
}
_, response = app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")

assert events == [SlotSet("stack", dialogue_stack)]
assert response.status == 200


# ENSURE THIS IS ALWAYS THE LAST TEST FOR OTHER TESTS TO RUN
# because the call to sys.exit() terminates pytest process
def test_endpoint_exit_for_unknown_actions_package():
Expand Down
22 changes: 21 additions & 1 deletion tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict
from typing import Any, Dict, List, Text

import pytest
from rasa_sdk.events import SlotSet

from rasa_sdk.interfaces import Tracker
from tests.conftest import get_stack


@pytest.mark.parametrize(
Expand Down Expand Up @@ -61,3 +62,22 @@ def test_tracker_with_slots():

assert tracker.slots["my slot"] == 5
assert tracker.slots["my slot 2"] is None


@pytest.mark.parametrize(
"stack_state, dialogue_stack",
[
({}, []),
({"stack": get_stack()}, get_stack()),
],
)
def test_stack_in_tracker_state(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):

state = {"events": [], "sender_id": "old", "active_loop": {}, **stack_state}
tracker = Tracker.from_dict(state)

assert tracker.stack == dialogue_stack
assert tracker.copy().stack == dialogue_stack
assert tracker.current_state()["stack"] == dialogue_stack
1 change: 0 additions & 1 deletion tests/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_get_tracer_and_context() -> None:
app = ep.create_app(None)
request, _ = app.test_client.post("/webhook", data=json.dumps(data))
tracer, context, span_name = get_tracer_and_context(None, request)
print(type(tracer))
Tawakalt marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(tracer, ProxyTracer)
assert span_name == "create_app.webhook"
Expand Down
Loading