From 630bdfb4142df48b2e94ac0c6ebc1a80f85ac852 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 20 Feb 2024 10:12:50 +0100 Subject: [PATCH] implement PR feedback --- .../tracing/instrumentation/instrumentation.py | 17 +++++++++++------ tests/tracing/instrumentation/conftest.py | 8 ++++++++ .../test_form_validation_action.py | 6 ++++-- .../instrumentation/test_validation_action.py | 9 +++++---- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index f7c589bb0..47f896475 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -254,13 +254,18 @@ def tracing_validation_action_extract_validation_events_wrapper( @functools.wraps(fn) async def wrapper( self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> List[EventType]: - with tracer.start_as_current_span( - f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}" - ) as span: + if issubclass(self.__class__, FormValidationAction): + span_name = ( + f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + ) + else: + span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" + + with tracer.start_as_current_span(span_name) as span: validation_events = await fn(self, dispatcher, tracker, domain) event_names = [] slot_names = [] diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index b194c74a5..5b7570f4b 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -80,6 +80,14 @@ async def run( def name(self) -> Text: return "mock_validation_action" + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + class MockFormValidationAction(FormValidationAction): def __init__(self) -> None: diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index 6d55d24fb..66374c6e3 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -8,13 +8,14 @@ from tests.tracing.instrumentation.conftest import MockFormValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet +from rasa_sdk.events import ActionExecuted, SlotSet @pytest.mark.parametrize( "events, expected_slots_to_validate", [ ([], "[]"), + ([ActionExecuted("my_form")], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], '["name", "address"]', @@ -68,6 +69,7 @@ async def test_form_validation_action_run( "events, slots, validation_events", [ ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), ( [SlotSet("name", "Tom")], '["name"]', @@ -113,7 +115,7 @@ async def test_form_validation_action_extract_validation_events( captured_span = captured_spans[-1] expected_span_name = ( - "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + "FormValidationAction.MockFormValidationAction._extract_validation_events" ) assert captured_span.name == expected_span_name diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 267ee578f..860d5032d 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -7,17 +7,17 @@ from rasa_sdk.tracing.instrumentation import instrumentation from tests.tracing.instrumentation.conftest import ( MockValidationAction, - MockFormValidationAction, ) from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet, EventType +from rasa_sdk.events import SlotSet, EventType, ActionExecuted @pytest.mark.parametrize( "events, expected_slots_to_validate", [ ([], "[]"), + ([ActionExecuted("my_form")], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], '["name", "address"]', @@ -70,6 +70,7 @@ async def test_validation_action_run( "events, slots, validation_events", [ ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), ( [SlotSet("name", "Tom")], '["name"]', @@ -91,7 +92,7 @@ async def test_validation_action_extract_validation_events( slots: Optional[str], validation_events: Optional[str], ) -> None: - component_class = MockFormValidationAction + component_class = MockValidationAction instrumentation.instrument( tracer_provider, @@ -115,7 +116,7 @@ async def test_validation_action_extract_validation_events( captured_span = captured_spans[-1] expected_span_name = ( - "MockFormValidationAction.MockFormValidationAction._extract_validation_events" + "ValidationAction.MockValidationAction._extract_validation_events" ) assert captured_span.name == expected_span_name