Skip to content

Commit

Permalink
Merge pull request #288 from crabnebula-dev/feat/add-cors-origin-api
Browse files Browse the repository at this point in the history
feat: add API to allow an origin to be allowed by CORS
  • Loading branch information
lucasfernog-crabnebula authored May 23, 2024
2 parents b34c109 + 90ef64e commit d9e9f6c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 18 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/devtools-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ bytes = "1.5.0"
ringbuf = "0.4.0-rc.3"
async-stream = "0.3.5"
http = "0.2"
hyper = "0.14"
tower = "0.4"
1 change: 0 additions & 1 deletion crates/devtools-core/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ impl<T, const CAP: usize> EventBuf<T, CAP> {
}

/// Push an event into the buffer, overwriting the oldest event if the buffer is full.
// TODO does it really make sense to track the dropped events here?
pub fn push_overwrite(&mut self, item: T) {
if self.inner.push_overwrite(item).is_some() {
self.sent = self.sent.saturating_sub(1);
Expand Down
116 changes: 104 additions & 12 deletions crates/devtools-core/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@ use devtools_wire_format::sources::sources_server::SourcesServer;
use devtools_wire_format::tauri::tauri_server;
use devtools_wire_format::tauri::tauri_server::TauriServer;
use futures::{FutureExt, TryStreamExt};
use http::HeaderValue;
use hyper::Body;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tonic::body::BoxBody;
use tonic::codegen::http::Method;
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
use tonic::codegen::BoxStream;
use tonic::{Request, Response, Status};
use tonic_health::pb::health_server::{Health, HealthServer};
use tonic_health::server::HealthReporter;
use tonic_health::ServingStatus;
use tower_http::cors::{AllowHeaders, CorsLayer};
use tower::Service;
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
use tower_layer::Layer;

/// Default maximum capacity for the channel of events sent from a
/// [`Server`] to each subscribed client.
Expand All @@ -28,15 +36,84 @@ use tower_http::cors::{AllowHeaders, CorsLayer};
const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;

/// The `gRPC` server that exposes the instrumenting API
pub struct Server(
tonic::transport::server::Router<tower_layer::Stack<CorsLayer, tower_layer::Identity>>,
);
pub struct Server {
router: tonic::transport::server::Router<
tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
>,
handle: ServerHandle,
}

/// A handle to a server that is allowed to modify its properties (such as CORS allowed origins)
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct ServerHandle {
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}

impl ServerHandle {
/// Allow the given origin in the instrumentation server CORS.
#[allow(clippy::missing_panics_doc)]
pub fn allow_origin(&self, origin: impl Into<AllowOrigin>) {
self.allowed_origins.lock().unwrap().push(origin.into());
}
}

struct InstrumentService {
tx: mpsc::Sender<Command>,
health_reporter: HealthReporter,
}

#[derive(Clone)]
struct DynamicCorsLayer {
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}

impl<S> Layer<S> for DynamicCorsLayer {
type Service = DynamicCors<S>;

fn layer(&self, service: S) -> Self::Service {
DynamicCors {
inner: service,
allowed_origins: self.allowed_origins.clone(),
}
}
}

#[derive(Debug, Clone)]
struct DynamicCors<S> {
inner: S,
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}

type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;

impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
let mut cors = CorsLayer::new()
// allow `GET` and `POST` when accessing the resource
.allow_methods([Method::GET, Method::POST])
.allow_headers(AllowHeaders::any());

for origin in &*self.allowed_origins.lock().unwrap() {
cors = cors.allow_origin(origin.clone());
}

Box::pin(cors.layer(self.inner.clone()).call(req))
}
}

impl Server {
#[allow(clippy::missing_panics_doc)]
pub fn new(
Expand All @@ -51,15 +128,22 @@ impl Server {
.set_serving::<InstrumentServer<InstrumentService>>()
.now_or_never();

let cors = CorsLayer::new()
// allow `GET` and `POST` when accessing the resource
.allow_methods([Method::GET, Method::POST])
.allow_headers(AllowHeaders::any())
.allow_origin(tower_http::cors::Any);
let allowed_origins =
Arc::new(Mutex::new(vec![
if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
AllowOrigin::from(tower_http::cors::Any)
} else {
HeaderValue::from_str("https://devtools.crabnebula.dev")
.unwrap()
.into()
},
]));

let router = tonic::transport::Server::builder()
.accept_http1(true)
.layer(cors)
.layer(DynamicCorsLayer {
allowed_origins: allowed_origins.clone(),
})
.add_service(tonic_web::enable(health_service))
.add_service(tonic_web::enable(InstrumentServer::new(
InstrumentService {
Expand All @@ -71,7 +155,15 @@ impl Server {
.add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
.add_service(tonic_web::enable(SourcesServer::new(sources_server)));

Self(router)
Self {
router,
handle: ServerHandle { allowed_origins },
}
}

#[must_use]
pub fn handle(&self) -> ServerHandle {
self.handle.clone()
}

/// Consumes this [`Server`] and returns a future that will execute the server.
Expand All @@ -82,7 +174,7 @@ impl Server {
pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
tracing::info!("Listening on {}", addr);

self.0.serve(addr).await?;
self.router.serve(addr).await?;

Ok(())
}
Expand Down
13 changes: 8 additions & 5 deletions crates/devtools/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod server;
use devtools_core::aggregator::Aggregator;
use devtools_core::layer::Layer;
use devtools_core::server::wire::tauri::tauri_server::TauriServer;
use devtools_core::server::Server;
use devtools_core::server::{Server, ServerHandle};
use devtools_core::Command;
pub use devtools_core::Error;
use devtools_core::{Result, Shared};
Expand Down Expand Up @@ -52,6 +52,7 @@ mod ios {

pub struct Devtools {
pub connection: ConnectionInfo,
pub server_handle: ServerHandle,
}

fn init_plugin<R: Runtime>(
Expand All @@ -64,10 +65,6 @@ fn init_plugin<R: Runtime>(
.setup(move |app_handle, _api| {
let (mut health_reporter, health_service) = tonic_health::server::health_reporter();

app_handle.manage(Devtools {
connection: connection_info(&addr),
});

health_reporter
.set_serving::<TauriServer<server::TauriService<R>>>()
.now_or_never()
Expand All @@ -87,6 +84,12 @@ fn init_plugin<R: Runtime>(
app_handle: app_handle.clone(),
},
);
let server_handle = server.handle();

app_handle.manage(Devtools {
connection: connection_info(&addr),
server_handle,
});

#[cfg(not(target_os = "ios"))]
print_link(&addr);
Expand Down

0 comments on commit d9e9f6c

Please sign in to comment.