Skip to content

Commit

Permalink
instrument ValidationAction._extract_validation_events
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 19, 2024
1 parent 2f3b8b7 commit a5df827
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.forms import ValidationAction, FormValidationAction

TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

Expand All @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
validation_action_class=ValidationAction,
form_validation_action_class=FormValidationAction,
)


Expand Down
84 changes: 81 additions & 3 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import functools
import inspect
import json
import logging
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Text,
Type,
TypeVar,
Union,
)

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.events import EventType
from rasa_sdk.forms import ValidationAction, FormValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister
from rasa_sdk.types import DomainDict


# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
Expand Down Expand Up @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
if attr_extractor and should_extract_args
else {}
)
if issubclass(self.__class__, ValidationAction):
if issubclass(self.__class__, FormValidationAction):
span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
elif issubclass(self.__class__, ValidationAction):
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"
else:
span_name = f"{self.__class__.__name__}.{fn.__name__}"
Expand Down Expand Up @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S:

ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)
ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction)
FormValidationActionType = TypeVar(
"FormValidationActionType", bound=FormValidationAction
)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
validation_action_class: Optional[Type[ValidationActionType]] = None,
form_validation_action_class: Optional[Type[FormValidationActionType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.
Expand All @@ -143,6 +156,8 @@ def instrument(
is given, no `ActionExecutor` will be instrumented.
:param validation_action_class: The `ValidationAction` to be instrumented. If `None`
is given, no `ValidationAction` will be instrumented.
:param form_validation_action_class: The `FormValidationAction` to be instrumented.
If `None` is given, no `FormValidationAction` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
Expand Down Expand Up @@ -172,8 +187,21 @@ def instrument(
"run",
attribute_extractors.extract_attrs_for_validation_action,
)
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(validation_action_class.__module__),
validation_action_class,
)
mark_class_as_instrumented(validation_action_class)

if form_validation_action_class is not None and not class_is_instrumented(
form_validation_action_class
):
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(form_validation_action_class.__module__),
form_validation_action_class,
)
mark_class_as_instrumented(form_validation_action_class)


def _instrument_method(
tracer: Tracer,
Expand Down Expand Up @@ -214,3 +242,53 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None:
_mangled_instrumented_boolean_attribute_name(instrumented_class),
True,
)


def _instrument_validation_action_extract_validation_events(
tracer: Tracer,
validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]],
) -> None:
def tracing_validation_action_extract_validation_events_wrapper(
fn: Callable,
) -> Callable:
@functools.wraps(fn)
async def wrapper(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> List[EventType]:
with tracer.start_as_current_span(
f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}"
) as span:
validation_events = await fn(self, dispatcher, tracker, domain)
event_names = []
slot_names = []

if validation_events:
for event in validation_events:
event_names.append(event.get("event"))
if event.get("event") == "slot":
slot_names.append(event.get("name"))

span.set_attributes(
{
"validation_events": json.dumps(
list(dict.fromkeys(event_names))
),
"slots": json.dumps(list(dict.fromkeys(slot_names))),
}
)
return validation_events

return wrapper

validation_action_class._extract_validation_events = ( # type: ignore
tracing_validation_action_extract_validation_events_wrapper(
validation_action_class._extract_validation_events
)
)

logger.debug(
f"Instrumented '{validation_action_class.__name__}._extract_validation_events'."
)

0 comments on commit a5df827

Please sign in to comment.