Skip to content

Commit

Permalink
Merge pull request #1077 from RasaHQ/ATO-2103-instrument-FormValidati…
Browse files Browse the repository at this point in the history
…onAction._extract_validation_events

[ATO-2103] Instrument FormValidationAction._extract_validation_events
  • Loading branch information
Tawakalt authored Feb 20, 2024
2 parents 2f3b8b7 + 7bafea3 commit 93104bf
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 13 deletions.
1 change: 1 addition & 0 deletions changelog/1077.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ValidationAction._extract_validation_events` and `FormValidationAction._extract_validation_events` and extract `validated_events` and `slots` attributes.
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.forms import ValidationAction, FormValidationAction

TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

Expand All @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
validation_action_class=ValidationAction,
form_validation_action_class=FormValidationAction,
)


Expand Down
6 changes: 3 additions & 3 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def extract_attrs_for_action_executor(

def extract_attrs_for_validation_action(
self: ValidationAction,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> Dict[Text, Any]:
"""Extract the attributes for `ValidationAction.run`.
Expand Down
89 changes: 86 additions & 3 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import functools
import inspect
import json
import logging
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Text,
Type,
TypeVar,
Union,
)

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.events import EventType
from rasa_sdk.forms import ValidationAction, FormValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister
from rasa_sdk.types import DomainDict


# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
Expand Down Expand Up @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
if attr_extractor and should_extract_args
else {}
)
if issubclass(self.__class__, ValidationAction):
if issubclass(self.__class__, FormValidationAction):
span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
elif issubclass(self.__class__, ValidationAction):
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"
else:
span_name = f"{self.__class__.__name__}.{fn.__name__}"
Expand Down Expand Up @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S:

ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)
ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction)
FormValidationActionType = TypeVar(
"FormValidationActionType", bound=FormValidationAction
)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
validation_action_class: Optional[Type[ValidationActionType]] = None,
form_validation_action_class: Optional[Type[FormValidationActionType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.
Expand All @@ -143,6 +156,8 @@ def instrument(
is given, no `ActionExecutor` will be instrumented.
:param validation_action_class: The `ValidationAction` to be instrumented. If `None`
is given, no `ValidationAction` will be instrumented.
:param form_validation_action_class: The `FormValidationAction` to be instrumented.
If `None` is given, no `FormValidationAction` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
Expand Down Expand Up @@ -172,8 +187,21 @@ def instrument(
"run",
attribute_extractors.extract_attrs_for_validation_action,
)
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(validation_action_class.__module__),
validation_action_class,
)
mark_class_as_instrumented(validation_action_class)

if form_validation_action_class is not None and not class_is_instrumented(
form_validation_action_class
):
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(form_validation_action_class.__module__),
form_validation_action_class,
)
mark_class_as_instrumented(form_validation_action_class)


def _instrument_method(
tracer: Tracer,
Expand Down Expand Up @@ -214,3 +242,58 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None:
_mangled_instrumented_boolean_attribute_name(instrumented_class),
True,
)


def _instrument_validation_action_extract_validation_events(
tracer: Tracer,
validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]],
) -> None:
def tracing_validation_action_extract_validation_events_wrapper(
fn: Callable,
) -> Callable:
@functools.wraps(fn)
async def wrapper(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[EventType]:
if issubclass(self.__class__, FormValidationAction):
span_name = (
f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
)
else:
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"

with tracer.start_as_current_span(span_name) as span:
validation_events = await fn(self, dispatcher, tracker, domain)
event_names = []
slot_names = []

if validation_events:
for event in validation_events:
event_names.append(event.get("event"))
if event.get("event") == "slot":
slot_names.append(event.get("name"))

span.set_attributes(
{
"validation_events": json.dumps(
list(dict.fromkeys(event_names))
),
"slots": json.dumps(list(dict.fromkeys(slot_names))),
}
)
return validation_events

return wrapper

validation_action_class._extract_validation_events = ( # type: ignore
tracing_validation_action_extract_validation_events_wrapper(
validation_action_class._extract_validation_events
)
)

logger.debug(
f"Instrumented '{validation_action_class.__name__}._extract_validation_events'."
)
3 changes: 2 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 4
assert len(response.json) == 5

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
Expand All @@ -33,6 +33,7 @@ def test_server_list_actions_returns_200():
{"name": "custom_action_exception"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
]
assert response.json == expected

Expand Down
37 changes: 36 additions & 1 deletion tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.forms import ValidationAction, FormValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker

Expand Down Expand Up @@ -79,3 +79,38 @@ async def run(

def name(self) -> Text:
return "mock_validation_action"

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events


class MockFormValidationAction(FormValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events

def name(self) -> Text:
return "mock_form_validation_action"
128 changes: 128 additions & 0 deletions tests/tracing/instrumentation/test_form_validation_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Sequence, Optional

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockFormValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import ActionExecuted, SlotSet


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
([ActionExecuted("my_form")], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
),
],
)
@pytest.mark.asyncio
async def test_form_validation_action_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[str],
expected_slots_to_validate: Optional[str],
) -> None:
component_class = MockFormValidationAction

instrumentation.instrument(
tracer_provider,
validation_action_class=component_class,
)

mock_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_validation_action.run(dispatcher, tracker, {})

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
# includes the child span for `_extract_validation_events` method call
assert num_captured_spans == 2

captured_span = captured_spans[-1]

assert captured_span.name == "FormValidationAction.MockFormValidationAction.run"

expected_attributes = {
"class_name": component_class.__name__,
"sender_id": "test",
"slots_to_validate": expected_slots_to_validate,
"action_name": "mock_form_validation_action",
}

assert captured_span.attributes == expected_attributes


@pytest.mark.parametrize(
"events, slots, validation_events",
[
([], "[]", "[]"),
([ActionExecuted("my_form")], "[]", '["action"]'),
(
[SlotSet("name", "Tom")],
'["name"]',
'["slot"]',
),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
'["slot"]',
),
],
)
@pytest.mark.asyncio
async def test_form_validation_action_extract_validation_events(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[str],
slots: Optional[str],
validation_events: Optional[str],
) -> None:
component_class = MockFormValidationAction

instrumentation.instrument(
tracer_provider,
form_validation_action_class=component_class,
)

mock_form_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_form_validation_action._extract_validation_events(
dispatcher, tracker, {}
)

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]
expected_span_name = (
"FormValidationAction.MockFormValidationAction._extract_validation_events"
)

assert captured_span.name == expected_span_name

expected_attributes = {
"validation_events": validation_events,
"slots": slots,
}

assert captured_span.attributes == expected_attributes
Loading

0 comments on commit 93104bf

Please sign in to comment.