Skip to content

Commit

Permalink
feat: Implement custom error handling for Inference Gateway API
Browse files Browse the repository at this point in the history
Signed-off-by: Eden Reich <eden.reich@gmail.com>
  • Loading branch information
edenreich committed Jan 28, 2025
1 parent 8000794 commit ab8c732
Showing 1 changed file with 138 additions and 19 deletions.
157 changes: 138 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,58 @@
//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
//! with various LLM providers through a unified interface.
use reqwest::blocking::Client;
use reqwest::{blocking::Client, StatusCode};
use serde::{Deserialize, Serialize};
use std::{error::Error, fmt};

/// Custom error types for the Inference Gateway SDK
#[derive(Debug)]
pub enum GatewayError {
/// Authentication error (401)
Unauthorized(String),
/// Bad request error (400)
BadRequest(String),
/// Internal server error (500)
InternalError(String),
/// Network or reqwest-related error
RequestError(reqwest::Error),
/// Other unexpected errors
Other(Box<dyn Error + Send + Sync>),
}

impl fmt::Display for GatewayError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
Self::BadRequest(msg) => write!(f, "Bad request: {}", msg),
Self::InternalError(msg) => write!(f, "Internal server error: {}", msg),
Self::RequestError(e) => write!(f, "Request error: {}", e),
Self::Other(e) => write!(f, "Other error: {}", e),
}
}
}

impl Error for GatewayError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::Other(e) => Some(e.as_ref()),
_ => None,
}
}
}

impl From<reqwest::Error> for GatewayError {
fn from(err: reqwest::Error) -> Self {
Self::RequestError(err)
}
}

#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: String,
}

/// Represents a model available through a provider
#[derive(Debug, Serialize, Deserialize)]
pub struct Model {
Expand Down Expand Up @@ -120,11 +168,10 @@ pub struct InferenceGatewayClient {
/// Core API interface for the Inference Gateway
pub trait InferenceGatewayAPI {
/// Lists available models from all providers
fn list_models(&self) -> Result<Vec<ProviderModels>, Box<dyn Error>>;
fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError>;

/// Lists available models by a specific provider
fn list_models_by_provider(&self, provider: Provider)
-> Result<ProviderModels, Box<dyn Error>>;
fn list_models_by_provider(&self, provider: Provider) -> Result<ProviderModels, GatewayError>;

/// Generates content using a specified model
///
Expand All @@ -137,7 +184,7 @@ pub trait InferenceGatewayAPI {
provider: Provider,
model: &str,
messages: Vec<Message>,
) -> Result<GenerateResponse, Box<dyn Error>>;
) -> Result<GenerateResponse, GatewayError>;

/// Checks if the API is available
fn health_check(&self) -> Result<bool, Box<dyn Error>>;
Expand Down Expand Up @@ -167,52 +214,104 @@ impl InferenceGatewayClient {
}

impl InferenceGatewayAPI for InferenceGatewayClient {
fn list_models(&self) -> Result<Vec<ProviderModels>, Box<dyn Error>> {
fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError> {
let url = format!("{}/llms", self.base_url);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = self.client.get(&url).bearer_auth(token);
request = request.bearer_auth(token);
}

let response = request.send()?;
let models = response.json()?;
Ok(models)

match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}

fn list_models_by_provider(
&self,
provider: Provider,
) -> Result<ProviderModels, Box<dyn Error>> {
fn list_models_by_provider(&self, provider: Provider) -> Result<ProviderModels, GatewayError> {
let url = format!("{}/llms/{}", self.base_url, provider);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = self.client.get(&url).bearer_auth(token);
}

let response = request.send()?;
let models: ProviderModels = response.json()?;
Ok(models)

match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}

fn generate_content(
&self,
provider: Provider,
model: &str,
messages: Vec<Message>,
) -> Result<GenerateResponse, Box<dyn Error>> {
) -> Result<GenerateResponse, GatewayError> {
let url = format!("{}/llms/{}/generate", self.base_url, provider);
let mut request = self.client.post(&url);
if let Some(token) = &self.token {
request = self.client.post(&url).bearer_auth(token);
request = request.bearer_auth(token);
}

let request_payload = GenerateRequest {
model: model.to_string(),
messages,
};

let response = request.json(&request_payload).send()?.json()?;
Ok(response)
let response = request.json(&request_payload).send()?;

match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}

fn health_check(&self) -> Result<bool, Box<dyn Error>> {
Expand Down Expand Up @@ -260,6 +359,26 @@ mod tests {
mock_without_auth.assert();
}

#[test]
fn test_unauthorized_error() {
let mut server = Server::new();
let mock = server
.mock("GET", "/llms")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid token"}"#)
.create();

let client = InferenceGatewayClient::new(&server.url());
let error = client.list_models().unwrap_err();

assert!(matches!(error, GatewayError::Unauthorized(_)));
if let GatewayError::Unauthorized(msg) = error {
assert_eq!(msg, "Invalid token");
}
mock.assert();
}

#[test]
fn test_list_models() {
let mut server = Server::new();
Expand Down

0 comments on commit ab8c732

Please sign in to comment.