Skip to content

Commit

Permalink
feat: DIA-1725: add model-metadata endpoint (#279)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Bernstein <matt@humansignal.com>
Co-authored-by: matt-bernstein <matt-bernstein@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent db98299 commit 6f09afa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
19 changes: 19 additions & 0 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,22 @@ async def batch_to_batch(
return output_df.set_index(batch.index)

# TODO: cost estimate

def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=None) -> dict:
if auth_info is None:
auth_info = {}
try:
# for azure models, need to get the canonical name for the model
if provider == "azure":
dummy_completion = litellm.completion(
model=f"azure/{model_name}",
messages=[{"role": "user", "content": ""}],
max_tokens=1,
**auth_info
)
model_name = dummy_completion.model
full_name = f"{provider}/{model_name}"
return litellm.get_model_info(full_name)
except Exception as err:
logger.error("Hit error when trying to get model metadata: %s", err)
return {}
21 changes: 21 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,27 @@ async def improved_prompt(request: ImprovedPromptRequest):
)


class ModelMetadataRequestItem(BaseModel):
provider: str
model_name: str
auth_info: Optional[Dict[str, str]] = None

class ModelMetadataRequest(BaseModel):
models: List[ModelMetadataRequestItem]

class ModelMetadataResponse(BaseModel):
model_metadata: Dict[str, Dict]

@app.post("/model-metadata", response_model=Response[ModelMetadataResponse])
async def model_metadata(request: ModelMetadataRequest):
from adala.runtimes._litellm import get_model_info

resp = {'model_metadata': {item.model_name: get_model_info(**item.model_dump()) for item in request.models}}
return Response[ModelMetadataResponse](
success=True,
data=resp
)

if __name__ == "__main__":
# for debugging
uvicorn.run("app:app", host="0.0.0.0", port=30001)

0 comments on commit 6f09afa

Please sign in to comment.