-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathquic_async_client_hello_callback_server.rs
163 lines (139 loc) · 5.49 KB
/
quic_async_client_hello_callback_server.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use moka::sync::Cache;
use rand::{distributions::WeightedIndex, prelude::*};
use s2n_quic::{
provider::tls::s2n_tls::{
callbacks::{ClientHelloCallback, ConfigResolver, ConnectionFuture},
config::Config,
connection::Connection,
error::Error as S2nError,
},
Server,
};
use std::{error::Error, fmt::Display, pin::Pin, sync::Arc, time::Duration};
use tokio::{fs, sync::OnceCell};
/// NOTE: this certificate is to be used for demonstration purposes only!
pub static CERT_PEM_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../quic/s2n-quic-core/certs/cert.pem"
);
/// NOTE: this certificate is to be used for demonstration purposes only!
pub static KEY_PEM_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../quic/s2n-quic-core/certs/key.pem"
);
type Sni = String;
// A Config cache associated with as SNI (server name indication).
//
// Implements ClientHelloCallback, loading the certificates asynchronously,
// and caching the s2n_tls::config::Config for subsequent calls with the
// same SNI.
//
// An SNI, indicates which hostname the client is attempting to connect to.
// Some deployments could require configuring the s2n_tls::config::Config
// based on the SNI (certificate).
struct ConfigCache {
cache: Cache<Sni, Arc<OnceCell<Config>>>,
}
impl ConfigCache {
fn new() -> Self {
ConfigCache {
// store Config for up to 100 unique SNI
cache: Cache::new(100),
}
}
}
impl ClientHelloCallback for ConfigCache {
fn on_client_hello(
&self,
connection: &mut Connection,
) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, S2nError> {
let sni = connection
.server_name()
.ok_or_else(|| S2nError::application(Box::new(CustomError)))?
.to_string();
let once_cell_config = self
.cache
.get_with(sni.clone(), || Arc::new(OnceCell::new()));
if let Some(config) = once_cell_config.get() {
eprintln!("Config already cached for SNI: {}", sni);
connection.set_config(config.clone())?;
// return `None` if the Config is already in the cache
return Ok(None);
}
// simulate failure 75% of times and success 25% of the times
let choices = [true, false];
let weights = [3, 1];
let dist = WeightedIndex::new(weights).unwrap();
let fut = async move {
let fut = once_cell_config.get_or_try_init(|| async {
let simulated_network_call_failed = choices[dist.sample(&mut thread_rng())];
if simulated_network_call_failed {
eprintln!("simulated network call failed");
return Err(S2nError::application(Box::new(CustomError)));
}
eprintln!("resolving certificate for SNI: {}", sni);
// load the cert and key file asynchronously.
let (cert, key) = {
// the SNI can be used to load the appropriate cert file
let _sni = sni;
let cert = fs::read_to_string(CERT_PEM_PATH)
.await
.map_err(|_| S2nError::application(Box::new(CustomError)))?;
let key = fs::read_to_string(KEY_PEM_PATH)
.await
.map_err(|_| S2nError::application(Box::new(CustomError)))?;
(cert, key)
};
// sleep(async tokio task which doesn't block thread) to mimic delay
tokio::time::sleep(Duration::from_secs(3)).await;
let config = s2n_quic::provider::tls::s2n_tls::Server::builder()
.with_certificate(cert, key)?
.build()?
.into();
Ok(config)
});
fut.await.cloned()
};
// return `Some(ConnectionFuture)` if the Config wasn't found in the
// cache and we need to load it asynchronously
Ok(Some(Box::pin(ConfigResolver::new(fut))))
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let tls = s2n_quic::provider::tls::s2n_tls::Server::builder()
.with_client_hello_handler(ConfigCache::new())?
.build()?;
let mut server = Server::builder()
.with_tls(tls)?
.with_io("127.0.0.1:4433")?
.start()?;
while let Some(mut connection) = server.accept().await {
// spawn a new task for the connection
tokio::spawn(async move {
eprintln!("Connection accepted from {:?}", connection.remote_addr());
while let Ok(Some(mut stream)) = connection.accept_bidirectional_stream().await {
// spawn a new task for the stream
tokio::spawn(async move {
eprintln!("Stream opened from {:?}", stream.connection().remote_addr());
// echo any data back to the stream
while let Ok(Some(data)) = stream.receive().await {
stream.send(data).await.expect("stream should be open");
}
});
}
});
}
Ok(())
}
#[derive(Debug)]
struct CustomError;
impl Display for CustomError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "custom error")?;
Ok(())
}
}
impl Error for CustomError {}