diff --git a/auth_rust/src/main.rs b/auth_rust/src/main.rs index f77c51b..9639979 100644 --- a/auth_rust/src/main.rs +++ b/auth_rust/src/main.rs @@ -1,37 +1,43 @@ +use axum::extract::State; +use axum::http::{header, HeaderMap, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::any; +use axum::{routing::get, Router}; +use redis::{AsyncCommands, Client}; use std::error::Error; use std::fmt::{Display, Formatter}; -use std::rc::Rc; -use axum::{Router, routing::{any, get}}; -use axum::http::{header, HeaderMap, StatusCode}; -use axum::response::{IntoResponse}; -use redis::{AsyncCommands, Client, Commands}; use tokio::net::TcpListener; async fn healthcheck() -> impl IntoResponse { (StatusCode::OK, "OK") } -async fn authenticate(headers: HeaderMap) -> impl IntoResponse { - if let Some(token) = headers.get("X-Pesto-Token") { - todo!(); - return (StatusCode::OK, [(header::CONTENT_TYPE, "application/json")], "{\"message\":\"OK\"}"); - } else { - return ( +async fn authenticate( + headers: HeaderMap, + State(auth_repo): State>, +) -> impl IntoResponse { + match headers.get("X-Pesto-Token") { + Some(token) => ( + StatusCode::OK, + [(header::CONTENT_TYPE, "application/json")], + r#"{"message":"OK}"#, + ), + None => ( StatusCode::UNAUTHORIZED, [(header::CONTENT_TYPE, "application/json")], - "{\"message\":\"Token must be supplied\"}", - ); + r#"{"message":"Token must be supplied"}"#, + ), } } enum AuthError { - TokenNotRegistered + TokenNotRegistered, } impl Display for AuthError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - AuthError::TokenNotRegistered => write!(f, "Token not registered") + AuthError::TokenNotRegistered => write!(f, "Token not registered"), } } } @@ -42,35 +48,29 @@ struct TokenValue { revoked: bool, } -trait AuthRepo { - fn acquire_token_value(&self, token: String) -> Result; - fn acquire_counter_limit(&self, token_value: TokenValue) -> Result; - fn increment_monthly_counter(&self, token_value: TokenValue, counter_limit: i64) -> Result<(), AuthError>; -} - -// Before you say anything, yes I don't have time to think of the correct naming -// We'll just do it Java-like. #[derive(Clone)] -struct AuthRepoImpl { - pub redis_client: Rc>, +struct AuthRepo { + pub redis_client: Command, } -impl AuthRepoImpl { - fn new(redis_client: Rc>) -> Self { - return AuthRepoImpl{ redis_client } +impl AuthRepo { + fn new(redis_client: Command) -> Self { + Self { redis_client } } -} -impl AuthRepo for AuthRepoImpl { - fn acquire_token_value(&self, token: String) -> Result { + fn acquire_token_value(&self, token: &str) -> Result { todo!() } - fn acquire_counter_limit(&self, token_value: TokenValue) -> Result { + fn acquire_counter_limit(&self, token_value: &TokenValue) -> Result { todo!() } - fn increment_monthly_counter(&self, token_value: TokenValue, counter_limit: i64) -> Result<(), AuthError> { + fn increment_monthly_counter( + &self, + token_value: &TokenValue, + counter_limit: i64, + ) -> Result<(), AuthError> { todo!() } } @@ -79,12 +79,12 @@ impl AuthRepo for AuthRepoImpl { async fn main() -> Result<(), Box> { let redis_client = Client::open("redis://@localhost:6739")?; let redis_async_connection = redis_client.get_multiplexed_async_connection().await?; - let rc_client = Rc::new(Box::new(redis_async_connection) as Box); + let auth_repo = AuthRepo::new(redis_async_connection); + let app = Router::new() .route("/healthz", get(healthcheck)) .route("/", any(authenticate)) - .with_state(rc_client); - + .with_state(auth_repo); let listener = TcpListener::bind("0.0.0.0:3000").await?; axum::serve(listener, app).await?;