diff --git a/Cargo.lock b/Cargo.lock index 1450fb9a..14127b95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -651,6 +651,7 @@ dependencies = [ "devtools-wire-format", "futures", "http 0.2.12", + "hyper 0.14.28", "prost-types", "ringbuf", "thiserror", @@ -659,6 +660,7 @@ dependencies = [ "tonic", "tonic-health", "tonic-web", + "tower", "tower-http", "tower-layer", "tracing", diff --git a/crates/devtools-core/Cargo.toml b/crates/devtools-core/Cargo.toml index 31c77a91..9194ab05 100644 --- a/crates/devtools-core/Cargo.toml +++ b/crates/devtools-core/Cargo.toml @@ -29,3 +29,5 @@ bytes = "1.5.0" ringbuf = "0.4.0-rc.3" async-stream = "0.3.5" http = "0.2" +hyper = "0.14" +tower = "0.4" diff --git a/crates/devtools-core/src/aggregator.rs b/crates/devtools-core/src/aggregator.rs index 452837f5..41e4eee5 100644 --- a/crates/devtools-core/src/aggregator.rs +++ b/crates/devtools-core/src/aggregator.rs @@ -321,7 +321,6 @@ impl EventBuf { } /// Push an event into the buffer, overwriting the oldest event if the buffer is full. - // TODO does it really make sense to track the dropped events here? pub fn push_overwrite(&mut self, item: T) { if self.inner.push_overwrite(item).is_some() { self.sent = self.sent.saturating_sub(1); diff --git a/crates/devtools-core/src/server.rs b/crates/devtools-core/src/server.rs index 62665393..1aff727f 100644 --- a/crates/devtools-core/src/server.rs +++ b/crates/devtools-core/src/server.rs @@ -9,8 +9,14 @@ use devtools_wire_format::sources::sources_server::SourcesServer; use devtools_wire_format::tauri::tauri_server; use devtools_wire_format::tauri::tauri_server::TauriServer; use futures::{FutureExt, TryStreamExt}; +use http::HeaderValue; +use hyper::Body; use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; use tokio::sync::mpsc; +use tonic::body::BoxBody; use tonic::codegen::http::Method; use tonic::codegen::tokio_stream::wrappers::ReceiverStream; use tonic::codegen::BoxStream; @@ -18,7 +24,9 @@ use tonic::{Request, Response, Status}; use tonic_health::pb::health_server::{Health, HealthServer}; use tonic_health::server::HealthReporter; use tonic_health::ServingStatus; -use tower_http::cors::{AllowHeaders, CorsLayer}; +use tower::Service; +use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer}; +use tower_layer::Layer; /// Default maximum capacity for the channel of events sent from a /// [`Server`] to each subscribed client. @@ -28,15 +36,84 @@ use tower_http::cors::{AllowHeaders, CorsLayer}; const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4; /// The `gRPC` server that exposes the instrumenting API -pub struct Server( - tonic::transport::server::Router>, -); +pub struct Server { + router: tonic::transport::server::Router< + tower_layer::Stack, + >, + handle: ServerHandle, +} + +/// A handle to a server that is allowed to modify its properties (such as CORS allowed origins) +#[allow(clippy::module_name_repetitions)] +#[derive(Clone)] +pub struct ServerHandle { + allowed_origins: Arc>>, +} + +impl ServerHandle { + /// Allow the given origin in the instrumentation server CORS. + #[allow(clippy::missing_panics_doc)] + pub fn allow_origin(&self, origin: impl Into) { + self.allowed_origins.lock().unwrap().push(origin.into()); + } +} struct InstrumentService { tx: mpsc::Sender, health_reporter: HealthReporter, } +#[derive(Clone)] +struct DynamicCorsLayer { + allowed_origins: Arc>>, +} + +impl Layer for DynamicCorsLayer { + type Service = DynamicCors; + + fn layer(&self, service: S) -> Self::Service { + DynamicCors { + inner: service, + allowed_origins: self.allowed_origins.clone(), + } + } +} + +#[derive(Debug, Clone)] +struct DynamicCors { + inner: S, + allowed_origins: Arc>>, +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +impl Service> for DynamicCors +where + S: Service, Response = hyper::Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: hyper::Request) -> Self::Future { + let mut cors = CorsLayer::new() + // allow `GET` and `POST` when accessing the resource + .allow_methods([Method::GET, Method::POST]) + .allow_headers(AllowHeaders::any()); + + for origin in &*self.allowed_origins.lock().unwrap() { + cors = cors.allow_origin(origin.clone()); + } + + Box::pin(cors.layer(self.inner.clone()).call(req)) + } +} + impl Server { #[allow(clippy::missing_panics_doc)] pub fn new( @@ -51,15 +128,22 @@ impl Server { .set_serving::>() .now_or_never(); - let cors = CorsLayer::new() - // allow `GET` and `POST` when accessing the resource - .allow_methods([Method::GET, Method::POST]) - .allow_headers(AllowHeaders::any()) - .allow_origin(tower_http::cors::Any); + let allowed_origins = + Arc::new(Mutex::new(vec![ + if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() { + AllowOrigin::from(tower_http::cors::Any) + } else { + HeaderValue::from_str("https://devtools.crabnebula.dev") + .unwrap() + .into() + }, + ])); let router = tonic::transport::Server::builder() .accept_http1(true) - .layer(cors) + .layer(DynamicCorsLayer { + allowed_origins: allowed_origins.clone(), + }) .add_service(tonic_web::enable(health_service)) .add_service(tonic_web::enable(InstrumentServer::new( InstrumentService { @@ -71,7 +155,15 @@ impl Server { .add_service(tonic_web::enable(MetadataServer::new(metadata_server))) .add_service(tonic_web::enable(SourcesServer::new(sources_server))); - Self(router) + Self { + router, + handle: ServerHandle { allowed_origins }, + } + } + + #[must_use] + pub fn handle(&self) -> ServerHandle { + self.handle.clone() } /// Consumes this [`Server`] and returns a future that will execute the server. @@ -82,7 +174,7 @@ impl Server { pub async fn run(self, addr: SocketAddr) -> crate::Result<()> { tracing::info!("Listening on {}", addr); - self.0.serve(addr).await?; + self.router.serve(addr).await?; Ok(()) } diff --git a/crates/devtools/src/lib.rs b/crates/devtools/src/lib.rs index 0a7b84b4..c3045cf8 100644 --- a/crates/devtools/src/lib.rs +++ b/crates/devtools/src/lib.rs @@ -3,7 +3,7 @@ mod server; use devtools_core::aggregator::Aggregator; use devtools_core::layer::Layer; use devtools_core::server::wire::tauri::tauri_server::TauriServer; -use devtools_core::server::Server; +use devtools_core::server::{Server, ServerHandle}; use devtools_core::Command; pub use devtools_core::Error; use devtools_core::{Result, Shared}; @@ -52,6 +52,7 @@ mod ios { pub struct Devtools { pub connection: ConnectionInfo, + pub server_handle: ServerHandle, } fn init_plugin( @@ -64,10 +65,6 @@ fn init_plugin( .setup(move |app_handle, _api| { let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); - app_handle.manage(Devtools { - connection: connection_info(&addr), - }); - health_reporter .set_serving::>>() .now_or_never() @@ -87,6 +84,12 @@ fn init_plugin( app_handle: app_handle.clone(), }, ); + let server_handle = server.handle(); + + app_handle.manage(Devtools { + connection: connection_info(&addr), + server_handle, + }); #[cfg(not(target_os = "ios"))] print_link(&addr);