Skip to content

Commit

Permalink
update test and changelog entry
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 16, 2024
1 parent 005cfa0 commit 32c3b23
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
2 changes: 1 addition & 1 deletion changelog/1076.improvement.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Instrument `ActionExecutor._create_api_response` and extract `slots` attribute.
Instrument `ActionExecutor._create_api_response` and extract `slots`, `events`, `utters` and `message_count` attributes.
50 changes: 35 additions & 15 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from typing import Any, Dict, Sequence, Text, Optional, List
from typing import Any, Dict, Sequence, Text, Optional, List, Callable
from unittest.mock import Mock
from pytest import MonkeyPatch
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
Expand All @@ -16,13 +16,28 @@
from rasa_sdk.executor import CollectingDispatcher


dispatcher1 = CollectingDispatcher()
dispatcher1.utter_message(template="utter_greet")
dispatcher2 = CollectingDispatcher()
dispatcher2.utter_message("Hello")
dispatcher3 = CollectingDispatcher()
dispatcher3.utter_message("Hello")
dispatcher3.utter_message(template="utter_greet")
def get_dispatcher0():
dispatcher = CollectingDispatcher()
return dispatcher


def get_dispatcher1():
dispatcher = CollectingDispatcher()
dispatcher.utter_message(template="utter_greet")
return dispatcher


def get_dispatcher2():
dispatcher = CollectingDispatcher()
dispatcher.utter_message("Hello")
return dispatcher


def get_dispatcher3():
dispatcher = CollectingDispatcher()
dispatcher.utter_message("Hello")
dispatcher.utter_message(template="utter_greet")
return dispatcher


@pytest.mark.parametrize(
Expand Down Expand Up @@ -102,17 +117,21 @@ def test_instrument_action_executor_run_registers_tracer(


@pytest.mark.parametrize(
"events, messages, expected",
"events, get_dispatcher, expected",
[
([], [], {"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0}),
(
[],
get_dispatcher0,
{"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0},
),
(
[ActionExecuted("my_form")],
dispatcher2.messages,
get_dispatcher2,
{"events": '["action"]', "slots": "[]", "utters": "[]", "message_count": 1},
),
(
[ActionExecuted("my_form"), SlotSet("my_slot", "some_value")],
dispatcher1.messages,
get_dispatcher1,
{
"events": '["action", "slot"]',
"slots": '["my_slot"]',
Expand All @@ -122,7 +141,7 @@ def test_instrument_action_executor_run_registers_tracer(
),
(
[SlotSet("my_slot", "some_value")],
dispatcher3.messages,
get_dispatcher3,
{
"events": '["slot"]',
"slots": '["my_slot"]',
Expand All @@ -137,7 +156,7 @@ def test_tracing_action_executor_create_api_response(
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[List],
messages: Optional[List],
get_dispatcher: Callable,
expected: Dict[Text, Any],
) -> None:
component_class = MockActionExecutor
Expand All @@ -149,7 +168,8 @@ def test_tracing_action_executor_create_api_response(

mock_action_executor = component_class()

mock_action_executor._create_api_response(events, messages)
dispatcher = get_dispatcher()
mock_action_executor._create_api_response(events, dispatcher.messages)

captured_spans: Sequence[
ReadableSpan
Expand Down

0 comments on commit 32c3b23

Please sign in to comment.