From c88cf098f90b5ea574f8e16cc66940a6ce14923c Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Thu, 15 Feb 2024 12:00:28 +0100 Subject: [PATCH 1/5] instrument ActionExecutor._create_api_response --- .../instrumentation/attribute_extractors.py | 16 +++++++++++- .../instrumentation/instrumentation.py | 25 +++++++++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 80183ac3a..2d03030f3 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -1,6 +1,6 @@ import json -from typing import Any, Dict, Text +from typing import Any, Dict, Text, List from rasa_sdk.executor import ActionExecutor, CollectingDispatcher from rasa_sdk.forms import ValidationAction from rasa_sdk.types import ActionCall, DomainDict @@ -56,3 +56,17 @@ def extract_attrs_for_validation_action( "slots_to_validate": json.dumps(list(slots_to_validate)), "action_name": self.name(), } + + +def extract_attrs_for_create_api_response( + events: List[Dict[Text, Any]], + messages: List[Dict[Text, Any]], +) -> Dict[Text, Any]: + """Extract the attributes for `ActionExecutor.run`. + + :param events: A list of events. + :param messsages: A list of bot responses. + :return: A dictionary containing the attributes. + """ + slot_names = [event.get("name") for event in events if event.get("event") == "slot"] + return {"slots": json.dumps(slot_names)} diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index f28fd2b40..23568d08b 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -99,14 +99,23 @@ def traceable( @functools.wraps(fn) def wrapper(self: T, *args: Any, **kwargs: Any) -> S: - attrs = ( - attr_extractor(self, *args, **kwargs) - if attr_extractor and should_extract_args - else {} - ) + if fn.__name__ == "_create_api_response": + attrs = ( + attr_extractor(*args, **kwargs) + if attr_extractor and should_extract_args + else {} + ) + else: + attrs = ( + attr_extractor(self, *args, **kwargs) + if attr_extractor and should_extract_args + else {} + ) with tracer.start_as_current_span( f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs ): + if fn.__name__ == "_create_api_response": + return fn(*args, **kwargs) return fn(self, *args, **kwargs) return wrapper @@ -140,6 +149,12 @@ def instrument( "run", attribute_extractors.extract_attrs_for_action_executor, ) + _instrument_method( + tracer, + action_executor_class, + "_create_api_response", + attribute_extractors.extract_attrs_for_create_api_response, + ) mark_class_as_instrumented(action_executor_class) ActionExecutorTracerRegister().register_tracer(tracer) From 85dcf9b28afd46458f981100113aaaadff9834d4 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Thu, 15 Feb 2024 12:01:47 +0100 Subject: [PATCH 2/5] add test --- tests/tracing/instrumentation/conftest.py | 8 +++- .../instrumentation/test_action_executor.py | 46 ++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 6435c1859..796113063 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -1,5 +1,5 @@ import pytest -from typing import Text +from typing import Any, Dict, Text, List from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -47,6 +47,12 @@ def fail_if_undefined(self, method_name: Text) -> None: async def run(self, action_call: ActionCall) -> None: pass + @staticmethod + def _create_api_response( + events: List[Dict[Text, Any]], messages: List[Dict[Text, Any]] + ) -> None: + pass + class MockValidationAction(ValidationAction): def __init__(self) -> None: diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index 05ab9f44c..c9c1bfd21 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -1,11 +1,12 @@ import pytest -from typing import Any, Dict, Sequence, Text, Optional +from typing import Any, Dict, Sequence, Text, Optional, List from unittest.mock import Mock from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry import trace +from rasa_sdk.events import ActionExecuted, SlotSet from rasa_sdk.tracing.instrumentation import instrumentation from tests.tracing.instrumentation.conftest import MockActionExecutor @@ -88,3 +89,46 @@ def test_instrument_action_executor_run_registers_tracer( assert tracer is not None assert tracer == mock_tracer + + +@pytest.mark.parametrize( + "events, expected", + [ + ([], {"slots": "[]"}), + ([ActionExecuted("my_form")], {"slots": "[]"}), + ( + [ActionExecuted("my_form"), SlotSet("my_slot", "some_value")], + {"slots": '["my_slot"]'}, + ), + ], +) +def test_tracing_action_executor_create_api_response( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[List], + expected: Dict[Text, Any], +) -> None: + component_class = MockActionExecutor + + instrumentation.instrument( + tracer_provider, + action_executor_class=component_class, + ) + + mock_action_executor = component_class() + + mock_action_executor._create_api_response(events, [{"text": "hello"}]) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + + assert captured_span.name == "MockActionExecutor._create_api_response" + + assert captured_span.attributes == expected From 7f106b1da1c1d979f3d6d1585a304e3cbd51c566 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Thu, 15 Feb 2024 12:42:23 +0100 Subject: [PATCH 3/5] add changelog entry --- changelog/1076.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1076.improvement.md diff --git a/changelog/1076.improvement.md b/changelog/1076.improvement.md new file mode 100644 index 000000000..a4cf6b6a4 --- /dev/null +++ b/changelog/1076.improvement.md @@ -0,0 +1 @@ +Instrument `ActionExecutor._create_api_response` and extract `slots` attribute. \ No newline at end of file From 005cfa0471baa0fba10ecc4ad80d3018cc9de214 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Thu, 15 Feb 2024 17:41:08 +0100 Subject: [PATCH 4/5] implement pr feedback --- .../instrumentation/attribute_extractors.py | 21 ++++++++-- .../instrumentation/instrumentation.py | 11 +++-- .../instrumentation/test_action_executor.py | 41 ++++++++++++++++--- 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 2d03030f3..432c04603 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -58,7 +58,7 @@ def extract_attrs_for_validation_action( } -def extract_attrs_for_create_api_response( +def extract_attrs_for_action_executor_create_api_response( events: List[Dict[Text, Any]], messages: List[Dict[Text, Any]], ) -> Dict[Text, Any]: @@ -68,5 +68,20 @@ def extract_attrs_for_create_api_response( :param messsages: A list of bot responses. :return: A dictionary containing the attributes. """ - slot_names = [event.get("name") for event in events if event.get("event") == "slot"] - return {"slots": json.dumps(slot_names)} + event_names = [] + slot_names = [] + + for event in events: + event_names.append(event.get("event")) + if event.get("event") == "slot" and event.get("name") != "requested_slot": + slot_names.append(event.get("name")) + utters = [ + message.get("response") for message in messages if message.get("response") + ] + + return { + "events": json.dumps(list(dict.fromkeys(event_names))), + "slots": json.dumps(list(dict.fromkeys(slot_names))), + "utters": json.dumps(utters), + "message_count": len(messages), + } diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index 23568d08b..fe7918f12 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -99,7 +99,9 @@ def traceable( @functools.wraps(fn) def wrapper(self: T, *args: Any, **kwargs: Any) -> S: - if fn.__name__ == "_create_api_response": + # the conditional statement is needed because + # _create_api_response is a static method + if isinstance(self, ActionExecutor) and fn.__name__ == "_create_api_response": attrs = ( attr_extractor(*args, **kwargs) if attr_extractor and should_extract_args @@ -114,7 +116,10 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: with tracer.start_as_current_span( f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs ): - if fn.__name__ == "_create_api_response": + if ( + isinstance(self, ActionExecutor) + and fn.__name__ == "_create_api_response" + ): return fn(*args, **kwargs) return fn(self, *args, **kwargs) @@ -153,7 +158,7 @@ def instrument( tracer, action_executor_class, "_create_api_response", - attribute_extractors.extract_attrs_for_create_api_response, + attribute_extractors.extract_attrs_for_action_executor_create_api_response, ) mark_class_as_instrumented(action_executor_class) ActionExecutorTracerRegister().register_tracer(tracer) diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index c9c1bfd21..5d593ff65 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -13,6 +13,16 @@ from rasa_sdk.types import ActionCall from rasa_sdk import Tracker from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from rasa_sdk.executor import CollectingDispatcher + + +dispatcher1 = CollectingDispatcher() +dispatcher1.utter_message(template="utter_greet") +dispatcher2 = CollectingDispatcher() +dispatcher2.utter_message("Hello") +dispatcher3 = CollectingDispatcher() +dispatcher3.utter_message("Hello") +dispatcher3.utter_message(template="utter_greet") @pytest.mark.parametrize( @@ -92,13 +102,33 @@ def test_instrument_action_executor_run_registers_tracer( @pytest.mark.parametrize( - "events, expected", + "events, messages, expected", [ - ([], {"slots": "[]"}), - ([ActionExecuted("my_form")], {"slots": "[]"}), + ([], [], {"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0}), + ( + [ActionExecuted("my_form")], + dispatcher2.messages, + {"events": '["action"]', "slots": "[]", "utters": "[]", "message_count": 1}, + ), ( [ActionExecuted("my_form"), SlotSet("my_slot", "some_value")], - {"slots": '["my_slot"]'}, + dispatcher1.messages, + { + "events": '["action", "slot"]', + "slots": '["my_slot"]', + "utters": '["utter_greet"]', + "message_count": 1, + }, + ), + ( + [SlotSet("my_slot", "some_value")], + dispatcher3.messages, + { + "events": '["slot"]', + "slots": '["my_slot"]', + "utters": '["utter_greet"]', + "message_count": 2, + }, ), ], ) @@ -107,6 +137,7 @@ def test_tracing_action_executor_create_api_response( span_exporter: InMemorySpanExporter, previous_num_captured_spans: int, events: Optional[List], + messages: Optional[List], expected: Dict[Text, Any], ) -> None: component_class = MockActionExecutor @@ -118,7 +149,7 @@ def test_tracing_action_executor_create_api_response( mock_action_executor = component_class() - mock_action_executor._create_api_response(events, [{"text": "hello"}]) + mock_action_executor._create_api_response(events, messages) captured_spans: Sequence[ ReadableSpan From 32c3b23e5ca89a6eb95b42bd7d641c6b0fab7e04 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 16 Feb 2024 12:02:33 +0100 Subject: [PATCH 5/5] update test and changelog entry --- changelog/1076.improvement.md | 2 +- .../instrumentation/test_action_executor.py | 50 +++++++++++++------ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/changelog/1076.improvement.md b/changelog/1076.improvement.md index a4cf6b6a4..1e3250832 100644 --- a/changelog/1076.improvement.md +++ b/changelog/1076.improvement.md @@ -1 +1 @@ -Instrument `ActionExecutor._create_api_response` and extract `slots` attribute. \ No newline at end of file +Instrument `ActionExecutor._create_api_response` and extract `slots`, `events`, `utters` and `message_count` attributes. \ No newline at end of file diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index 5d593ff65..95598798a 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -1,6 +1,6 @@ import pytest -from typing import Any, Dict, Sequence, Text, Optional, List +from typing import Any, Dict, Sequence, Text, Optional, List, Callable from unittest.mock import Mock from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan, TracerProvider @@ -16,13 +16,28 @@ from rasa_sdk.executor import CollectingDispatcher -dispatcher1 = CollectingDispatcher() -dispatcher1.utter_message(template="utter_greet") -dispatcher2 = CollectingDispatcher() -dispatcher2.utter_message("Hello") -dispatcher3 = CollectingDispatcher() -dispatcher3.utter_message("Hello") -dispatcher3.utter_message(template="utter_greet") +def get_dispatcher0(): + dispatcher = CollectingDispatcher() + return dispatcher + + +def get_dispatcher1(): + dispatcher = CollectingDispatcher() + dispatcher.utter_message(template="utter_greet") + return dispatcher + + +def get_dispatcher2(): + dispatcher = CollectingDispatcher() + dispatcher.utter_message("Hello") + return dispatcher + + +def get_dispatcher3(): + dispatcher = CollectingDispatcher() + dispatcher.utter_message("Hello") + dispatcher.utter_message(template="utter_greet") + return dispatcher @pytest.mark.parametrize( @@ -102,17 +117,21 @@ def test_instrument_action_executor_run_registers_tracer( @pytest.mark.parametrize( - "events, messages, expected", + "events, get_dispatcher, expected", [ - ([], [], {"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0}), + ( + [], + get_dispatcher0, + {"events": "[]", "slots": "[]", "utters": "[]", "message_count": 0}, + ), ( [ActionExecuted("my_form")], - dispatcher2.messages, + get_dispatcher2, {"events": '["action"]', "slots": "[]", "utters": "[]", "message_count": 1}, ), ( [ActionExecuted("my_form"), SlotSet("my_slot", "some_value")], - dispatcher1.messages, + get_dispatcher1, { "events": '["action", "slot"]', "slots": '["my_slot"]', @@ -122,7 +141,7 @@ def test_instrument_action_executor_run_registers_tracer( ), ( [SlotSet("my_slot", "some_value")], - dispatcher3.messages, + get_dispatcher3, { "events": '["slot"]', "slots": '["my_slot"]', @@ -137,7 +156,7 @@ def test_tracing_action_executor_create_api_response( span_exporter: InMemorySpanExporter, previous_num_captured_spans: int, events: Optional[List], - messages: Optional[List], + get_dispatcher: Callable, expected: Dict[Text, Any], ) -> None: component_class = MockActionExecutor @@ -149,7 +168,8 @@ def test_tracing_action_executor_create_api_response( mock_action_executor = component_class() - mock_action_executor._create_api_response(events, messages) + dispatcher = get_dispatcher() + mock_action_executor._create_api_response(events, dispatcher.messages) captured_spans: Sequence[ ReadableSpan