Skip to content

Commit

Permalink
implement PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 20, 2024
1 parent b1b2421 commit 630bdfb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
17 changes: 11 additions & 6 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,18 @@ def tracing_validation_action_extract_validation_events_wrapper(
@functools.wraps(fn)
async def wrapper(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[EventType]:
with tracer.start_as_current_span(
f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}"
) as span:
if issubclass(self.__class__, FormValidationAction):
span_name = (
f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
)
else:
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"

with tracer.start_as_current_span(span_name) as span:
validation_events = await fn(self, dispatcher, tracker, domain)
event_names = []
slot_names = []
Expand Down
8 changes: 8 additions & 0 deletions tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ async def run(
def name(self) -> Text:
return "mock_validation_action"

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events


class MockFormValidationAction(FormValidationAction):
def __init__(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions tests/tracing/instrumentation/test_form_validation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from tests.tracing.instrumentation.conftest import MockFormValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import SlotSet
from rasa_sdk.events import ActionExecuted, SlotSet


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
([ActionExecuted("my_form")], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
Expand Down Expand Up @@ -68,6 +69,7 @@ async def test_form_validation_action_run(
"events, slots, validation_events",
[
([], "[]", "[]"),
([ActionExecuted("my_form")], "[]", '["action"]'),
(
[SlotSet("name", "Tom")],
'["name"]',
Expand Down Expand Up @@ -113,7 +115,7 @@ async def test_form_validation_action_extract_validation_events(

captured_span = captured_spans[-1]
expected_span_name = (
"MockFormValidationAction.MockFormValidationAction._extract_validation_events"
"FormValidationAction.MockFormValidationAction._extract_validation_events"
)

assert captured_span.name == expected_span_name
Expand Down
9 changes: 5 additions & 4 deletions tests/tracing/instrumentation/test_validation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import (
MockValidationAction,
MockFormValidationAction,
)
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import SlotSet, EventType
from rasa_sdk.events import SlotSet, EventType, ActionExecuted


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
([ActionExecuted("my_form")], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
Expand Down Expand Up @@ -70,6 +70,7 @@ async def test_validation_action_run(
"events, slots, validation_events",
[
([], "[]", "[]"),
([ActionExecuted("my_form")], "[]", '["action"]'),
(
[SlotSet("name", "Tom")],
'["name"]',
Expand All @@ -91,7 +92,7 @@ async def test_validation_action_extract_validation_events(
slots: Optional[str],
validation_events: Optional[str],
) -> None:
component_class = MockFormValidationAction
component_class = MockValidationAction

instrumentation.instrument(
tracer_provider,
Expand All @@ -115,7 +116,7 @@ async def test_validation_action_extract_validation_events(

captured_span = captured_spans[-1]
expected_span_name = (
"MockFormValidationAction.MockFormValidationAction._extract_validation_events"
"ValidationAction.MockValidationAction._extract_validation_events"
)

assert captured_span.name == expected_span_name
Expand Down

0 comments on commit 630bdfb

Please sign in to comment.