Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support customized Cid generation #851

Merged
merged 19 commits into from
Oct 3, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions quinn-proto/src/cid_generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use rand::RngCore;

use crate::shared::ConnectionId;
use crate::MAX_CID_SIZE;

/// Generates connection IDs for incoming connections
pub trait ConnectionIdGenerator: Send {
/// Generates a new CID
///
/// Connection IDs MUST NOT contain any information that can be used by
/// an external observer (that is, one that does not cooperate with the
/// issuer) to correlate them with other connection IDs for the same
/// connection.
fn generate_cid(&mut self) -> ConnectionId;
/// Performs any validation if it is needed (e.g. HMAC, etc)
///
/// Apply validation check on those CIDs that may still exist in hash table
/// but considered invalid by application-layer logic.
/// e.g. we may want to limit the amount of time for which a CID is valid
/// in order to reduce the number of valid IDs that could be accumulated
/// by an attacker.
fn validate_cid(&mut self, _cid: &ConnectionId) -> bool {
true
}
/// Returns the length of a CID for cononections created by this generator
fn cid_len(&self) -> usize;
}

/// Generates purely random connection IDs of a certain length
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
cid_len: usize,
}
impl Default for RandomConnectionIdGenerator {
fn default() -> Self {
Self { cid_len: 8 }
}
}
impl RandomConnectionIdGenerator {
/// Initialize Random CID generator with a fixed CID length (which must be less or equal to MAX_CID_SIZE)
pub fn new(cid_len: usize) -> Self {
debug_assert!(cid_len <= MAX_CID_SIZE);
Self { cid_len }
}
}
impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; MAX_CID_SIZE];
rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);

ConnectionId::new(&bytes_arr[..self.cid_len])
}

/// Provide the length of dst_cid in short header packet
fn cid_len(&self) -> usize {
self.cid_len
}
}
44 changes: 28 additions & 16 deletions quinn-proto/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use rand::RngCore;
#[cfg(feature = "rustls")]
use crate::crypto::types::{Certificate, CertificateChain, PrivateKey};
use crate::{
cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator},
congestion,
crypto::{self, ClientConfig as _, HmacKey as _, ServerConfig as _},
VarInt, MAX_CID_SIZE,
VarInt,
};

/// Parameters governing the core QUIC state machine
Expand Down Expand Up @@ -297,9 +298,13 @@ pub struct EndpointConfig<S>
where
S: crypto::Session,
{
pub(crate) local_cid_len: usize,
pub(crate) reset_key: Arc<S::HmacKey>,
pub(crate) max_udp_payload_size: u64,
/// CID generator factory
///
/// Create a cid generator for local cid in Endpoint struct
pub(crate) connection_id_generator_factory:
Arc<dyn Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync>,
}

impl<S> EndpointConfig<S>
Expand All @@ -308,25 +313,32 @@ where
{
/// Create a default config with a particular `reset_key`
pub fn new(reset_key: S::HmacKey) -> Self {
let cid_factory: fn() -> Box<dyn ConnectionIdGenerator> =
|| Box::new(RandomConnectionIdGenerator::default());
Self {
local_cid_len: 8,
reset_key: Arc::new(reset_key),
max_udp_payload_size: MAX_UDP_PAYLOAD_SIZE,
connection_id_generator_factory: Arc::new(cid_factory),
}
}

/// Length of connection IDs for the endpoint.
/// Supply a custom connection ID generator factory
///
/// This must be no greater than 20. If zero, incoming packets are mapped to connections only by
/// their source address. Otherwise, the connection ID field is used alone, allowing for source
/// address to change and for multiple connections from a single address. When local_cid_len >
/// 0, at most 3/4 * 2^(local_cid_len * 8) simultaneous connections can be supported.
pub fn local_cid_len(&mut self, value: usize) -> Result<&mut Self, ConfigError> {
if value > MAX_CID_SIZE {
return Err(ConfigError::OutOfBounds);
}
self.local_cid_len = value;
Ok(self)
/// Called once by each `Endpoint` constructed from this configuration to obtain the CID generator which will
/// be used to generate the CIDs used for incoming packets on all connections involving that `Endpoint`. A
/// custom CID generator allows applications to embed information in local connection IDs, e.g. to support
/// stateless packet-level load balancers.
///
///
/// EndpointConfig::new() applies a default random CID generator factory.
/// This functions accepts any customized CID generator to reset CID generator factory that
/// implements ConnectionIdGenerator trait
pub fn cid_generator<F: Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync + 'static>(
&mut self,
factory: F,
) -> &mut Self {
self.connection_id_generator_factory = Arc::new(factory);
self
}

/// Private key used to send authenticated connection resets to peers who were
Expand Down Expand Up @@ -359,9 +371,9 @@ where
impl<S: crypto::Session> fmt::Debug for EndpointConfig<S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("EndpointConfig")
.field("local_cid_len", &self.local_cid_len)
.field("reset_key", &"[ elided ]")
.field("max_udp_payload_size", &self.max_udp_payload_size)
.field("cid_generator_factory", &"[ elided ]")
.finish()
}
}
Expand All @@ -380,9 +392,9 @@ impl<S: crypto::Session> Default for EndpointConfig<S> {
impl<S: crypto::Session> Clone for EndpointConfig<S> {
fn clone(&self) -> Self {
Self {
local_cid_len: self.local_cid_len,
reset_key: self.reset_key.clone(),
max_udp_payload_size: self.max_udp_payload_size,
connection_id_generator_factory: self.connection_id_generator_factory.clone(),
}
}
}
Expand Down
15 changes: 8 additions & 7 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tracing::{debug, error, trace, trace_span, warn};
use crate::{
cid_queue::CidQueue,
coding::BufMutExt,
config::{EndpointConfig, ServerConfig, TransportConfig},
config::{ServerConfig, TransportConfig},
crypto::{self, HeaderKey, KeyPair, Keys, PacketKey},
frame,
frame::{Close, Datagram, FrameStruct},
Expand Down Expand Up @@ -58,7 +58,6 @@ pub struct Connection<S>
where
S: crypto::Session,
{
endpoint_config: Arc<EndpointConfig<S>>,
server_config: Option<Arc<ServerConfig<S>>>,
config: Arc<TransportConfig>,
rng: StdRng,
Expand All @@ -76,6 +75,8 @@ where
/// Exactly one prior to `self.rem_cids.offset` except during processing of certain
/// NEW_CONNECTION_ID frames.
rem_cid_seq: u64,
/// cid length used to decode short packet
local_cid_len: usize,
path: PathData,
prev_path: Option<PathData>,
state: State,
Expand Down Expand Up @@ -163,7 +164,6 @@ where
S: crypto::Session,
{
pub(crate) fn new(
endpoint_config: Arc<EndpointConfig<S>>,
server_config: Option<Arc<ServerConfig<S>>>,
config: Arc<TransportConfig>,
init_cid: ConnectionId,
Expand All @@ -172,6 +172,7 @@ where
remote: SocketAddr,
crypto: S,
now: Instant,
local_cid_len: usize,
) -> Self {
let side = if server_config.is_some() {
Side::Server
Expand All @@ -192,13 +193,13 @@ where
.as_ref()
.map_or(false, |c| c.use_stateless_retry);
let mut this = Self {
endpoint_config,
server_config,
crypto,
handshake_cid: loc_cid,
rem_cid,
rem_handshake_cid: rem_cid,
rem_cid_seq: 0,
local_cid_len,
path: PathData::new(
remote,
config.initial_rtt,
Expand Down Expand Up @@ -1641,7 +1642,7 @@ where
self.total_recvd = self.total_recvd.wrapping_add(data.len() as u64);
let mut remaining = Some(data);
while let Some(data) = remaining {
match PartialDecode::new(data, self.endpoint_config.local_cid_len) {
match PartialDecode::new(data, self.local_cid_len) {
Ok((partial_decode, rest)) => {
remaining = rest;
self.handle_decode(now, remote, ecn, partial_decode);
Expand Down Expand Up @@ -2268,7 +2269,7 @@ where
self.streams.received_stop_sending(id, error_code);
}
Frame::RetireConnectionId { sequence } => {
if self.endpoint_config.local_cid_len == 0 {
if self.local_cid_len == 0 {
return Err(TransportError::PROTOCOL_VIOLATION(
"RETIRE_CONNECTION_ID when CIDs aren't in use",
));
Expand Down Expand Up @@ -2474,7 +2475,7 @@ where

/// Issue an initial set of connection IDs to the peer
fn issue_cids(&mut self) {
if self.endpoint_config.local_cid_len == 0 {
if self.local_cid_len == 0 {
return;
}

Expand Down
Loading