From c33e11dce2c2ecc8620654b8a8a5c65d406b2e17 Mon Sep 17 00:00:00 2001 From: Minaam Shahid <34722968+minaamshahid@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:27:25 +0500 Subject: [PATCH] chore: linting fixes --- flow_judge/flow_judge.py | 12 +- flow_judge/models/adapters/baseten/adapter.py | 174 +++++++++--------- flow_judge/models/adapters/baseten/data_io.py | 6 +- flow_judge/models/adapters/baseten/deploy.py | 6 +- flow_judge/models/adapters/baseten/errors.py | 9 +- .../models/adapters/baseten/management.py | 82 +++------ .../models/adapters/baseten/token_bucket.py | 3 +- 7 files changed, 138 insertions(+), 154 deletions(-) diff --git a/flow_judge/flow_judge.py b/flow_judge/flow_judge.py index 6a1f3e2..82a1048 100644 --- a/flow_judge/flow_judge.py +++ b/flow_judge/flow_judge.py @@ -1,10 +1,10 @@ import asyncio import logging -from flow_judge.models.adapters.baseten.data_io import BatchResult -from flow_judge.models.adapters.baseten.errors import FlowJudgeError from flow_judge.eval_data_types import EvalInput, EvalOutput from flow_judge.metrics import CustomMetric, Metric +from flow_judge.models.adapters.baseten.data_io import BatchResult +from flow_judge.models.adapters.baseten.errors import FlowJudgeError from flow_judge.models.common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars from flow_judge.utils.result_writer import write_results_to_disk @@ -122,7 +122,9 @@ def __init__( if not isinstance(model, AsyncBaseFlowJudgeModel): raise ValueError("Invalid model type. Use AsyncBaseFlowJudgeModel or its subclasses.") - async def async_evaluate(self, eval_input: EvalInput, save_results: bool = False) -> EvalOutput | None: + async def async_evaluate( + self, eval_input: EvalInput, save_results: bool = False + ) -> EvalOutput | None: """Evaluate a single EvalInput object asynchronously.""" try: self._validate_inputs(eval_input) @@ -132,9 +134,7 @@ async def async_evaluate(self, eval_input: EvalInput, save_results: bool = False # If there are Baseten errors we log & return here. if isinstance(result, FlowJudgeError): - logger.error( - f" {result.error_type}: {result.error_message}" - ) + logger.error(f" {result.error_type}: {result.error_message}") return eval_output = EvalOutput.parse(response) diff --git a/flow_judge/models/adapters/baseten/adapter.py b/flow_judge/models/adapters/baseten/adapter.py index b8f26fe..b594b22 100644 --- a/flow_judge/models/adapters/baseten/adapter.py +++ b/flow_judge/models/adapters/baseten/adapter.py @@ -1,28 +1,38 @@ -import os import asyncio import json import logging +import os import time -from typing import Any, Union +from typing import Any import aiohttp import structlog from openai import OpenAI, OpenAIError -from tenacity import RetryError, retry, stop_after_attempt, wait_exponential, retry_if_exception_type, before_sleep_log -from flow_judge.models.adapters.base import BaseAPIAdapter -from flow_judge.models.adapters.base import AsyncBaseAPIAdapter -from flow_judge.models.adapters.baseten.token_bucket import TokenBucket -from flow_judge.models.adapters.baseten.data_io import Message, BatchResult -from flow_judge.models.adapters.baseten.management import set_scale_down_delay, wake_deployment, get_production_deployment_status -from flow_judge.models.adapters.baseten.validation import validate_baseten_signature +from tenacity import ( + RetryError, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from flow_judge.models.adapters.base import AsyncBaseAPIAdapter, BaseAPIAdapter +from flow_judge.models.adapters.baseten.data_io import BatchResult, Message from flow_judge.models.adapters.baseten.errors import ( - FlowJudgeError, BasetenAPIError, + BasetenRateLimitError, BasetenRequestError, BasetenResponseError, - BasetenRateLimitError + FlowJudgeError, ) - +from flow_judge.models.adapters.baseten.management import ( + get_production_deployment_status, + set_scale_down_delay, + wake_deployment, +) +from flow_judge.models.adapters.baseten.token_bucket import TokenBucket +from flow_judge.models.adapters.baseten.validation import validate_baseten_signature logger = structlog.get_logger(__name__) @@ -176,7 +186,6 @@ def __init__( except KeyError as e: raise ValueError("BASETEN_API_KEY is not provided in the environment.") from e - async def _check_webhook_health(self) -> bool: """Make an async health request to the webhook before executing generation tasks. @@ -185,11 +194,10 @@ async def _check_webhook_health(self) -> bool: """ try: async with aiohttp.ClientSession() as session: - async with session.get( - f"{self.webhook_proxy_url}/health") as response: + async with session.get(f"{self.webhook_proxy_url}/health") as response: if response.status != 200: - raise BasetenRequestError("Proxy seems in a unhealthy state." - " Aborting Baseten requests." + raise BasetenRequestError( + "Proxy seems in a unhealthy state." " Aborting Baseten requests." ) return True @@ -197,11 +205,9 @@ async def _check_webhook_health(self) -> bool: raise BasetenRequestError(f"Network error while fetching token: {str(e)}") from e except ConnectionError as e: raise BasetenRequestError( - "Unable to connect to the webhook proxy." - " Make sure the correct URL is given." + "Unable to connect to the webhook proxy." " Make sure the correct URL is given." ) from e - async def _make_request(self, request_messages: list[dict[str, Any]]) -> str: """Make an asynchronous request to the Baseten model. @@ -249,7 +255,6 @@ async def _make_request(self, request_messages: list[dict[str, Any]]) -> str: except json.JSONDecodeError as e: raise BasetenResponseError(f"Invalid JSON response: {str(e)}") from e - async def _get_stream_token(self, request_id: str) -> str | None: """Retrieve the stream token for a given request ID. @@ -283,7 +288,6 @@ async def _get_stream_token(self, request_id: str) -> str | None: except json.JSONDecodeError as e: raise BasetenResponseError(f"Invalid JSON response for token: {str(e)}") from e - async def _fetch_stream(self, request_id: str) -> str: """Fetch and process the stream from the webhook proxy. @@ -327,14 +331,12 @@ async def _fetch_stream(self, request_id: str) -> str: resp = json.loads(split_chunks[1]) message = resp["data"]["choices"][0]["message"]["content"].strip() - signature = split_chunks[2].replace("\n\n","").split("signature=")[1] + signature = split_chunks[2].replace("\n\n", "").split("signature=")[1] except (json.JSONDecodeError, KeyError, IndexError) as e: logger.warning(f"Failed to parse chunk: {e}") - raise BasetenResponseError( - f"Invalid JSON response: {str(e)}" - ) from e - + raise BasetenResponseError(f"Invalid JSON response: {str(e)}") from e + if "data: eot" in decoded_chunk: break @@ -347,9 +349,7 @@ async def _fetch_stream(self, request_id: str) -> str: return message - async def _process_request_with_retry( - self, message: Message - ) -> Message | FlowJudgeError: + async def _process_request_with_retry(self, message: Message) -> Message | FlowJudgeError: """Process a single request message with retry and exponential backoff. Args: @@ -363,6 +363,7 @@ async def _process_request_with_retry( It will attempt to process the request up to self.max_retries times before giving up and returning a FlowJudgeError. """ + @retry( retry=retry_if_exception_type(BasetenAPIError), stop=stop_after_attempt(self.max_retries), @@ -383,7 +384,6 @@ async def _attempt_request(): return await self._fetch_stream(request_id) - try: response = await _attempt_request() message["response"] = response @@ -401,7 +401,7 @@ async def _attempt_request(): error_message=str(e), request_id=message["id"], ) - + async def _is_model_awake(self) -> bool: """Wake the deployed model. @@ -418,54 +418,52 @@ async def _is_model_awake(self) -> bool: status = await get_production_deployment_status(self.model_id, self.baseten_api_key) if status == "ACTIVE": return - + has_triggered_correctly = await wake_deployment(self.model_id, self.baseten_api_key) if not has_triggered_correctly: raise BasetenAPIError("Trigger to wake the deployed model failed.") - TIMEOUT_SECONDS = 300 + timeout_seconds = 300 # Wait for the initial trigger to switch deployment status await asyncio.sleep(3) - async def has_model_activated(start_time: Union[float, int]): + async def has_model_activated(start_time: float | int): status = await get_production_deployment_status(self.model_id, self.baseten_api_key) if status is None: - logger.warning( - "Unable to detect if model is awake. " - "Continuing anyway." - ) + logger.warning("Unable to detect if model is awake. " "Continuing anyway.") return True if status in ["BUILDING", "DEPLOYING", "LOADING_MODEL", "WAKING_UP", "UPDATING"]: logger.info("The deployed model is waking up.") - if time.time() - start_time >= TIMEOUT_SECONDS: + if time.time() - start_time >= timeout_seconds: raise BasetenAPIError("Model took too long to wake up. Stopping execution.") - + await asyncio.sleep(10) return await has_model_activated(start_time) - + if status in ["BUILD_FAILED", "BUILD_STOPPED", "FAILED", "UNHEALTHY"]: - raise BasetenAPIError("Model seems to be in an unhealthy state. Stopping execution.") + raise BasetenAPIError( + "Model seems to be in an unhealthy state. Stopping execution." + ) if status in ["ACTIVE"]: logger.info("The deployed model is active.") return True - + if not await has_model_activated(time.time()): raise BasetenAPIError("Unable to wake up the model.") - - + async def _initialize_state_for_request(self, scale_down_delay: int) -> None: - """Pre-steps for a single/batched request - + """Pre-steps for a single/batched request. + :param scale_down_delay: The delay in seconds to scale down the model. - - Note: + + Note: Activates the model by sendng a "wake-up" request and waits for the model to wake-up. Updates the scale-down delay value, defaults to 120secs for a single request, and 30secs for batched requests. - + Raises: BasetenAPIError for when we are unable to activate the model. """ @@ -473,45 +471,44 @@ async def _initialize_state_for_request(self, scale_down_delay: int) -> None: # Update scale down delay to 30secs for batched requests. is_scaled_down = await set_scale_down_delay( - scale_down_delay=30, - api_key=self.baseten_api_key, - model_id=self.model_id + scale_down_delay=30, api_key=self.baseten_api_key, model_id=self.model_id ) if not is_scaled_down: logger.warning("Unable to reduce scale down delay. Continuing with default.") - async def _async_fetch_response(self, prompt: str) -> Message | FlowJudgeError: + """Single async request to Baseten. + + Args: + prompt: Prompt string for the request. + + Returns: + A message dictionary or an error. + (Message | FlowJudgeError) + """ # Attempt to initialize the model state. try: await self._check_webhook_health() await self._initialize_state_for_request(scale_down_delay=120) except BasetenAPIError as e: return FlowJudgeError( - error_type=type(e).__name__, - error_message=str(e), - request_id=None - ) - + error_type=type(e).__name__, error_message=str(e), request_id=None + ) + result = await self._process_request_with_retry( - Message( - prompt=prompt, - index=1, - id=None, - response=None - )) - + Message(prompt=prompt, index=1, id=None, response=None) + ) + if isinstance(result, FlowJudgeError): return result - - return result["response"] + return result["response"] async def _async_fetch_batched_response(self, prompts: list[str]) -> BatchResult: """Process a batch of evaluation inputs asynchronously. Args: - batch (List[str]): A list of prompts to process. + prompts (List[str]): A list of prompts to process. Returns: BatchResult: An object containing successful outputs and errors. @@ -521,44 +518,43 @@ async def _async_fetch_batched_response(self, prompts: list[str]) -> BatchResult the rate limit. It aggregates results and errors into a BatchResult. """ indexed_prompts = [ - Message(index=i+1, prompt=prompt, id=None, response="") for i, prompt in enumerate(prompts) + Message(index=i + 1, prompt=prompt, id=None, response="") + for i, prompt in enumerate(prompts) ] all_results = [] - # Attempt to initialize the model state. + # Attempt to initialize the model state. try: await self._initialize_state_for_request(scale_down_delay=30) except BasetenAPIError as e: return BatchResult( successful_outputs=[], - errors=[FlowJudgeError( - error_type=type(e).__name__, - error_message=str(e), - request_id=None - )], + errors=[ + FlowJudgeError( + error_type=type(e).__name__, error_message=str(e), request_id=None + ) + ], success_rate=0, - total_requests=0 - ) + total_requests=0, + ) for i in range(0, len(indexed_prompts), self.batch_size): - try: await self._check_webhook_health() except BasetenAPIError as e: - all_results.append(FlowJudgeError( - error_type=type(e).__name__, - error_message=str(e), - request_id=None, - )) + all_results.append( + FlowJudgeError( + error_type=type(e).__name__, + error_message=str(e), + request_id=None, + ) + ) break batch = indexed_prompts[i : i + self.batch_size] logger.debug(f"Batch {i}: {batch}") - tasks = [ - self._process_request_with_retry(request_message) - for request_message in batch - ] + tasks = [self._process_request_with_retry(request_message) for request_message in batch] results = await asyncio.gather(*tasks) all_results.extend(results) diff --git a/flow_judge/models/adapters/baseten/data_io.py b/flow_judge/models/adapters/baseten/data_io.py index 102be51..08b1082 100644 --- a/flow_judge/models/adapters/baseten/data_io.py +++ b/flow_judge/models/adapters/baseten/data_io.py @@ -1,7 +1,9 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict + from flow_judge.models.adapters.baseten.errors import FlowJudgeError + class Message(TypedDict): """Represents a single request message for the Baseten API. @@ -14,11 +16,13 @@ class Message(TypedDict): Do not include sensitive information in the 'content' field, as it may be logged or stored for debugging purposes. """ + id: str index: int prompt: str response: str + class BatchResult(BaseModel): """Represents the result of a batch evaluation process. @@ -60,4 +64,4 @@ def check_success_rate_range(cls, v): """Placeholder.""" if not 0 <= v <= 1: raise ValueError("success_rate must be between 0 and 1") - return v \ No newline at end of file + return v diff --git a/flow_judge/models/adapters/baseten/deploy.py b/flow_judge/models/adapters/baseten/deploy.py index bd2f01a..1347a22 100644 --- a/flow_judge/models/adapters/baseten/deploy.py +++ b/flow_judge/models/adapters/baseten/deploy.py @@ -6,6 +6,7 @@ import truss from truss.api.definitions import ModelDeployment from truss.remote.baseten.error import ApiError + from flow_judge.models.adapters.baseten.management import set_scale_down_delay from .api_auth import ensure_baseten_authentication, get_baseten_api_key @@ -54,7 +55,10 @@ def _initialize_model() -> bool: if has_updated_scale_down: logger.info("Successfully updated Baseten deployed model scale down delay to 2 mins.") else: - logger.info("Unable to update Baseten deployed model scale down delay period. Continuing with default") + logger.info( + "Unable to update Baseten deployed model scale down delay period." + " Continuing with default" + ) return True except ApiError as e: logger.error( diff --git a/flow_judge/models/adapters/baseten/errors.py b/flow_judge/models/adapters/baseten/errors.py index d1fcd76..e3d5b0b 100644 --- a/flow_judge/models/adapters/baseten/errors.py +++ b/flow_judge/models/adapters/baseten/errors.py @@ -1,6 +1,8 @@ from datetime import datetime + from pydantic import BaseModel, Field, field_validator + class FlowJudgeError(BaseModel): """Represents an error encountered during the Flow Judge evaluation process. @@ -22,7 +24,9 @@ class FlowJudgeError(BaseModel): error_type: str = Field(..., description="Type of the error encountered") error_message: str = Field(..., description="Detailed error message") - request_id: str | None = Field(default=None, description="ID of the request that caused the error") + request_id: str | None = Field( + default=None, description="ID of the request that caused the error" + ) timestamp: datetime = Field( default_factory=datetime.now, description="Time when the error occurred" ) @@ -39,6 +43,7 @@ def check_non_empty_string(cls, v): raise ValueError("Field must not be empty or just whitespace") return v + class BasetenAPIError(Exception): """Base exception for Baseten API errors.""" @@ -60,4 +65,4 @@ class BasetenResponseError(BasetenAPIError): class BasetenRateLimitError(BasetenAPIError): """Exception for rate limit errors.""" - pass \ No newline at end of file + pass diff --git a/flow_judge/models/adapters/baseten/management.py b/flow_judge/models/adapters/baseten/management.py index 0b9117a..987f9be 100644 --- a/flow_judge/models/adapters/baseten/management.py +++ b/flow_judge/models/adapters/baseten/management.py @@ -1,9 +1,11 @@ -import aiohttp import json + +import aiohttp import structlog logger = structlog.get_logger(__name__) + def _get_management_base_url(model_id: str) -> str: """Get the base URL for the Management API. @@ -12,11 +14,8 @@ def _get_management_base_url(model_id: str) -> str: """ return f"https://api.baseten.co/v1/models/{model_id}/deployments/production" -async def set_scale_down_delay( - scale_down_delay: int, - api_key: str, - model_id: str - ) -> bool: + +async def set_scale_down_delay(scale_down_delay: int, api_key: str, model_id: str) -> bool: """Dynamically updates the cooldown period for the deployed model. :param scale_down_delay: The cooldown period in seconds. @@ -30,40 +29,29 @@ async def set_scale_down_delay( async with session.patch( url=url, headers={"Authorization": f"Api-Key {api_key}"}, - json={ - "scale_down_delay": scale_down_delay - } + json={"scale_down_delay": scale_down_delay}, ) as response: if response.status != 200: logger.warning( "Unable to update Baseten scale down delay attribute." f"Request failed with status code {response.status}" - ) + ) return False - + resp = await response.json() if "status" in resp: - return ( - True if resp["status"] - in ["ACCEPTED", "UNCHANGED", "QUEUED"] - else False - ) + return True if resp["status"] in ["ACCEPTED", "UNCHANGED", "QUEUED"] else False except aiohttp.ClientError as e: - logger.warning( - "Network error with Baseten scale_down_delay" - f" {e}" - ) + logger.warning("Network error with Baseten scale_down_delay" f" {e}") return False except Exception as e: - logger.error( - "Unexpected error occurred with Baseten scale down delay request" - f" {e}" - ) + logger.error("Unexpected error occurred with Baseten scale down delay request" f" {e}") return False - + + async def wake_deployment(model_id: str, api_key: str) -> bool: """Activates the Baseten model. - + :param model_id: The ID of the deployed model. :returns: True if success, False if failed. :rtype: bool @@ -72,31 +60,24 @@ async def wake_deployment(model_id: str, api_key: str) -> bool: try: async with aiohttp.ClientSession() as session: async with session.post( - url=url, - headers={"Authorization": f"Api-Key {api_key}"}, - json={} + url=url, headers={"Authorization": f"Api-Key {api_key}"}, json={} ) as response: if response.status != 202: logger.warning( "Unable to activate Baseten model." f"Request failed with status code {response.status}" - ) + ) return False - + return True except aiohttp.ClientError as e: - logger.warning( - "Network error with Baseten model activation." - f" {e}" - ) + logger.warning("Network error with Baseten model activation." f" {e}") return False except Exception as e: - logger.error( - "Unexpected error occurred with Baseten model activation." - f" {e}" - ) + logger.error("Unexpected error occurred with Baseten model activation." f" {e}") return False - + + async def get_production_deployment_status(model_id: str, api_key: str) -> str | None: """Get model production deployment_id by it's model_id. @@ -114,26 +95,19 @@ async def get_production_deployment_status(model_id: str, api_key: str) -> str | logger.warning( "Unable to get model deployment details" f"Request failed with status {response.status}" - ) + ) return None - + re = await response.json() return re["status"] except (json.JSONDecodeError, KeyError, IndexError) as e: - logger.warning( - "Unable to parse response for Model deployment info request." - f" {e}" - ) - return None + logger.warning("Unable to parse response for Model deployment info request." f" {e}") + return None except aiohttp.ClientError as e: - logger.warning( - "Network error with Baseten model deployment information." - f" {e}" - ) + logger.warning("Network error with Baseten model deployment information." f" {e}") return None except Exception as e: logger.error( - "Unexpected error occurred with Baseten model deployment info request." - f" {e}" + "Unexpected error occurred with Baseten model deployment info request." f" {e}" ) - return None \ No newline at end of file + return None diff --git a/flow_judge/models/adapters/baseten/token_bucket.py b/flow_judge/models/adapters/baseten/token_bucket.py index b1537ba..efe548d 100644 --- a/flow_judge/models/adapters/baseten/token_bucket.py +++ b/flow_judge/models/adapters/baseten/token_bucket.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass, field + @dataclass class TokenBucket: """Implements a token bucket algorithm for rate limiting. @@ -43,4 +44,4 @@ def consume(self, tokens: int = 1) -> bool: if self.tokens >= tokens: self.tokens -= tokens return True - return False \ No newline at end of file + return False