diff --git a/examples/src/tls_client_auth/server.rs b/examples/src/tls_client_auth/server.rs index 01e61f255..e7e25baf8 100644 --- a/examples/src/tls_client_auth/server.rs +++ b/examples/src/tls_client_auth/server.rs @@ -17,6 +17,10 @@ pub struct EchoServer; #[tonic::async_trait] impl pb::echo_server::Echo for EchoServer { async fn unary_echo(&self, request: Request) -> EchoResult { + if let Some(certs) = request.peer_certs() { + println!("Got {} peer certs!", certs.len()); + } + let message = request.into_inner().message; Ok(Response::new(EchoResponse { message })) } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index a4c2bde0d..b0b774d73 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -33,8 +33,8 @@ transport = [ "tower-load", "tracing-futures", ] -tls = ["tokio-rustls"] -tls-roots = ["rustls-native-certs"] +tls = ["transport", "tokio-rustls"] +tls-roots = ["tls", "rustls-native-certs"] # [[bench]] # name = "bench_main" diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 489c2a228..f1658b83b 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,7 +1,11 @@ use crate::metadata::MetadataMap; +#[cfg(feature = "transport")] +use crate::transport::Certificate; use futures_core::Stream; use http::Extensions; use std::net::SocketAddr; +#[cfg(feature = "transport")] +use std::sync::Arc; /// A gRPC request and metadata from an RPC call. #[derive(Debug)] @@ -14,6 +18,8 @@ pub struct Request { #[derive(Clone)] pub(crate) struct ConnectionInfo { pub(crate) remote_addr: Option, + #[cfg(feature = "transport")] + pub(crate) peer_certs: Option>>, } /// Trait implemented by RPC request types. @@ -188,6 +194,18 @@ impl Request { self.get::()?.remote_addr } + /// Get the peer certificates of the connected client. + /// + /// This is used to fetch the certificates from the TLS session + /// and is mostly used for mTLS. This currently only returns + /// `Some` on the server side of the `transport` server with + /// TLS enabled connections. + #[cfg(feature = "transport")] + #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] + pub fn peer_certs(&self) -> Option>> { + self.get::()?.peer_certs.clone() + } + pub(crate) fn get(&self) -> Option<&I> { self.extensions.get::() } diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index f50df1075..05f2ec0ca 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,7 +1,8 @@ +use crate::transport::Certificate; use hyper::server::conn::AddrStream; use std::net::SocketAddr; #[cfg(feature = "tls")] -use tokio_rustls::TlsStream; +use tokio_rustls::{rustls::Session, server::TlsStream}; /// Trait that connected IO resources implement. /// @@ -13,6 +14,11 @@ pub trait Connected { fn remote_addr(&self) -> Option { None } + + /// Return the set of connected peer TLS certificates. + fn peer_certs(&self) -> Option> { + None + } } impl Connected for AddrStream { @@ -27,4 +33,18 @@ impl Connected for TlsStream { let (inner, _) = self.get_ref(); inner.remote_addr() } + + fn peer_certs(&self) -> Option> { + let (_, session) = self.get_ref(); + + if let Some(certs) = session.get_peer_certificates() { + let certs = certs + .into_iter() + .map(|c| Certificate::from_pem(c.0)) + .collect(); + Some(certs) + } else { + None + } + } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 9158545cc..e03e03774 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -505,6 +505,7 @@ where fn call(&mut self, io: &ServerIo) -> Self::Future { let conn_info = crate::request::ConnectionInfo { remote_addr: io.remote_addr(), + peer_certs: io.peer_certs().map(Arc::new), }; let interceptor = self.interceptor.clone(); diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index a72bc1d4f..98961507d 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,4 +1,4 @@ -use crate::transport::server::Connected; +use crate::transport::{server::Connected, Certificate}; use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; use std::net::SocketAddr; @@ -71,8 +71,11 @@ impl ServerIo { impl Connected for ServerIo { fn remote_addr(&self) -> Option { - let io = &*self.0; - io.remote_addr() + (&*self.0).remote_addr() + } + + fn peer_certs(&self) -> Option> { + (&self.0).peer_certs() } } diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index f6b79868b..e6cd2d8b4 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -157,17 +157,17 @@ impl TlsAcceptor { }) } - pub(crate) async fn accept(&self, io: IO) -> Result + pub(crate) async fn accept( + &self, + io: IO, + ) -> Result, crate::Error> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, { - let io = { - let acceptor = RustlsAcceptor::from(self.inner.clone()); - let tls = acceptor.accept(io).await?; - BoxedIo::new(tls) - }; + let acceptor = RustlsAcceptor::from(self.inner.clone()); + let tls = acceptor.accept(io).await?; - Ok(io) + Ok(tls) } }