Skip to content

Commit

Permalink
Implemented REST APIs for sync calls rather than async events
Browse files Browse the repository at this point in the history
  • Loading branch information
john-sharratt committed Mar 3, 2024
1 parent a305c16 commit 0cf5685
Show file tree
Hide file tree
Showing 64 changed files with 3,895 additions and 1,071 deletions.
408 changes: 391 additions & 17 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pharos = { version = "0.5.3" }
mime_guess = { version = "2.0.4" }
wasm-bindgen = { version = "0.2.91" }
hyper-tungstenite = { version = "0.13.0" }
gloo-net = { version = "0.5.0" }
reqwest = { version = "0.11.24", features = ["json"] }
sha256 = { version = "1.5.0", default-features = false }
tokio = { version = "1.36.0", default-features = false, features = [
"rt-multi-thread",
Expand Down
8 changes: 3 additions & 5 deletions crates/backend/src/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures_util::{stream::SplitSink, SinkExt};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use immutable_bank_model::header::LedgerMessage;
use immutable_bank_model::ledger::LedgerMessage;
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};

use crate::general_state::GeneralState;
Expand All @@ -24,10 +24,8 @@ impl GeneralState {
});
}

pub fn broadcast(&self, msg: LedgerMessage) {
tracing::warn!("Broadcast message: {:?}", msg);

let data = match bincode::serialize(&msg) {
pub fn broadcast(&self, msg: &LedgerMessage) {
let data = match bincode::serialize(msg) {
Ok(d) => d,
Err(err) => {
tracing::error!("failed to serialize entry to broadcast - {}", err);
Expand Down
56 changes: 33 additions & 23 deletions crates/backend/src/general_state.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, ops::Deref, sync::Arc, time::Duration};
use std::{ops::Deref, sync::Arc, time::Duration};

use immutable_bank_model::{header::LedgerMessage, ledger::Ledger, ledger_type::LedgerEntry};
use immutable_bank_model::{
ledger::{Ledger, LedgerForBank, LedgerMessage},
ledger_type::LedgerEntry,
};
use tokio::sync::{broadcast, Mutex, MutexGuard};

use crate::opts::Opts;
use crate::{opts::Opts, BROKER_SECRET};

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct GeneralStateInner {
pub existing_banks: HashSet<String>,
pub ledger: Ledger,
}

Expand Down Expand Up @@ -44,13 +46,22 @@ impl GeneralState {

let mut inner = GeneralStateInner::default();
for msg in msgs {
match &msg.entry {
LedgerEntry::NewBank(bank) | LedgerEntry::UpdateBank(bank) => {
inner.existing_banks.insert(bank.owner.clone());
}
_ => {}
}
inner.ledger.entries.insert(msg.header, msg.entry);
let ledger = match &msg.entry {
LedgerEntry::NewBank { bank_secret, .. } => inner
.ledger
.banks
.entry(msg.header.bank_id.clone())
.or_insert_with(|| LedgerForBank {
broker_secret: BROKER_SECRET.clone(),
bank_secret: bank_secret.clone(),
entries: Vec::new(),
}),
_ => match inner.ledger.banks.get_mut(&msg.header.bank_id) {
Some(l) => l,
None => continue,
},
};
ledger.entries.push(msg);
}

let state = GeneralState {
Expand All @@ -77,23 +88,22 @@ impl GeneralState {
// Copy the state
let msgs = {
let guard = self.lock().await;
tracing::info!(
"Saving general state to {:?} - entries.len={}",
opts.data_path,
guard.ledger.entries.len()
);
tokio::task::block_in_place(|| {
let entries = tokio::task::block_in_place(|| {
guard
.deref()
.ledger
.entries
.banks
.iter()
.map(|(h, e)| LedgerMessage {
header: h.clone(),
entry: e.clone(),
})
.flat_map(|b| b.1.entries.iter())
.cloned()
.collect::<Vec<_>>()
})
});
tracing::info!(
"Saving general state to {:?} - entries.len={}",
opts.data_path,
entries.len()
);
entries
};

// Determine the staging location
Expand Down
81 changes: 81 additions & 0 deletions crates/backend/src/handlers/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use http::{HeaderName, StatusCode};
use http_body_util::Full;
use hyper::{body::Bytes, Request, Response};
use include_dir::{include_dir, Dir};

const HTML_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/../../crates/web/dist");

pub async fn get_handler(
req: Request<hyper::body::Incoming>,
) -> anyhow::Result<Response<Full<Bytes>>> {
// Get the path to the thing we loading and sanitize it
let mut path = req.uri().path();

// Special case and strip slash
if path.is_empty() || path == "/" {
path = "/index.html";
}
if path.starts_with("/") {
path = &path[1..];
}

// Sanitize
if path.contains("..") || path.starts_with("/") || path.starts_with("~") {
tracing::warn!("Access denied: path={}", path);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from("Access denied")))?);
}

// Load the file
let file = match HTML_DIR.get_file(path) {
Some(file) => file,
None => {
tracing::debug!("Not found: path={}", path);
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("File not found")))?);
}
};

// Cache time
let cache_control = if file.path().ends_with(".mp3")
|| file.path().ends_with(".mp4")
|| file.path().ends_with(".wav")
|| file.path().ends_with(".wasm")
{
"Cache-Control: max-age=86400, stale-while-revalidate=86400"
} else {
"Cache-Control: max-age=30, stale-while-revalidate=86400"
};

// Write the response
let mut res = Response::new(Full::new(Bytes::from(file.contents())));

let meme = mime_guess::from_path(file.path()).first_or_octet_stream();
res.headers_mut().insert(
http::header::CONNECTION,
http::HeaderValue::from_str("Keep-Alive")?,
);
res.headers_mut().insert(
HeaderName::from_static("Keep-Alive"),
http::HeaderValue::from_str("timeout=2, max=100")?,
);
/*
res.headers_mut().insert(
http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
http::HeaderValue::from_str("*")?,
);
*/
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_str(&meme.to_string())?,
);
res.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_str(cache_control)?,
);

tracing::debug!("Response: status={}, url={}", res.status(), req.uri());
Ok(res)
}
60 changes: 60 additions & 0 deletions crates/backend/src/handlers/http.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use http::{Method, StatusCode};
use http_body_util::Full;
use hyper::{body::Bytes, Request, Response};
use immutable_bank_model::bank_id::BankId;

use crate::{
general_state::GeneralState,
handlers::{get::get_handler, options::options_handler, post::post_handler, ws::handle_ws},
};

pub async fn http_handler(
mut req: Request<hyper::body::Incoming>,
state: GeneralState,
) -> anyhow::Result<Response<Full<Bytes>>> {
tracing::debug!("Request: method={}, url={}", req.method(), req.uri());

if req.method() == &Method::OPTIONS {
return options_handler(req).await;
}

if hyper_tungstenite::is_upgrade_request(&req) {
tracing::debug!("Request: upgrading to websocket");
let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?;

// Get the bank ID from the request query string
let bank_path = req.uri().path().to_lowercase();
let bank_id = match bank_path.split_once("/bank/") {
Some((_, bank_id)) => BankId::from(bank_id),
None => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid Request")))?);
}
};

// Spawn a task to handle the websocket connection.
tokio::spawn(async move {
if let Err(e) = handle_ws(websocket, bank_id, state).await {
tracing::error!("Error in websocket connection: {e}");
}
});

// Return the response so the spawned future can continue.
return Ok(response);
}

// If its a GET request
if req.method() == Method::GET {
return get_handler(req).await;
}

// Maybe its a API call
if req.method() == Method::POST {
return post_handler(req, state).await;
}

Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid Request")))?)
}
6 changes: 6 additions & 0 deletions crates/backend/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

pub mod ws;
pub mod http;
pub mod get;
pub mod post;
pub mod options;
17 changes: 17 additions & 0 deletions crates/backend/src/handlers/options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use http::StatusCode;
use http_body_util::Full;
use hyper::{body::Bytes, Request, Response};

pub async fn options_handler(
req: Request<hyper::body::Incoming>,
) -> anyhow::Result<Response<Full<Bytes>>> {
let res = Response::builder()
.status(StatusCode::OK)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Headers", "*")
.header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
.body(Full::default())?;

tracing::debug!("Response: status={}, url={}", res.status(), req.uri());
Ok(res)
}
84 changes: 84 additions & 0 deletions crates/backend/src/handlers/post.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use http_body_util::{BodyExt, Full};

use crate::{general_state::GeneralState, BROKER_SECRET};

// Sending more than a megabyte is prevented
pub const MAX_REQUEST_BODY: usize = 1024 * 1024;

pub async fn post_handler(
req: Request<hyper::body::Incoming>,
state: GeneralState,
) -> anyhow::Result<Response<Full<Bytes>>> {
let (parts, mut body) = req.into_parts();

// Read all the data to a particular limit and then fail
// (this is to prevent DDOS attacks)
let mut data: Vec<u8> = Vec::with_capacity(4906);
while let Some(frame) = body.frame().await {
if let Some(frame) = frame?.data_ref() {
if data.len() > frame.len() {
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from("DDOS protection")))?);
}
data.extend_from_slice(frame);
}
}

// pattern match for both the method and the path of the request
let res = match (parts.method, parts.uri.path()) {
(hyper::Method::POST, "/update-bank") => {
let req = serde_json::from_slice(&data)?;
let state_inner = state.clone();
let res = state
.inner
.lock()
.await
.ledger
.update_bank(req, move |msg| state_inner.broadcast(msg))?;
tracing::info!("UpdateBank-Response: {:?}", res);
serde_json::to_vec_pretty(&res)?
}
(hyper::Method::POST, "/new-bank") => {
let req = serde_json::from_slice(&data)?;
let state_inner = state.clone();
let res =
state
.inner
.lock()
.await
.ledger
.new_bank(&BROKER_SECRET, req, move |msg| state_inner.broadcast(msg))?;
tracing::info!("NewBank-Response: {:?}", res);
serde_json::to_vec_pretty(&res)?
}
(hyper::Method::POST, "/transfer") => {
let req = serde_json::from_slice(&data)?;
let state_inner = state.clone();
let res = state
.inner
.lock()
.await
.ledger
.transfer(req, move |msg| state_inner.broadcast(msg))?;
tracing::info!("Transfer-Response: {:?}", res);
serde_json::to_vec_pretty(&res)?
}
// Anything else handler
_ => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid Request")))?)
}
};

Ok(Response::builder()
.status(StatusCode::OK)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Headers", "*")
.header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(res)))?)
}
Loading

0 comments on commit 0cf5685

Please sign in to comment.