Skip to content

Commit

Permalink
api/view: improve the collect of the Schema1 events
Browse files Browse the repository at this point in the history
Harmonize how we collect we prepare the Schema1 events and how we
do the payload validation and the exception handling.
  • Loading branch information
goneri committed Jun 28, 2024
1 parent 176455d commit 7585c21
Show file tree
Hide file tree
Showing 19 changed files with 792 additions and 742 deletions.
1 change: 0 additions & 1 deletion ansible_ai_connect/ai/api/data/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class APIPayload(BaseModel):
model: str = ""
prompt: str = ""
original_prompt: str = ""
context: str = ""
userId: Optional[UUID] = None
suggestionId: Optional[UUID] = None
Expand Down
7 changes: 0 additions & 7 deletions ansible_ai_connect/ai/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ class InternalServerError(BaseWisdomAPIException):
default_detail = "An error occurred attempting to complete the request."


class FeedbackValidationException(WisdomBadRequest):
default_code = "error__feedback_validation"

def __init__(self, detail, *args, **kwargs):
super().__init__(detail, *args, **kwargs)


class FeedbackInternalServerException(BaseWisdomAPIException):
status_code = 500
default_code = "error__feedback_internal_server"
Expand Down
19 changes: 0 additions & 19 deletions ansible_ai_connect/ai/api/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,25 +368,6 @@ def get_task_names_from_tasks(tasks):
return names


def restore_original_task_names(output_yaml, prompt):
if output_yaml and is_multi_task_prompt(prompt):
prompt_tasks = get_task_names_from_prompt(prompt)
matches = re.finditer(r"^- name:\s+(.*)", output_yaml, re.M)
for i, match in enumerate(matches):
try:
task_line = match.group(0)
task = match.group(1)
restored_task_line = task_line.replace(task, prompt_tasks[i])
output_yaml = output_yaml.replace(task_line, restored_task_line)
except IndexError:
logger.error(
"There is no match for the enumerated prompt task in the suggestion yaml"
)
break

return output_yaml


# List of Task keywords to filter out during prediction results parsing.
ansible_task_keywords = None
# RegExp Pattern based on ARI sources, see ansible_risk_insight/finder.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@
class DeserializeStage(PipelineElement):
def process(self, context: CompletionContext) -> None:
request = context.request
# NOTE: This line is probably useless
request._request._suggestion_id = request.data.get("suggestionId")

request_serializer = CompletionRequestSerializer(
data=request.data, context={"request": request}
)

try:
# TODO: is_valid() is already called in ai/api/views.py and we should
# reuse the validated_data here
request_serializer.is_valid(raise_exception=True)
request._request._suggestion_id = str(request_serializer.validated_data["suggestionId"])
request._request._model: str = request_serializer.validated_data.get("model", "")
request._request._ansible_extension_version = str(
request_serializer.validated_data.get("metadata", {}).get(
"ansibleExtensionVersion", None
Expand All @@ -62,6 +66,5 @@ def process(self, context: CompletionContext) -> None:
)

payload = APIPayload(**request_serializer.validated_data)
payload.original_prompt = request.data.get("prompt", "")

context.payload = payload
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from prometheus_client import Histogram
from yaml.error import MarkedYAMLError

import ansible_ai_connect.ai.api.telemetry.schema1 as schema1
from ansible_ai_connect.ai.api import formatter as fmtr
from ansible_ai_connect.ai.api.exceptions import (
PostprocessException,
process_error_count,
)
from ansible_ai_connect.ai.api.pipelines.common import PipelineElement
from ansible_ai_connect.ai.api.pipelines.completion_context import CompletionContext
from ansible_ai_connect.ai.api.utils.segment import send_segment_event
from ansible_ai_connect.ai.api.utils.segment import send_schema1_event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,32 +79,29 @@ def write_to_segment(
if isinstance(exception, MarkedYAMLError)
else str(exception) if str(exception) else exception.__class__.__name__
)

if event_type == "ARI":
event_name = "postprocess"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"truncated": truncated_yaml,
"postprocessed": postprocessed_yaml,
"details": postprocess_detail,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
if event_type == "ansible-lint":
event_name = "postprocessLint"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"postprocessed": postprocessed_yaml,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
schema1_event = schema1.Postprocess()
schema1_event.details = postprocess_detail
schema1_event.truncated = truncated_yaml

elif event_type == "ansible-lint":

schema1_event = schema1.PostprocessLint()
schema1_event.postprocessed = postprocessed_yaml
schema1_event.problem = problem

schema1_event.set_user(user)
schema1_event.set_exception(exception)
schema1_event.duration = duration
schema1_event.postprocessed = postprocessed_yaml
schema1_event.problem = problem
schema1_event.recommendation = recommendation_yaml
schema1_event.suggestionId = str(suggestion_id) if suggestion_id else ""

if model_id:
event["modelName"] = model_id
send_segment_event(event, event_name, user)
schema1_event.modelName = model_id
send_schema1_event(schema1_event)


def trim_whitespace_lines(input: str):
Expand Down Expand Up @@ -139,11 +137,10 @@ def completion_post_process(context: CompletionContext):
model_id = context.model_id
suggestion_id = context.payload.suggestionId
prompt = context.payload.prompt
original_prompt = context.payload.original_prompt
payload_context = context.payload.context
original_indent = context.original_indent
post_processed_predictions = context.anonymized_predictions.copy()
is_multi_task_prompt = fmtr.is_multi_task_prompt(original_prompt)
is_multi_task_prompt = fmtr.is_multi_task_prompt(prompt)

ari_caller = apps.get_app_config("ai").get_ari_caller()
if not ari_caller:
Expand All @@ -159,16 +156,11 @@ def completion_post_process(context: CompletionContext):
f"unexpected predictions array length {len(post_processed_predictions['predictions'])}"
)

anonymized_recommendation_yaml = post_processed_predictions["predictions"][0]
recommendation_yaml = post_processed_predictions["predictions"][0]

if not anonymized_recommendation_yaml:
raise PostprocessException(
f"unexpected prediction content {anonymized_recommendation_yaml}"
)
if not recommendation_yaml:
raise PostprocessException(f"unexpected prediction content {recommendation_yaml}")

recommendation_yaml = fmtr.restore_original_task_names(
anonymized_recommendation_yaml, original_prompt
)
recommendation_problem = None
truncated_yaml = None
postprocessed_yaml = None
Expand Down Expand Up @@ -213,7 +205,7 @@ def completion_post_process(context: CompletionContext):
f"original recommendation: \n{recommendation_yaml}"
)
postprocessed_yaml, ari_results = ari_caller.postprocess(
recommendation_yaml, original_prompt, payload_context
recommendation_yaml, prompt, payload_context
)
logger.debug(
f"suggestion id: {suggestion_id}, "
Expand Down Expand Up @@ -259,7 +251,7 @@ def completion_post_process(context: CompletionContext):
f"rules_with_applied_changes: {tasks_with_applied_changes} "
f"recommendation_yaml: [{repr(recommendation_yaml)}] "
f"postprocessed_yaml: [{repr(postprocessed_yaml)}] "
f"original_prompt: [{repr(original_prompt)}] "
f"prompt: [{repr(prompt)}] "
f"payload_context: [{repr(payload_context)}] "
f"postprocess_details: [{json.dumps(postprocess_details)}] "
)
Expand All @@ -279,7 +271,7 @@ def completion_post_process(context: CompletionContext):
write_to_segment(
user,
suggestion_id,
anonymized_recommendation_yaml,
recommendation_yaml,
truncated_yaml,
postprocessed_yaml,
postprocess_details,
Expand All @@ -298,9 +290,7 @@ def completion_post_process(context: CompletionContext):
input_yaml = postprocessed_yaml if postprocessed_yaml else recommendation_yaml
# Single task predictions are missing the `- name: ` line and fail linter schema check
if not is_multi_task_prompt:
input_yaml = (
f"{original_prompt.lstrip() if ari_caller else original_prompt}{input_yaml}"
)
input_yaml = f"{prompt.lstrip() if ari_caller else prompt}{input_yaml}"
postprocessed_yaml = ansible_lint_caller.run_linter(input_yaml)
# Stripping the leading STRIP_YAML_LINE that was added by above processing
if postprocessed_yaml.startswith(STRIP_YAML_LINE):
Expand All @@ -318,7 +308,7 @@ def completion_post_process(context: CompletionContext):
)
finally:
anonymized_input_yaml = (
postprocessed_yaml if postprocessed_yaml else anonymized_recommendation_yaml
postprocessed_yaml if postprocessed_yaml else recommendation_yaml
)
write_to_segment(
user,
Expand Down Expand Up @@ -358,6 +348,8 @@ def completion_post_process(context: CompletionContext):
logger.debug(f"suggestion id: {suggestion_id}, indented recommendation: \n{indented_yaml}")

# gather data for completion segment event
# WARNING: the block below do inplace transformation of 'tasks', we should refact the
# code to avoid that.
for i, task in enumerate(tasks):
if fmtr.is_multi_task_prompt(prompt):
task["prediction"] = fmtr.extract_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

def completion_pre_process(context: CompletionContext):
prompt = context.payload.prompt
original_prompt, _ = fmtr.extract_prompt_and_context(context.payload.original_prompt)
payload_context = context.payload.context

# Additional context (variables) is supported when
Expand Down Expand Up @@ -70,17 +69,6 @@ def completion_pre_process(context: CompletionContext):
context.payload.context, context.payload.prompt = fmtr.preprocess(
payload_context, prompt, ansibleFileType, additionalContext
)
if not multi_task:
# We are currently more forgiving on leading spacing of single task
# prompts than multi task prompts. In order to use the "original"
# single task prompt successfull in post-processing, we need to
# ensure its spacing aligns with the normalized context we got
# back from preprocess. We can calculate the proper spacing from the
# normalized prompt.
normalized_indent = len(context.payload.prompt) - len(context.payload.prompt.lstrip())
normalized_original_prompt = fmtr.normalize_yaml(original_prompt)
original_prompt = " " * normalized_indent + normalized_original_prompt
context.payload.original_prompt = original_prompt


class PreProcessStage(PipelineElement):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def add_indents(vars, n):
@modify_settings()
class CompletionPreProcessTest(TestCase):
def call_completion_pre_process(self, payload, is_commercial_user, expected_context):
original_prompt = payload.get("prompt")
user = Mock(rh_user_has_seat=is_commercial_user)
request = Mock(user=user)
serializer = CompletionRequestSerializer(context={"request": request})
Expand All @@ -529,7 +528,6 @@ def call_completion_pre_process(self, payload, is_commercial_user, expected_cont
request=request,
payload=APIPayload(
prompt=data.get("prompt"),
original_prompt=original_prompt,
context=data.get("context"),
),
metadata=data.get("metadata"),
Expand Down
5 changes: 1 addition & 4 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@


class Metadata(serializers.Serializer):
class Meta:
fields = ["ansibleExtensionVersion"]

ansibleExtensionVersion = serializers.RegexField(
r"v?\d+\.\d+\.\d+",
required=False,
Expand Down Expand Up @@ -118,7 +115,6 @@ def validate_extracted_prompt(prompt, user):
raise serializers.ValidationError(
{"prompt": "requested prompt format is not supported"}
)

if "&&" in prompt:
raise serializers.ValidationError(
{"prompt": "multiple task requests should be separated by a single '&'"}
Expand Down Expand Up @@ -155,6 +151,7 @@ def validate(self, data):
data = super().validate(data)

data["prompt"], data["context"] = fmtr.extract_prompt_and_context(data["prompt"])

CompletionRequestSerializer.validate_extracted_prompt(
data["prompt"], self.context.get("request").user
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from rest_framework.response import Response
from rest_framework.status import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST

from ansible_ai_connect.ai.api.exceptions import InternalServerError, ServiceUnavailable
from ansible_ai_connect.ai.api.permissions import (
IsOrganisationAdministrator,
IsOrganisationLightspeedSubscriber,
)
from ansible_ai_connect.ai.api.serializers import TelemetrySettingsRequestSerializer
from ansible_ai_connect.ai.api.utils.segment import send_segment_event
from ansible_ai_connect.ai.api.views import InternalServerError, ServiceUnavailable
from ansible_ai_connect.users.signals import user_set_telemetry_settings

logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 7585c21

Please sign in to comment.