From a5df827fc8ba3865b1fed7d45731037ed365c831 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Mon, 19 Feb 2024 16:12:32 +0100 Subject: [PATCH 1/4] instrument ValidationAction._extract_validation_events --- rasa_sdk/tracing/config.py | 3 +- .../instrumentation/instrumentation.py | 84 ++++++++++++++++++- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/rasa_sdk/tracing/config.py b/rasa_sdk/tracing/config.py index a0079a0b1..080f2eb7a 100644 --- a/rasa_sdk/tracing/config.py +++ b/rasa_sdk/tracing/config.py @@ -14,7 +14,7 @@ from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config from rasa_sdk.tracing.instrumentation import instrumentation from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk.forms import ValidationAction, FormValidationAction TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk") @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None: tracer_provider=tracer_provider, action_executor_class=ActionExecutor, validation_action_class=ValidationAction, + form_validation_action_class=FormValidationAction, ) diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index fe7918f12..f7c589bb0 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -1,23 +1,30 @@ import functools import inspect +import json import logging from typing import ( Any, Awaitable, Callable, Dict, + List, Optional, Text, Type, TypeVar, + Union, ) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Tracer -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.events import EventType +from rasa_sdk.forms import ValidationAction, FormValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from rasa_sdk.types import DomainDict + # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S: if attr_extractor and should_extract_args else {} ) - if issubclass(self.__class__, ValidationAction): + if issubclass(self.__class__, FormValidationAction): + span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + elif issubclass(self.__class__, ValidationAction): span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" else: span_name = f"{self.__class__.__name__}.{fn.__name__}" @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor) ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction) +FormValidationActionType = TypeVar( + "FormValidationActionType", bound=FormValidationAction +) def instrument( tracer_provider: TracerProvider, action_executor_class: Optional[Type[ActionExecutorType]] = None, validation_action_class: Optional[Type[ValidationActionType]] = None, + form_validation_action_class: Optional[Type[FormValidationActionType]] = None, ) -> None: """Substitute methods to be traced by their traced counterparts. @@ -143,6 +156,8 @@ def instrument( is given, no `ActionExecutor` will be instrumented. :param validation_action_class: The `ValidationAction` to be instrumented. If `None` is given, no `ValidationAction` will be instrumented. + :param form_validation_action_class: The `FormValidationAction` to be instrumented. + If `None` is given, no `FormValidationAction` will be instrumented. """ if action_executor_class is not None and not class_is_instrumented( action_executor_class @@ -172,8 +187,21 @@ def instrument( "run", attribute_extractors.extract_attrs_for_validation_action, ) + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(validation_action_class.__module__), + validation_action_class, + ) mark_class_as_instrumented(validation_action_class) + if form_validation_action_class is not None and not class_is_instrumented( + form_validation_action_class + ): + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(form_validation_action_class.__module__), + form_validation_action_class, + ) + mark_class_as_instrumented(form_validation_action_class) + def _instrument_method( tracer: Tracer, @@ -214,3 +242,53 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None: _mangled_instrumented_boolean_attribute_name(instrumented_class), True, ) + + +def _instrument_validation_action_extract_validation_events( + tracer: Tracer, + validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]], +) -> None: + def tracing_validation_action_extract_validation_events_wrapper( + fn: Callable, + ) -> Callable: + @functools.wraps(fn) + async def wrapper( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> List[EventType]: + with tracer.start_as_current_span( + f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}" + ) as span: + validation_events = await fn(self, dispatcher, tracker, domain) + event_names = [] + slot_names = [] + + if validation_events: + for event in validation_events: + event_names.append(event.get("event")) + if event.get("event") == "slot": + slot_names.append(event.get("name")) + + span.set_attributes( + { + "validation_events": json.dumps( + list(dict.fromkeys(event_names)) + ), + "slots": json.dumps(list(dict.fromkeys(slot_names))), + } + ) + return validation_events + + return wrapper + + validation_action_class._extract_validation_events = ( # type: ignore + tracing_validation_action_extract_validation_events_wrapper( + validation_action_class._extract_validation_events + ) + ) + + logger.debug( + f"Instrumented '{validation_action_class.__name__}._extract_validation_events'." + ) From 0bbc2eba445eda31d839d07a4c62e4c0f6a38e63 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Mon, 19 Feb 2024 17:00:20 +0100 Subject: [PATCH 2/4] add tests --- tests/test_endpoint.py | 3 +- tests/tracing/instrumentation/conftest.py | 29 +++- .../test_form_validation_action.py | 126 ++++++++++++++++++ .../instrumentation/test_validation_action.py | 71 +++++++++- 4 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 tests/tracing/instrumentation/test_form_validation_action.py diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 29678d3a9..06e779e80 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -23,7 +23,7 @@ def test_server_health_returns_200(): def test_server_list_actions_returns_200(): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 4 + assert len(response.json) == 5 # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS expected = [ @@ -33,6 +33,7 @@ def test_server_list_actions_returns_200(): {"name": "custom_action_exception"}, # defined in tests/tracing/instrumentation/conftest.py {"name": "mock_validation_action"}, + {"name": "mock_form_validation_action"}, ] assert response.json == expected diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 796113063..b194c74a5 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -6,7 +6,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.executor import ActionExecutor, CollectingDispatcher -from rasa_sdk.forms import ValidationAction +from rasa_sdk.forms import ValidationAction, FormValidationAction from rasa_sdk.types import ActionCall, DomainDict from rasa_sdk import Tracker @@ -79,3 +79,30 @@ async def run( def name(self) -> Text: return "mock_validation_action" + + +class MockFormValidationAction(FormValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + def name(self) -> Text: + return "mock_form_validation_action" diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py new file mode 100644 index 000000000..6d55d24fb --- /dev/null +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -0,0 +1,126 @@ +from typing import Sequence, Optional + +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from rasa_sdk.tracing.instrumentation import instrumentation +from tests.tracing.instrumentation.conftest import MockFormValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.events import SlotSet + + +@pytest.mark.parametrize( + "events, expected_slots_to_validate", + [ + ([], "[]"), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + ), + ], +) +@pytest.mark.asyncio +async def test_form_validation_action_run( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + expected_slots_to_validate: Optional[str], +) -> None: + component_class = MockFormValidationAction + + instrumentation.instrument( + tracer_provider, + validation_action_class=component_class, + ) + + mock_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_validation_action.run(dispatcher, tracker, {}) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + # includes the child span for `_extract_validation_events` method call + assert num_captured_spans == 2 + + captured_span = captured_spans[-1] + + assert captured_span.name == "FormValidationAction.MockFormValidationAction.run" + + expected_attributes = { + "class_name": component_class.__name__, + "sender_id": "test", + "slots_to_validate": expected_slots_to_validate, + "action_name": "mock_form_validation_action", + } + + assert captured_span.attributes == expected_attributes + + +@pytest.mark.parametrize( + "events, slots, validation_events", + [ + ([], "[]", "[]"), + ( + [SlotSet("name", "Tom")], + '["name"]', + '["slot"]', + ), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + '["slot"]', + ), + ], +) +@pytest.mark.asyncio +async def test_form_validation_action_extract_validation_events( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + slots: Optional[str], + validation_events: Optional[str], +) -> None: + component_class = MockFormValidationAction + + instrumentation.instrument( + tracer_provider, + form_validation_action_class=component_class, + ) + + mock_form_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_form_validation_action._extract_validation_events( + dispatcher, tracker, {} + ) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + expected_span_name = ( + "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + ) + + assert captured_span.name == expected_span_name + + expected_attributes = { + "validation_events": validation_events, + "slots": slots, + } + + assert captured_span.attributes == expected_attributes diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index ce407d8db..267ee578f 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -1,11 +1,14 @@ -from typing import List, Sequence +from typing import List, Optional, Sequence import pytest from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import MockValidationAction +from tests.tracing.instrumentation.conftest import ( + MockValidationAction, + MockFormValidationAction, +) from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher from rasa_sdk.events import SlotSet, EventType @@ -22,7 +25,7 @@ ], ) @pytest.mark.asyncio -async def test_tracing_action_executor_run( +async def test_validation_action_run( tracer_provider: TracerProvider, span_exporter: InMemorySpanExporter, previous_num_captured_spans: int, @@ -61,3 +64,65 @@ async def test_tracing_action_executor_run( } assert captured_span.attributes == expected_attributes + + +@pytest.mark.parametrize( + "events, slots, validation_events", + [ + ([], "[]", "[]"), + ( + [SlotSet("name", "Tom")], + '["name"]', + '["slot"]', + ), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + '["slot"]', + ), + ], +) +@pytest.mark.asyncio +async def test_validation_action_extract_validation_events( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + slots: Optional[str], + validation_events: Optional[str], +) -> None: + component_class = MockFormValidationAction + + instrumentation.instrument( + tracer_provider, + form_validation_action_class=component_class, + ) + + mock_form_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_form_validation_action._extract_validation_events( + dispatcher, tracker, {} + ) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + expected_span_name = ( + "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + ) + + assert captured_span.name == expected_span_name + + expected_attributes = { + "validation_events": validation_events, + "slots": slots, + } + + assert captured_span.attributes == expected_attributes From b1b2421cdb763234bc7d63007bac4d67a84028f3 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Mon, 19 Feb 2024 17:02:57 +0100 Subject: [PATCH 3/4] add changelog entry --- changelog/1077.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1077.improvement.md diff --git a/changelog/1077.improvement.md b/changelog/1077.improvement.md new file mode 100644 index 000000000..bea999c4f --- /dev/null +++ b/changelog/1077.improvement.md @@ -0,0 +1 @@ +Instrument `ValidationAction._extract_validation_events` and `FormValidationAction._extract_validation_events` and extract `validated_events` and `slots` attributes. \ No newline at end of file From 7bafea39183694465d65ae49917f00cf39984773 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 20 Feb 2024 10:12:50 +0100 Subject: [PATCH 4/4] implement PR feedback --- .../instrumentation/attribute_extractors.py | 6 +++--- .../tracing/instrumentation/instrumentation.py | 17 +++++++++++------ tests/tracing/instrumentation/conftest.py | 8 ++++++++ .../test_form_validation_action.py | 6 ++++-- .../instrumentation/test_validation_action.py | 9 +++++---- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 432c04603..ff389438c 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -36,9 +36,9 @@ def extract_attrs_for_action_executor( def extract_attrs_for_validation_action( self: ValidationAction, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> Dict[Text, Any]: """Extract the attributes for `ValidationAction.run`. diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index f7c589bb0..47f896475 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -254,13 +254,18 @@ def tracing_validation_action_extract_validation_events_wrapper( @functools.wraps(fn) async def wrapper( self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> List[EventType]: - with tracer.start_as_current_span( - f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}" - ) as span: + if issubclass(self.__class__, FormValidationAction): + span_name = ( + f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + ) + else: + span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" + + with tracer.start_as_current_span(span_name) as span: validation_events = await fn(self, dispatcher, tracker, domain) event_names = [] slot_names = [] diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index b194c74a5..5b7570f4b 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -80,6 +80,14 @@ async def run( def name(self) -> Text: return "mock_validation_action" + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + class MockFormValidationAction(FormValidationAction): def __init__(self) -> None: diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index 6d55d24fb..66374c6e3 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -8,13 +8,14 @@ from tests.tracing.instrumentation.conftest import MockFormValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet +from rasa_sdk.events import ActionExecuted, SlotSet @pytest.mark.parametrize( "events, expected_slots_to_validate", [ ([], "[]"), + ([ActionExecuted("my_form")], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], '["name", "address"]', @@ -68,6 +69,7 @@ async def test_form_validation_action_run( "events, slots, validation_events", [ ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), ( [SlotSet("name", "Tom")], '["name"]', @@ -113,7 +115,7 @@ async def test_form_validation_action_extract_validation_events( captured_span = captured_spans[-1] expected_span_name = ( - "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + "FormValidationAction.MockFormValidationAction._extract_validation_events" ) assert captured_span.name == expected_span_name diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 267ee578f..860d5032d 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -7,17 +7,17 @@ from rasa_sdk.tracing.instrumentation import instrumentation from tests.tracing.instrumentation.conftest import ( MockValidationAction, - MockFormValidationAction, ) from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet, EventType +from rasa_sdk.events import SlotSet, EventType, ActionExecuted @pytest.mark.parametrize( "events, expected_slots_to_validate", [ ([], "[]"), + ([ActionExecuted("my_form")], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], '["name", "address"]', @@ -70,6 +70,7 @@ async def test_validation_action_run( "events, slots, validation_events", [ ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), ( [SlotSet("name", "Tom")], '["name"]', @@ -91,7 +92,7 @@ async def test_validation_action_extract_validation_events( slots: Optional[str], validation_events: Optional[str], ) -> None: - component_class = MockFormValidationAction + component_class = MockValidationAction instrumentation.instrument( tracer_provider, @@ -115,7 +116,7 @@ async def test_validation_action_extract_validation_events( captured_span = captured_spans[-1] expected_span_name = ( - "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + "ValidationAction.MockValidationAction._extract_validation_events" ) assert captured_span.name == expected_span_name