Skip to content

Commit

Permalink
Move MaxmindDb::lookup to session creation (#968)
Browse files Browse the repository at this point in the history
  • Loading branch information
XAMPPRocky authored May 29, 2024
1 parent bbd0092 commit 037756a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 50 deletions.
25 changes: 6 additions & 19 deletions src/components/proxy/packet_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use tokio::sync::mpsc;
/// Packet received from local port
#[derive(Debug)]
struct DownstreamPacket {
asn_info: Option<crate::net::maxmind_db::IpNetEntry>,
contents: PoolBuffer,
received_at: UtcTimestamp,
source: SocketAddr,
Expand Down Expand Up @@ -130,17 +129,13 @@ impl DownstreamReceiveWorkerConfig {
source.set_ip(source.ip().to_canonical());
let packet = DownstreamPacket {
received_at: UtcTimestamp::now(),
asn_info: crate::net::maxmind_db::MaxmindDb::lookup(source.ip()),
contents,
source,
};

if let Some(last_received_at) = last_received_at {
crate::metrics::packet_jitter(
crate::metrics::READ,
packet.asn_info.as_ref(),
)
.set((packet.received_at - last_received_at).nanos());
crate::metrics::packet_jitter(crate::metrics::READ, None)
.set((packet.received_at - last_received_at).nanos());
}
last_received_at = Some(packet.received_at);

Expand Down Expand Up @@ -184,19 +179,13 @@ impl DownstreamReceiveWorkerConfig {
);

let timer = crate::metrics::processing_time(crate::metrics::READ).start_timer();
let asn_info = packet.asn_info.clone();
let asn_info = asn_info.as_ref();
match Self::process_downstream_received_packet(packet, config, sessions).await {
Ok(()) => {}
Err(error) => {
let discriminant = PipelineErrorDiscriminants::from(&error).to_string();
crate::metrics::errors_total(crate::metrics::READ, &discriminant, asn_info).inc();
crate::metrics::packets_dropped_total(
crate::metrics::READ,
&discriminant,
asn_info,
)
.inc();
crate::metrics::errors_total(crate::metrics::READ, &discriminant, None).inc();
crate::metrics::packets_dropped_total(crate::metrics::READ, &discriminant, None)
.inc();
let _ = error_sender.send(error);
}
}
Expand Down Expand Up @@ -241,9 +230,7 @@ impl DownstreamReceiveWorkerConfig {
dest: epa.to_socket_addr().await?,
};

sessions
.send(session_key, packet.asn_info.clone(), contents.clone())
.await?;
sessions.send(session_key, contents.clone()).await?;
}

Ok(())
Expand Down
59 changes: 28 additions & 31 deletions src/components/proxy/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ impl SessionPool {
async fn create_new_session_from_new_socket<'pool>(
self: &'pool Arc<Self>,
key: SessionKey,
asn_info: Option<IpNetEntry>,
) -> Result<UpstreamSender, super::PipelineError> {
) -> Result<(Option<IpNetEntry>, UpstreamSender), super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "creating new socket for session");
let raw_socket = crate::net::raw_socket_with_reuse(0)?;
let port = raw_socket
Expand Down Expand Up @@ -209,7 +208,7 @@ impl SessionPool {
initialised.await.map_err(|error| eyre::eyre!(error))??;

self.ports_to_sockets.write().await.insert(port, tx.clone());
self.create_session_from_existing_socket(key, tx, port, asn_info)
self.create_session_from_existing_socket(key, tx, port)
.await
}

Expand Down Expand Up @@ -268,13 +267,12 @@ impl SessionPool {
pub async fn get<'pool>(
self: &'pool Arc<Self>,
key @ SessionKey { dest, .. }: SessionKey,
asn_info: Option<IpNetEntry>,
) -> Result<UpstreamSender, super::PipelineError> {
) -> Result<(Option<IpNetEntry>, UpstreamSender), super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get");
// If we already have a session for the key pairing, return that session.
if let Some(entry) = self.session_map.get(&key) {
tracing::trace!("returning existing session");
return Ok(entry.upstream_sender.clone());
return Ok((entry.asn_info.clone(), entry.upstream_sender.clone()));
}

// If there's a socket_set available, it means there are sockets
Expand All @@ -285,7 +283,7 @@ impl SessionPool {
let no_sockets = self.ports_to_sockets.read().await.is_empty();
return if no_sockets {
// Initial case where we have no allocated or reserved sockets.
self.create_new_session_from_new_socket(key, asn_info).await
self.create_new_session_from_new_socket(key).await
} else {
// Where we have no allocated sockets for a destination, assign
// the first available one.
Expand All @@ -301,7 +299,7 @@ impl SessionPool {
})
.map_err(super::PipelineError::Session)?;

self.create_session_from_existing_socket(key, sender, port, asn_info)
self.create_session_from_existing_socket(key, sender, port)
.await
};
};
Expand All @@ -326,11 +324,11 @@ impl SessionPool {
})
.map_err(super::PipelineError::Session)?
.insert(port);
self.create_session_from_existing_socket(key, socket, port, asn_info)
self.create_session_from_existing_socket(key, socket, port)
.await
} else {
drop(storage);
self.create_new_session_from_new_socket(key, asn_info).await
self.create_new_session_from_new_socket(key).await
}
}

Expand All @@ -340,8 +338,7 @@ impl SessionPool {
key: SessionKey,
upstream_sender: UpstreamSender,
socket_port: u16,
asn_info: Option<IpNetEntry>,
) -> Result<UpstreamSender, super::PipelineError> {
) -> Result<(Option<IpNetEntry>, UpstreamSender), super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "reusing socket for session");
let mut storage = self.storage.write().await;
storage
Expand All @@ -358,6 +355,8 @@ impl SessionPool {
.destination_to_sources
.insert((key.dest, socket_port), key.source);

let asn_info = crate::net::maxmind_db::MaxmindDb::lookup(key.source.ip());

if let Some(asn_info) = &asn_info {
storage
.sources_to_asn_info
Expand All @@ -370,12 +369,12 @@ impl SessionPool {
upstream_sender.clone(),
socket_port,
self.clone(),
asn_info,
asn_info.clone(),
)?;
tracing::trace!("inserting session into map");
self.session_map.insert(key, session);
tracing::trace!("session inserted");
Ok(upstream_sender)
Ok((asn_info, upstream_sender))
}

/// process_recv_packet processes a packet that is received by this session.
Expand Down Expand Up @@ -413,13 +412,13 @@ impl SessionPool {
pub async fn send(
self: &Arc<Self>,
key: SessionKey,
asn_info: Option<IpNetEntry>,
packet: FrozenPoolBuffer,
) -> Result<(), super::PipelineError> {
use tokio::sync::mpsc::error::TrySendError;

self.get(key, asn_info.clone())
.await?
let (asn_info, sender) = self.get(key).await?;

sender
.try_send((packet, asn_info, key.dest))
.map_err(|error| match error {
TrySendError::Closed(_) => super::PipelineError::ChannelClosed,
Expand Down Expand Up @@ -622,7 +621,7 @@ mod tests {
)
.into();

let _session = pool.get(key, None).await.unwrap();
let _session = pool.get(key).await.unwrap();

assert!(pool.drop_session(key).await);

Expand All @@ -643,8 +642,8 @@ mod tests {
)
.into();

let _session1 = pool.get(key1, None).await.unwrap();
let _session2 = pool.get(key2, None).await.unwrap();
let _session1 = pool.get(key1).await.unwrap();
let _session2 = pool.get(key2).await.unwrap();

assert!(pool.drop_session(key1).await);
assert!(!pool.has_no_allocated_sockets().await);
Expand All @@ -668,8 +667,8 @@ mod tests {
)
.into();

let _socket1 = pool.get(key1, None).await.unwrap();
let _socket2 = pool.get(key2, None).await.unwrap();
let _socket1 = pool.get(key1).await.unwrap();
let _socket2 = pool.get(key2).await.unwrap();
assert_ne!(
pool.session_map.get(&key1).unwrap().socket_port,
pool.session_map.get(&key2).unwrap().socket_port
Expand All @@ -693,8 +692,8 @@ mod tests {
)
.into();

let _socket1 = pool.get(key1, None).await.unwrap();
let _socket2 = pool.get(key2, None).await.unwrap();
let _socket1 = pool.get(key1).await.unwrap();
let _socket2 = pool.get(key2).await.unwrap();

assert_eq!(
pool.session_map.get(&key1).unwrap().socket_port,
Expand All @@ -716,13 +715,13 @@ mod tests {
)
.into();

let socket1 = pool.get(key1, None).await.unwrap();
let socket1 = pool.get(key1).await.unwrap();

let task = tokio::spawn(async move {
let _ = socket1;
});

let _socket2 = pool.get(key2, None).await.unwrap();
let _socket2 = pool.get(key2).await.unwrap();

task.await.unwrap();
}
Expand All @@ -741,13 +740,13 @@ mod tests {
)
.into();

let socket1 = pool.get(key1, None).await.unwrap();
let socket1 = pool.get(key1).await.unwrap();

let task = tokio::spawn(async move {
let _ = socket1;
});

let _socket2 = pool.get(key2, None).await.unwrap();
let _socket2 = pool.get(key2).await.unwrap();

task.await.unwrap();
}
Expand All @@ -767,9 +766,7 @@ mod tests {
let key: SessionKey = (source, dest).into();
let msg = b"helloworld";

pool.send(key, None, alloc_buffer(msg).freeze())
.await
.unwrap();
pool.send(key, alloc_buffer(msg).freeze()).await.unwrap();

let (data, _, _) = tokio::time::timeout(std::time::Duration::from_secs(1), receiver.recv())
.await
Expand Down

0 comments on commit 037756a

Please sign in to comment.