Skip to content

Commit

Permalink
chore: linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
minaamshahid authored and sariola committed Oct 21, 2024
1 parent b78dd82 commit c33e11d
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 154 deletions.
12 changes: 6 additions & 6 deletions flow_judge/flow_judge.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import logging

from flow_judge.models.adapters.baseten.data_io import BatchResult
from flow_judge.models.adapters.baseten.errors import FlowJudgeError
from flow_judge.eval_data_types import EvalInput, EvalOutput
from flow_judge.metrics import CustomMetric, Metric
from flow_judge.models.adapters.baseten.data_io import BatchResult
from flow_judge.models.adapters.baseten.errors import FlowJudgeError
from flow_judge.models.common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel
from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars
from flow_judge.utils.result_writer import write_results_to_disk
Expand Down Expand Up @@ -122,7 +122,9 @@ def __init__(
if not isinstance(model, AsyncBaseFlowJudgeModel):
raise ValueError("Invalid model type. Use AsyncBaseFlowJudgeModel or its subclasses.")

async def async_evaluate(self, eval_input: EvalInput, save_results: bool = False) -> EvalOutput | None:
async def async_evaluate(
self, eval_input: EvalInput, save_results: bool = False
) -> EvalOutput | None:
"""Evaluate a single EvalInput object asynchronously."""
try:
self._validate_inputs(eval_input)
Expand All @@ -132,9 +134,7 @@ async def async_evaluate(self, eval_input: EvalInput, save_results: bool = False

# If there are Baseten errors we log & return here.
if isinstance(result, FlowJudgeError):
logger.error(
f" {result.error_type}: {result.error_message}"
)
logger.error(f" {result.error_type}: {result.error_message}")
return

eval_output = EvalOutput.parse(response)
Expand Down
174 changes: 85 additions & 89 deletions flow_judge/models/adapters/baseten/adapter.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
import os
import asyncio
import json
import logging
import os
import time
from typing import Any, Union
from typing import Any

import aiohttp
import structlog
from openai import OpenAI, OpenAIError
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential, retry_if_exception_type, before_sleep_log
from flow_judge.models.adapters.base import BaseAPIAdapter
from flow_judge.models.adapters.base import AsyncBaseAPIAdapter
from flow_judge.models.adapters.baseten.token_bucket import TokenBucket
from flow_judge.models.adapters.baseten.data_io import Message, BatchResult
from flow_judge.models.adapters.baseten.management import set_scale_down_delay, wake_deployment, get_production_deployment_status
from flow_judge.models.adapters.baseten.validation import validate_baseten_signature
from tenacity import (
RetryError,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from flow_judge.models.adapters.base import AsyncBaseAPIAdapter, BaseAPIAdapter
from flow_judge.models.adapters.baseten.data_io import BatchResult, Message
from flow_judge.models.adapters.baseten.errors import (
FlowJudgeError,
BasetenAPIError,
BasetenRateLimitError,
BasetenRequestError,
BasetenResponseError,
BasetenRateLimitError
FlowJudgeError,
)

from flow_judge.models.adapters.baseten.management import (
get_production_deployment_status,
set_scale_down_delay,
wake_deployment,
)
from flow_judge.models.adapters.baseten.token_bucket import TokenBucket
from flow_judge.models.adapters.baseten.validation import validate_baseten_signature

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -176,7 +186,6 @@ def __init__(
except KeyError as e:
raise ValueError("BASETEN_API_KEY is not provided in the environment.") from e


async def _check_webhook_health(self) -> bool:
"""Make an async health request to the webhook before executing generation tasks.
Expand All @@ -185,23 +194,20 @@ async def _check_webhook_health(self) -> bool:
"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.webhook_proxy_url}/health") as response:
async with session.get(f"{self.webhook_proxy_url}/health") as response:
if response.status != 200:
raise BasetenRequestError("Proxy seems in a unhealthy state."
" Aborting Baseten requests."
raise BasetenRequestError(
"Proxy seems in a unhealthy state." " Aborting Baseten requests."
)
return True

except aiohttp.ClientError as e:
raise BasetenRequestError(f"Network error while fetching token: {str(e)}") from e
except ConnectionError as e:
raise BasetenRequestError(
"Unable to connect to the webhook proxy."
" Make sure the correct URL is given."
"Unable to connect to the webhook proxy." " Make sure the correct URL is given."
) from e


async def _make_request(self, request_messages: list[dict[str, Any]]) -> str:
"""Make an asynchronous request to the Baseten model.
Expand Down Expand Up @@ -249,7 +255,6 @@ async def _make_request(self, request_messages: list[dict[str, Any]]) -> str:
except json.JSONDecodeError as e:
raise BasetenResponseError(f"Invalid JSON response: {str(e)}") from e


async def _get_stream_token(self, request_id: str) -> str | None:
"""Retrieve the stream token for a given request ID.
Expand Down Expand Up @@ -283,7 +288,6 @@ async def _get_stream_token(self, request_id: str) -> str | None:
except json.JSONDecodeError as e:
raise BasetenResponseError(f"Invalid JSON response for token: {str(e)}") from e


async def _fetch_stream(self, request_id: str) -> str:
"""Fetch and process the stream from the webhook proxy.
Expand Down Expand Up @@ -327,14 +331,12 @@ async def _fetch_stream(self, request_id: str) -> str:
resp = json.loads(split_chunks[1])

message = resp["data"]["choices"][0]["message"]["content"].strip()
signature = split_chunks[2].replace("\n\n","").split("signature=")[1]
signature = split_chunks[2].replace("\n\n", "").split("signature=")[1]

except (json.JSONDecodeError, KeyError, IndexError) as e:
logger.warning(f"Failed to parse chunk: {e}")
raise BasetenResponseError(
f"Invalid JSON response: {str(e)}"
) from e

raise BasetenResponseError(f"Invalid JSON response: {str(e)}") from e

if "data: eot" in decoded_chunk:
break

Expand All @@ -347,9 +349,7 @@ async def _fetch_stream(self, request_id: str) -> str:

return message

async def _process_request_with_retry(
self, message: Message
) -> Message | FlowJudgeError:
async def _process_request_with_retry(self, message: Message) -> Message | FlowJudgeError:
"""Process a single request message with retry and exponential backoff.
Args:
Expand All @@ -363,6 +363,7 @@ async def _process_request_with_retry(
It will attempt to process the request up to self.max_retries times
before giving up and returning a FlowJudgeError.
"""

@retry(
retry=retry_if_exception_type(BasetenAPIError),
stop=stop_after_attempt(self.max_retries),
Expand All @@ -383,7 +384,6 @@ async def _attempt_request():

return await self._fetch_stream(request_id)


try:
response = await _attempt_request()
message["response"] = response
Expand All @@ -401,7 +401,7 @@ async def _attempt_request():
error_message=str(e),
request_id=message["id"],
)

async def _is_model_awake(self) -> bool:
"""Wake the deployed model.
Expand All @@ -418,100 +418,97 @@ async def _is_model_awake(self) -> bool:
status = await get_production_deployment_status(self.model_id, self.baseten_api_key)
if status == "ACTIVE":
return

has_triggered_correctly = await wake_deployment(self.model_id, self.baseten_api_key)
if not has_triggered_correctly:
raise BasetenAPIError("Trigger to wake the deployed model failed.")

TIMEOUT_SECONDS = 300
timeout_seconds = 300

# Wait for the initial trigger to switch deployment status
await asyncio.sleep(3)

async def has_model_activated(start_time: Union[float, int]):
async def has_model_activated(start_time: float | int):
status = await get_production_deployment_status(self.model_id, self.baseten_api_key)
if status is None:
logger.warning(
"Unable to detect if model is awake. "
"Continuing anyway."
)
logger.warning("Unable to detect if model is awake. " "Continuing anyway.")
return True
if status in ["BUILDING", "DEPLOYING", "LOADING_MODEL", "WAKING_UP", "UPDATING"]:
logger.info("The deployed model is waking up.")

if time.time() - start_time >= TIMEOUT_SECONDS:
if time.time() - start_time >= timeout_seconds:
raise BasetenAPIError("Model took too long to wake up. Stopping execution.")

await asyncio.sleep(10)
return await has_model_activated(start_time)

if status in ["BUILD_FAILED", "BUILD_STOPPED", "FAILED", "UNHEALTHY"]:
raise BasetenAPIError("Model seems to be in an unhealthy state. Stopping execution.")
raise BasetenAPIError(
"Model seems to be in an unhealthy state. Stopping execution."
)

if status in ["ACTIVE"]:
logger.info("The deployed model is active.")
return True

if not await has_model_activated(time.time()):
raise BasetenAPIError("Unable to wake up the model.")



async def _initialize_state_for_request(self, scale_down_delay: int) -> None:
"""Pre-steps for a single/batched request
"""Pre-steps for a single/batched request.
:param scale_down_delay: The delay in seconds to scale down the model.
Note:
Note:
Activates the model by sendng a "wake-up" request and waits for
the model to wake-up. Updates the scale-down delay value, defaults to
120secs for a single request, and 30secs for batched requests.
Raises:
BasetenAPIError for when we are unable to activate the model.
"""
await self._is_model_awake()

# Update scale down delay to 30secs for batched requests.
is_scaled_down = await set_scale_down_delay(
scale_down_delay=30,
api_key=self.baseten_api_key,
model_id=self.model_id
scale_down_delay=30, api_key=self.baseten_api_key, model_id=self.model_id
)
if not is_scaled_down:
logger.warning("Unable to reduce scale down delay. Continuing with default.")


async def _async_fetch_response(self, prompt: str) -> Message | FlowJudgeError:
"""Single async request to Baseten.
Args:
prompt: Prompt string for the request.
Returns:
A message dictionary or an error.
(Message | FlowJudgeError)
"""
# Attempt to initialize the model state.
try:
await self._check_webhook_health()
await self._initialize_state_for_request(scale_down_delay=120)
except BasetenAPIError as e:
return FlowJudgeError(
error_type=type(e).__name__,
error_message=str(e),
request_id=None
)

error_type=type(e).__name__, error_message=str(e), request_id=None
)

result = await self._process_request_with_retry(
Message(
prompt=prompt,
index=1,
id=None,
response=None
))

Message(prompt=prompt, index=1, id=None, response=None)
)

if isinstance(result, FlowJudgeError):
return result

return result["response"]

return result["response"]

async def _async_fetch_batched_response(self, prompts: list[str]) -> BatchResult:
"""Process a batch of evaluation inputs asynchronously.
Args:
batch (List[str]): A list of prompts to process.
prompts (List[str]): A list of prompts to process.
Returns:
BatchResult: An object containing successful outputs and errors.
Expand All @@ -521,44 +518,43 @@ async def _async_fetch_batched_response(self, prompts: list[str]) -> BatchResult
the rate limit. It aggregates results and errors into a BatchResult.
"""
indexed_prompts = [
Message(index=i+1, prompt=prompt, id=None, response="") for i, prompt in enumerate(prompts)
Message(index=i + 1, prompt=prompt, id=None, response="")
for i, prompt in enumerate(prompts)
]

all_results = []

# Attempt to initialize the model state.
# Attempt to initialize the model state.
try:
await self._initialize_state_for_request(scale_down_delay=30)
except BasetenAPIError as e:
return BatchResult(
successful_outputs=[],
errors=[FlowJudgeError(
error_type=type(e).__name__,
error_message=str(e),
request_id=None
)],
errors=[
FlowJudgeError(
error_type=type(e).__name__, error_message=str(e), request_id=None
)
],
success_rate=0,
total_requests=0
)
total_requests=0,
)

for i in range(0, len(indexed_prompts), self.batch_size):

try:
await self._check_webhook_health()
except BasetenAPIError as e:
all_results.append(FlowJudgeError(
error_type=type(e).__name__,
error_message=str(e),
request_id=None,
))
all_results.append(
FlowJudgeError(
error_type=type(e).__name__,
error_message=str(e),
request_id=None,
)
)
break

batch = indexed_prompts[i : i + self.batch_size]
logger.debug(f"Batch {i}: {batch}")
tasks = [
self._process_request_with_retry(request_message)
for request_message in batch
]
tasks = [self._process_request_with_retry(request_message) for request_message in batch]

results = await asyncio.gather(*tasks)
all_results.extend(results)
Expand Down
Loading

0 comments on commit c33e11d

Please sign in to comment.