Skip to content

Commit

Permalink
FIX jwt authentication + redirect
Browse files Browse the repository at this point in the history
  • Loading branch information
synoet committed May 10, 2024
1 parent 9896108 commit b684bef
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
1 change: 1 addition & 0 deletions k8s/debug/provision.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ helm upgrade --install traefik traefik/traefik \
--set 'ports.web.port=80' \
--set 'ports.websecure.port=443' \
--create-namespace \
--version 27.0.2 \
--namespace traefik

# Create the anubis namespace
Expand Down
2 changes: 0 additions & 2 deletions theia/proxy-rs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ WORKDIR /app

COPY --from=builder /usr/src/app/target/release/proxy-rs /app/proxy-rs

EXPOSE 5000

USER app

CMD ["./proxy-rs"]
74 changes: 38 additions & 36 deletions theia/proxy-rs/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@ mod ws;

use anyhow::Result;
use axum::{
body::Body,
extract::{Path, Request, WebSocketUpgrade},
extract::{Path, Query, Request, WebSocketUpgrade},
http::StatusCode,
response::{IntoResponse, Redirect, Response},
response::{IntoResponse, Redirect},
routing::get,
Extension, Router,
};
use axum_extra::extract::cookie::{Cookie, CookieJar};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlPoolOptions, prelude::FromRow, MySqlPool};
use std::time::Duration;
use std::{env::var, sync::Arc};
use tokio::net::TcpStream;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;

Expand All @@ -31,7 +27,12 @@ const JWT_EXPIRATION: usize = 6 * 60 * 60;

// Lazy static evaluation of environment variables
lazy_static! {
static ref DEBUG: bool = var("DEBUG").unwrap_or("false".to_string()).parse().unwrap();
static ref IS_DEBUG: bool = {
match var("DEBUG").unwrap_or("false".to_string()).parse() {
Ok(val) => val,
Err(_) => false,
}
};
static ref SECRET_KEY: String = {
match var("SECRET_KEY") {
Ok(key) => key,
Expand All @@ -52,11 +53,12 @@ lazy_static! {
struct Claims {
exp: usize,
session_id: String,
#[serde(rename = "netid")]
net_id: String,
}

fn authenticate_jwt(token: &str) -> Result<Claims> {
let key: DecodingKey = DecodingKey::from_secret(SECRET_KEY.as_bytes());
let key: DecodingKey = DecodingKey::from_secret((*SECRET_KEY).as_bytes());
let validation = Validation::new(Algorithm::HS256);
let decoded = decode::<Claims>(token, &key, &validation)?;

Expand Down Expand Up @@ -96,6 +98,7 @@ async fn update_last_proxy_time(session_id: &str, pool: &MySqlPool) -> Result<()
}

async fn ping(jar: CookieJar, Extension(pool): Extension<Arc<MySqlPool>>) -> (StatusCode, String) {
tracing::info!("Ping");
match jar.get("ide") {
Some(cookie) => match authenticate_jwt(cookie.value()) {
Ok(claims) => {
Expand All @@ -111,7 +114,12 @@ async fn ping(jar: CookieJar, Extension(pool): Extension<Arc<MySqlPool>>) -> (St
(StatusCode::OK, "pong".to_string())
}

async fn initialize(jar: CookieJar) -> impl IntoResponse {
#[derive(Deserialize)]
struct InitializeQueryParams {
token: String,
}

async fn initialize(params: Query<InitializeQueryParams>, jar: CookieJar) -> impl IntoResponse {
let failed_response = |_reason: &str| {
(
StatusCode::PERMANENT_REDIRECT,
Expand All @@ -120,15 +128,11 @@ async fn initialize(jar: CookieJar) -> impl IntoResponse {
)
};

let token = match jar.get("ide") {
Some(cookie) => match authenticate_jwt(cookie.value()) {
Ok(claims) => claims,
Err(_) => {
return failed_response("Invalid token");
}
},
None => {
return failed_response("No token provided");
let token = match authenticate_jwt(&params.token) {
Ok(claims) => claims,
Err(err) => {
tracing::error!("failed to authenticate jwt: {}", err);
return failed_response("Invalid token");
}
};

Expand All @@ -138,10 +142,15 @@ async fn initialize(jar: CookieJar) -> impl IntoResponse {

let new_jar = jar.add(ide_cookie);

let _domain = match *IS_DEBUG {
true => "localhost".to_string(),
false => "anubis-lms.io".to_string(),
};

(
StatusCode::PERMANENT_REDIRECT,
new_jar,
Redirect::to("https://anubis-lms.io/"),
Redirect::to("/ide/"),
)
}

Expand Down Expand Up @@ -175,17 +184,16 @@ async fn get_cluster_address(pool: &MySqlPool, session_id: &str) -> Result<Strin
}

async fn handle(
Path(path): Path<String>,
ws: WebSocketUpgrade,
Extension(pool): Extension<Arc<MySqlPool>>,
jar: CookieJar,
req: Request,
_req: Request,
) -> impl IntoResponse {
let port = path.parse::<u64>().unwrap();
// let port = path.parse::<u64>().unwrap();

if port > MAX_PROXY_PORT {
return (StatusCode::BAD_REQUEST, "Invalid port".to_string());
}
// if port > MAX_PROXY_PORT {
// return (StatusCode::BAD_REQUEST, "Invalid port".to_string());
// }

let token = match jar.get("ide") {
Some(cookie) => match authenticate_jwt(cookie.value()) {
Expand All @@ -210,13 +218,13 @@ async fn handle(
})
.unwrap();

let host = format!("ws://{}:{}", cluster_address, port);
let host = format!("ws://{}:{}", cluster_address, MAX_PROXY_PORT);

let _result = ws.on_upgrade(move |socket| ws::forward(host, socket));

ws.on_upgrade(move |socket| ws::forward(host, socket));
(StatusCode::OK, "authorized".to_string())
}


#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
Expand All @@ -229,21 +237,14 @@ async fn main() {
.acquire_timeout(Duration::from_secs(5))
.connect(&DB_URL)
.await
.map_err(|e| {
tracing::error!("Failed to connect to database: {}", e);
panic!("Failed to connect to database");
})
.map(|_| {
tracing::info!("Connected to database");
})
.unwrap();

let pool = Arc::new(pool);

let app = Router::new()
.route("/ping", get(ping))
.route("/initialize", get(initialize))
.route("/:port", get(handle))
.route("/", get(handle))
.layer(Extension(pool))
.layer(
TraceLayer::new_for_http()
Expand All @@ -259,5 +260,6 @@ async fn main() {
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", PROXY_SERVER_PORT))
.await
.unwrap();

axum::serve(listener, app).await.unwrap();
}

0 comments on commit b684bef

Please sign in to comment.