Skip to content

Commit

Permalink
feat: add cost estimate endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
pakelley committed Oct 7, 2024
1 parent 952bf76 commit 5868a19
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from pydantic import BaseModel, SerializeAsAny, field_validator
from redis import Redis
import time
Expand Down Expand Up @@ -76,6 +77,19 @@ class BatchSubmitted(BaseModel):
job_id: str


class CostEstimate(BaseModel):
prompt_cost_usd: Optional[float]
completion_cost_usd: Optional[float]
total_cost_usd: Optional[float]


class CostEstimateRequest(BaseModel):
prompt: str
substitutions: List[Dict]
model: str
output_fields: List[str]


class Status(Enum):
PENDING = "Pending"
INPROGRESS = "InProgress"
Expand Down Expand Up @@ -214,6 +228,88 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


def get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens


def get_completion_tokens(model: str, output_fields: List[str]) -> int:
max_tokens = litellm.get_model_info(model=model, custom_llm_provider="openai").get(
"max_tokens", None
)
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
return min(max_tokens, 4 * len(output_fields))


def _estimate_cost(user_prompt: str, model: str, output_fields: List[str]):
prompt_tokens = get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = get_completion_tokens(model, output_fields)
prompt_cost, completion_cost = litellm.cost_per_token(
model="gpt-3.5-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
):
"""
Submits a batch of data to an existing streaming job.
Will push the batch of data into Kafka in a topic specific to the job ID
Args:
batch (BatchData): The data to push to Kafka queue to be processed by agent.arun()
Returns:
Response: Generic response indicating status of request
"""
prompt = request.prompt
substitutions = request.substitutions
model = request.model
output_fields = request.output_fields
try:
user_prompts = [prompt.format(**substitution) for substitution in substitutions]
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = _estimate_cost(
user_prompt=user_prompt,
model=model,
output_fields=output_fields,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
cumulative_total_cost += total_cost
return Response[CostEstimate](
data=CostEstimate(
prompt_cost_usd=cumulative_prompt_cost,
completion_cost_usd=cumulative_completion_cost,
total_cost_usd=cumulative_total_cost,
)
)

except Exception as e:
logger.error("Failed to estimate cost: %s", e)
return Response[CostEstimate](
data=CostEstimate(
prompt_cost_usd=None,
completion_cost_usd=None,
total_cost_usd=None,
)
)


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand Down

0 comments on commit 5868a19

Please sign in to comment.