Skip to content

Commit

Permalink
refactor(oauth2): support oauth2 state inmemory managing (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
eve0415 authored Sep 14, 2024
1 parent 0d10447 commit f3d2ea1
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 82 deletions.
15 changes: 5 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ board = { path = "backend/board" }
game = { path = "backend/game" }
oauth = { path = "backend/oauth" }

anyhow = "1.0.88"
log = "0.4.22"
rand = "0.8.5"
redis = { version = "0.26.1", features = ["aio", "r2d2", "ahash", "tokio-comp", "connection-manager"] }
redis = { version = "0.27.0", features = ["aio", "r2d2", "ahash", "tokio-comp", "connection-manager","sentinel"] }
serde = { version = "1.0.210", features = ["derive"] }
serenity = { version = "0.12.2", default-features = false }
thiserror = "1.0.63"
tokio = { version = "1.40.0", features = ["rt-multi-thread", "macros", "rt"] }
1 change: 0 additions & 1 deletion backend/game/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ edition.workspace = true
[dependencies]
board.workspace = true

anyhow.workspace = true
serde.workspace = true
serenity.workspace = true
thiserror.workspace = true
3 changes: 2 additions & 1 deletion backend/oauth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ version.workspace = true
edition.workspace = true

[dependencies]
anyhow.workspace = true
log.workspace = true
rand.workspace = true
redis.workspace = true
serde.workspace = true
serenity.workspace = true
thiserror.workspace = true
tokio.workspace = true

async-trait = "0.1.82"
base64 = "0.22.1"
digest = "0.10.7"
reqwest = { version = "0.12.7", features = ["rustls-tls-native-roots", "multipart", "json"] }
Expand Down
18 changes: 13 additions & 5 deletions backend/oauth/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use redis::RedisError;
use std::error::Error;
use url::ParseError;

#[derive(thiserror::Error, Debug)]
pub enum Error {
pub enum OAuth2Error {
#[error("Redis connection lost")]
RedisConnectionLost,

Expand All @@ -9,8 +13,12 @@ pub enum Error {
#[error("Not a member in the guild")]
NotMember,

#[error("Unknown error: {0}")]
Unknown(#[source] anyhow::Error),
}
#[error(transparent)]
RedisError(#[from] RedisError),

pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
#[error(transparent)]
InternalError(#[from] ParseError),

#[error(transparent)]
Unknown(#[from] Box<dyn Error + Sync + Send>),
}
86 changes: 34 additions & 52 deletions backend/oauth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
pub mod error;
pub mod security;

use crate::error::OAuth2Error;
use crate::security::SecurityManager;
use base64::Engine;
use rand::random;
use redis::{AsyncCommands, Client as RedisClient};
use reqwest::ClientBuilder;
use reqwest::{Client as HttpClient, StatusCode};
use serde::{Deserialize, Serialize};
use serenity::all::UserId;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::sync::Arc;
use tokio::sync::Mutex;
use url::Url;

const AUTHORIZATION_URL: &str = "https://discord.com/oauth2/authorize";
Expand All @@ -17,52 +22,42 @@ const DISCORD_CDN_URL: &str = "https://cdn.discordapp.com";
const RESPONSE_TYPE: &str = "code";
const SCOPE: &str = "identify guilds.members.read";
const CODE_CHALLENGE_METHOD: &str = "S256";
const STATE_LIFETIME: u64 = 300;
const GRANT_TYPE: &str = "authorization_code";
const GUILD_ID: &str = "1176516474102353950";

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct DiscordOAuth {
id: String,
secret: String,
redirect_url: String,
redis_client: RedisClient,
http_client: HttpClient,
security_manager: Arc<Mutex<dyn SecurityManager>>,
}

impl DiscordOAuth {
pub fn new(
id: String,
secret: String,
redirect_url: String,
redis_client: RedisClient,
) -> Self {
DiscordOAuth {
security_manager: Arc<Mutex<dyn SecurityManager>>,
) -> Box<Self> {
Box::new(DiscordOAuth {
id,
secret,
redirect_url,
redis_client,
http_client: ClientBuilder::new().https_only(true).build().unwrap(),
}
security_manager,
})
}

pub async fn generate_authorization_url(self) -> error::Result<Url> {
pub async fn generate_authorization_url(self) -> Result<Url, OAuth2Error> {
let (state, code_verifier, code_challenge) = generate_state_and_code_challenge();

redis::cmd("HSETEX")
.arg("oauth")
.arg(STATE_LIFETIME)
.arg(state.clone())
.arg(code_verifier)
.exec_async(
&mut self
.redis_client
.get_multiplexed_tokio_connection()
.await
.map_err(|_| error::Error::RedisConnectionLost)?,
)
self.security_manager
.lock()
.await
.map_err(|e| error::Error::Unknown(e.into()))?;
.save_state(state.clone(), code_verifier.clone())
.await?;

let url = Url::parse_with_params(
AUTHORIZATION_URL,
Expand All @@ -77,30 +72,18 @@ impl DiscordOAuth {
("prompt", "none".to_string()),
],
)
.map_err(|e| error::Error::Unknown(e.into()))?;
.map_err(OAuth2Error::InternalError)?;

Ok(url)
}

pub async fn get_user(self, code: String, state: String) -> error::Result<User> {
let mut conn = self
.redis_client
.get_multiplexed_tokio_connection()
pub async fn get_user(self, code: String, state: String) -> Result<User, OAuth2Error> {
let code_verifier = self
.security_manager
.lock()
.await
.map_err(|_| error::Error::RedisConnectionLost)?;
let code_verifier: String = redis::cmd("HGET")
.arg("oauth")
.arg(state.to_owned())
.query_async(&mut conn)
.await
.map_err(|_| error::Error::InvalidState {
state: state.to_owned(),
})?;

let _: () = conn
.hdel("oauth", state.to_owned())
.await
.map_err(|e| error::Error::Unknown(e.into()))?;
.verify_state(&state)
.await?;

let params = HashMap::from([
("client_id", self.id.to_owned()),
Expand All @@ -121,21 +104,20 @@ impl DiscordOAuth {
.map_err(|e| {
log::error!("{:?}", e);

error::Error::Unknown(e.into())
OAuth2Error::Unknown(Box::new(e))
})?;

if response.status() != StatusCode::OK {
log::error!("Failed to get access token: {:?}", response.text().await);

return Err(error::Error::Unknown(anyhow::anyhow!(
"Failed to get access token"
)));
return Err(OAuth2Error::Unknown(Box::new(Error::new(
ErrorKind::UnexpectedEof,
response.text().await.unwrap(),
))));
}

let res = response.json::<AccessTokenResponse>().await.map_err(|e| {
log::error!("{:?}", e);

error::Error::Unknown(e.into())
OAuth2Error::Unknown(Box::new(e))
})?;

let response = self
Expand All @@ -146,18 +128,18 @@ impl DiscordOAuth {
.bearer_auth(res.access_token)
.send()
.await
.map_err(|e| error::Error::Unknown(e.into()))?;
.map_err(|e| OAuth2Error::Unknown(Box::new(e)))?;

if response.status() != StatusCode::OK {
log::error!("Failed to get guild member: {:?}", response.text().await);

return Err(error::Error::NotMember);
return Err(OAuth2Error::NotMember);
}

let member = response
.json::<serenity::model::guild::Member>()
.await
.map_err(|e| error::Error::Unknown(e.into()))?;
.map_err(|e| OAuth2Error::Unknown(Box::new(e)))?;

Ok(User {
id: member.user.id,
Expand Down
87 changes: 87 additions & 0 deletions backend/oauth/src/security/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::error::OAuth2Error;
use crate::security::{SecurityManager, STATE_LIFETIME};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

#[derive(Clone, Default)]
pub struct InMemorySecurityManager {
challenges: Arc<Mutex<HashMap<String, String>>>,
}

#[async_trait]
impl SecurityManager for InMemorySecurityManager {
async fn save_state(
&mut self,
state: String,
code_verifier: String,
) -> Result<(), OAuth2Error> {
let challenges = Arc::clone(&self.challenges);

{
let mut lock = challenges.lock().unwrap();
lock.insert(state.clone(), code_verifier);
}

// Remove the challenge after 5 minutes
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(STATE_LIFETIME)).await;
let mut lock = challenges.lock().unwrap();
lock.remove(&state);
});

Ok(())
}

async fn verify_state(&mut self, state: &str) -> Result<String, OAuth2Error> {
let mut lock = self.challenges.lock().unwrap();
match lock.remove(state) {
Some(code_verifier) => Ok(code_verifier),
None => Err(OAuth2Error::InvalidState {
state: state.to_owned(),
}),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_save_state() {
let mut manager = InMemorySecurityManager::default();
let state = "state".to_string();
let code_verifier = "code_verifier".to_string();

manager
.save_state(state.clone(), code_verifier.clone())
.await
.unwrap();

let challenges = manager.challenges.lock().unwrap();
assert_eq!(challenges.get(&state), Some(&code_verifier));
}

#[tokio::test]
async fn test_verify_state() {
let mut manager = InMemorySecurityManager::default();
let state = "state".to_string();
let code_verifier = "code_verifier".to_string();

manager
.save_state(state.clone(), code_verifier.clone())
.await
.unwrap();

assert_eq!(manager.verify_state(&state).await.unwrap(), code_verifier);
}

#[tokio::test]
async fn test_verify_state_invalid() {
let mut manager = InMemorySecurityManager::default();
let state = "state".to_string();

assert!(manager.verify_state(&state).await.is_err())
}
}
Loading

0 comments on commit f3d2ea1

Please sign in to comment.