Skip to content

Commit

Permalink
api/view: improve the collect of the Schema1 events
Browse files Browse the repository at this point in the history
Harmonize how we collect we prepare the Schema1 events and how we
do the payload validation and the exception handling.
  • Loading branch information
goneri committed Jun 28, 2024
1 parent 176455d commit 7327124
Show file tree
Hide file tree
Showing 13 changed files with 733 additions and 602 deletions.
7 changes: 0 additions & 7 deletions ansible_ai_connect/ai/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ class InternalServerError(BaseWisdomAPIException):
default_detail = "An error occurred attempting to complete the request."


class FeedbackValidationException(WisdomBadRequest):
default_code = "error__feedback_validation"

def __init__(self, detail, *args, **kwargs):
super().__init__(detail, *args, **kwargs)


class FeedbackInternalServerException(BaseWisdomAPIException):
status_code = 500
default_code = "error__feedback_internal_server"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
class DeserializeStage(PipelineElement):
def process(self, context: CompletionContext) -> None:
request = context.request
# NOTE: This line is probably useless
request._request._suggestion_id = request.data.get("suggestionId")

request_serializer = CompletionRequestSerializer(
Expand All @@ -38,7 +39,9 @@ def process(self, context: CompletionContext) -> None:

try:
request_serializer.is_valid(raise_exception=True)
print(request_serializer.validated_data)
request._request._suggestion_id = str(request_serializer.validated_data["suggestionId"])
request._request._model: str = request_serializer.validated_data.get("model", "")
request._request._ansible_extension_version = str(
request_serializer.validated_data.get("metadata", {}).get(
"ansibleExtensionVersion", None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from prometheus_client import Histogram
from yaml.error import MarkedYAMLError

import ansible_ai_connect.ai.api.telemetry.schema1 as schema1
from ansible_ai_connect.ai.api import formatter as fmtr
from ansible_ai_connect.ai.api.exceptions import (
PostprocessException,
process_error_count,
)
from ansible_ai_connect.ai.api.pipelines.common import PipelineElement
from ansible_ai_connect.ai.api.pipelines.completion_context import CompletionContext
from ansible_ai_connect.ai.api.utils.segment import send_segment_event
from ansible_ai_connect.ai.api.utils.segment import send_schema1_event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,32 +79,29 @@ def write_to_segment(
if isinstance(exception, MarkedYAMLError)
else str(exception) if str(exception) else exception.__class__.__name__
)

if event_type == "ARI":
event_name = "postprocess"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"truncated": truncated_yaml,
"postprocessed": postprocessed_yaml,
"details": postprocess_detail,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
if event_type == "ansible-lint":
event_name = "postprocessLint"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"postprocessed": postprocessed_yaml,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
schema1_event = schema1.Postprocess()
schema1_event.details = postprocess_detail
schema1_event.truncated = truncated_yaml

elif event_type == "ansible-lint":

schema1_event = schema1.PostprocessLint()
schema1_event.postprocessed = postprocessed_yaml
schema1_event.problem = problem

schema1_event.set_user(user)
schema1_event.set_exception(exception)
schema1_event.duration = duration
schema1_event.postprocessed = postprocessed_yaml
schema1_event.problem = problem
schema1_event.recommendation = recommendation_yaml
schema1_event.suggestionId = str(suggestion_id) if suggestion_id else ""

if model_id:
event["modelName"] = model_id
send_segment_event(event, event_name, user)
schema1_event.modelName = model_id
send_schema1_event(schema1_event)


def trim_whitespace_lines(input: str):
Expand Down Expand Up @@ -358,6 +356,8 @@ def completion_post_process(context: CompletionContext):
logger.debug(f"suggestion id: {suggestion_id}, indented recommendation: \n{indented_yaml}")

# gather data for completion segment event
# WARNING: the block below do inplace transformation of 'tasks', we should refact the
# code to avoid that.
for i, task in enumerate(tasks):
if fmtr.is_multi_task_prompt(prompt):
task["prediction"] = fmtr.extract_task(
Expand Down
19 changes: 3 additions & 16 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@


class Metadata(serializers.Serializer):
class Meta:
fields = ["ansibleExtensionVersion"]

ansibleExtensionVersion = serializers.RegexField(
r"v?\d+\.\d+\.\d+",
required=False,
Expand Down Expand Up @@ -111,14 +108,8 @@ class CompletionRequestSerializer(Metadata):
model = serializers.CharField(required=False, allow_blank=True)

@staticmethod
def validate_extracted_prompt(prompt, user):
def validate_extracted_prompt(prompt):
if fmtr.is_multi_task_prompt(prompt):
# Multi-task is commercial-only
if user.rh_user_has_seat is False:
raise serializers.ValidationError(
{"prompt": "requested prompt format is not supported"}
)

if "&&" in prompt:
raise serializers.ValidationError(
{"prompt": "multiple task requests should be separated by a single '&'"}
Expand Down Expand Up @@ -146,18 +137,14 @@ def validate_extracted_prompt(prompt, user):
raise serializers.ValidationError({"prompt": "prompt contains a dictionary"})

def validate_model(self, value):
user = self.context.get("request").user
if settings.ANSIBLE_AI_ENABLE_TECH_PREVIEW and user.rh_user_has_seat is False:
raise serializers.ValidationError("user is not entitled to customized model")
return value

def validate(self, data):
data = super().validate(data)

data["prompt"], data["context"] = fmtr.extract_prompt_and_context(data["prompt"])
CompletionRequestSerializer.validate_extracted_prompt(
data["prompt"], self.context.get("request").user
)

CompletionRequestSerializer.validate_extracted_prompt(data["prompt"])

# If suggestion ID was not included in the request, set a random UUID to it.
if data.get("suggestionId") is None:
Expand Down
Loading

0 comments on commit 7327124

Please sign in to comment.