From a5df827fc8ba3865b1fed7d45731037ed365c831 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Mon, 19 Feb 2024 16:12:32 +0100 Subject: [PATCH] instrument ValidationAction._extract_validation_events --- rasa_sdk/tracing/config.py | 3 +- .../instrumentation/instrumentation.py | 84 ++++++++++++++++++- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/rasa_sdk/tracing/config.py b/rasa_sdk/tracing/config.py index a0079a0b1..080f2eb7a 100644 --- a/rasa_sdk/tracing/config.py +++ b/rasa_sdk/tracing/config.py @@ -14,7 +14,7 @@ from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config from rasa_sdk.tracing.instrumentation import instrumentation from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk.forms import ValidationAction, FormValidationAction TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk") @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None: tracer_provider=tracer_provider, action_executor_class=ActionExecutor, validation_action_class=ValidationAction, + form_validation_action_class=FormValidationAction, ) diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index fe7918f12..f7c589bb0 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -1,23 +1,30 @@ import functools import inspect +import json import logging from typing import ( Any, Awaitable, Callable, Dict, + List, Optional, Text, Type, TypeVar, + Union, ) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Tracer -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.events import EventType +from rasa_sdk.forms import ValidationAction, FormValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from rasa_sdk.types import DomainDict + # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S: if attr_extractor and should_extract_args else {} ) - if issubclass(self.__class__, ValidationAction): + if issubclass(self.__class__, FormValidationAction): + span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + elif issubclass(self.__class__, ValidationAction): span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" else: span_name = f"{self.__class__.__name__}.{fn.__name__}" @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor) ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction) +FormValidationActionType = TypeVar( + "FormValidationActionType", bound=FormValidationAction +) def instrument( tracer_provider: TracerProvider, action_executor_class: Optional[Type[ActionExecutorType]] = None, validation_action_class: Optional[Type[ValidationActionType]] = None, + form_validation_action_class: Optional[Type[FormValidationActionType]] = None, ) -> None: """Substitute methods to be traced by their traced counterparts. @@ -143,6 +156,8 @@ def instrument( is given, no `ActionExecutor` will be instrumented. :param validation_action_class: The `ValidationAction` to be instrumented. If `None` is given, no `ValidationAction` will be instrumented. + :param form_validation_action_class: The `FormValidationAction` to be instrumented. + If `None` is given, no `FormValidationAction` will be instrumented. """ if action_executor_class is not None and not class_is_instrumented( action_executor_class @@ -172,8 +187,21 @@ def instrument( "run", attribute_extractors.extract_attrs_for_validation_action, ) + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(validation_action_class.__module__), + validation_action_class, + ) mark_class_as_instrumented(validation_action_class) + if form_validation_action_class is not None and not class_is_instrumented( + form_validation_action_class + ): + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(form_validation_action_class.__module__), + form_validation_action_class, + ) + mark_class_as_instrumented(form_validation_action_class) + def _instrument_method( tracer: Tracer, @@ -214,3 +242,53 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None: _mangled_instrumented_boolean_attribute_name(instrumented_class), True, ) + + +def _instrument_validation_action_extract_validation_events( + tracer: Tracer, + validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]], +) -> None: + def tracing_validation_action_extract_validation_events_wrapper( + fn: Callable, + ) -> Callable: + @functools.wraps(fn) + async def wrapper( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> List[EventType]: + with tracer.start_as_current_span( + f"{validation_action_class.__name__}.{self.__class__.__name__}.{fn.__name__}" + ) as span: + validation_events = await fn(self, dispatcher, tracker, domain) + event_names = [] + slot_names = [] + + if validation_events: + for event in validation_events: + event_names.append(event.get("event")) + if event.get("event") == "slot": + slot_names.append(event.get("name")) + + span.set_attributes( + { + "validation_events": json.dumps( + list(dict.fromkeys(event_names)) + ), + "slots": json.dumps(list(dict.fromkeys(slot_names))), + } + ) + return validation_events + + return wrapper + + validation_action_class._extract_validation_events = ( # type: ignore + tracing_validation_action_extract_validation_events_wrapper( + validation_action_class._extract_validation_events + ) + ) + + logger.debug( + f"Instrumented '{validation_action_class.__name__}._extract_validation_events'." + )