Skip to content

Commit

Permalink
Fix dead lock in SessionPool
Browse files Browse the repository at this point in the history
The problem was that in dashmap we were occasionally getting unlucky and
having sessions sometimes end up in the same shard, which would cause a
deadlock since one would be trying to insert into a shard while we still
hold a reference to one.

Fortunately this is a pretty simple fix of not returning a reference at
all and instead just returning the socket, since that was the only thing
that was used out of the session struct.
  • Loading branch information
XAMPPRocky committed Oct 18, 2023
1 parent cc2244b commit 2a4b520
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 67 deletions.
123 changes: 69 additions & 54 deletions src/proxy/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,10 @@ use crate::{
utils::{net::DualStackLocalSocket, Loggable},
};

use dashmap::DashMap;

pub(crate) mod metrics;

pub type SessionMap = crate::ttl_map::TtlMap<SessionKey, Session>;

type SessionRef<'pool> =
dashmap::mapref::one::Ref<'pool, SessionKey, crate::ttl_map::Value<Session>>;

/// A data structure that is responsible for holding sessions, and pooling
/// sockets between them. This means that we only provide new unique sockets
/// to new connections to the same gameserver, and we share sockets across
Expand All @@ -51,7 +46,7 @@ type SessionRef<'pool> =
/// send back to the original client.
#[derive(Debug)]
pub struct SessionPool {
ports_to_sockets: DashMap<u16, Arc<DualStackLocalSocket>>,
ports_to_sockets: RwLock<HashMap<u16, Arc<DualStackLocalSocket>>>,
storage: Arc<RwLock<SocketStorage>>,
session_map: SessionMap,
downstream_socket: Arc<DualStackLocalSocket>,
Expand Down Expand Up @@ -95,10 +90,14 @@ impl SessionPool {
self: &'pool Arc<Self>,
key: SessionKey,
asn_info: Option<IpNetEntry>,
) -> Result<SessionRef<'pool>, super::PipelineError> {
) -> Result<Arc<DualStackLocalSocket>, super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "creating new socket for session");
let socket = DualStackLocalSocket::new(0).map(Arc::new)?;
let port = socket.local_ipv4_addr().unwrap().port();
self.ports_to_sockets.insert(port, socket.clone());
self.ports_to_sockets
.write()
.await
.insert(port, socket.clone());

let upstream_socket = socket.clone();
let pool = self.clone();
Expand Down Expand Up @@ -181,51 +180,59 @@ impl SessionPool {
self: &'pool Arc<Self>,
key @ SessionKey { dest, .. }: SessionKey,
asn_info: Option<IpNetEntry>,
) -> Result<SessionRef<'pool>, super::PipelineError> {
) -> Result<Arc<DualStackLocalSocket>, super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get");
// If we already have a session for the key pairing, return that session.
if let Some(entry) = self.session_map.get(&key) {
return Ok(entry);
tracing::trace!("returning existing session");
return Ok(entry.socket.clone());
}

// If there's a socket_set available, it means there are sockets
// allocated to the address that we want to avoid.
let storage = self.storage.read().await;
let Some(socket_set) = storage.destination_to_sockets.get(&dest) else {
drop(storage);
return if self.ports_to_sockets.is_empty() {
let no_sockets = self.ports_to_sockets.read().await.is_empty();
return if no_sockets {
// Initial case where we have no allocated or reserved sockets.
self.create_new_session_from_new_socket(key, asn_info).await
} else {
// Where we have no allocated sockets for a destination, assign
// the first available one.
let entry = self.ports_to_sockets.iter().next().unwrap();
let port = *entry.key();
let (port, socket) = self
.ports_to_sockets
.read()
.await
.iter()
.next()
.map(|(port, socket)| (*port, socket.clone()))
.unwrap();

self.create_session_from_existing_socket(key, entry.value().clone(), port, asn_info)
self.create_session_from_existing_socket(key, socket, port, asn_info)
.await
};
};

if let Some(entry) = self
let available_socket = self
.ports_to_sockets
.read()
.await
.iter()
.find(|entry| !socket_set.contains(entry.key()))
{
.find(|(port, _)| !socket_set.contains(port))
.map(|(port, socket)| (*port, socket.clone()));

if let Some((port, socket)) = available_socket {
drop(storage);
self.storage
.write()
.await
.destination_to_sockets
.get_mut(&dest)
.unwrap()
.insert(*entry.key());
self.create_session_from_existing_socket(
key,
entry.value().clone(),
*entry.key(),
asn_info,
)
.await
.insert(port);
self.create_session_from_existing_socket(key, socket, port, asn_info)
.await
} else {
drop(storage);
self.create_new_session_from_new_socket(key, asn_info).await
Expand All @@ -239,7 +246,8 @@ impl SessionPool {
upstream_socket: Arc<DualStackLocalSocket>,
socket_port: u16,
asn_info: Option<IpNetEntry>,
) -> Result<SessionRef<'session>, super::PipelineError> {
) -> Result<Arc<DualStackLocalSocket>, super::PipelineError> {
tracing::trace!(source=%key.source, dest=%key.dest, "reusing socket for session");
let mut storage = self.storage.write().await;
storage
.destination_to_sockets
Expand All @@ -262,14 +270,17 @@ impl SessionPool {
}

drop(storage);
let session = Session::new(key, upstream_socket, socket_port, self.clone(), asn_info)?;
Ok(match self.session_map.entry(key) {
crate::ttl_map::Entry::Occupied(mut entry) => {
entry.insert(session);
entry.into_ref()
}
crate::ttl_map::Entry::Vacant(v) => v.insert(session).downgrade(),
})
let session = Session::new(
key,
upstream_socket.clone(),
socket_port,
self.clone(),
asn_info,
)?;
tracing::trace!("inserting session into map");
self.session_map.insert(key, session);
tracing::trace!("session inserted");
Ok(upstream_socket)
}

/// process_recv_packet processes a packet that is received by this session.
Expand Down Expand Up @@ -309,7 +320,7 @@ impl SessionPool {
) -> Result<usize, super::PipelineError> {
self.get(key, asn_info)
.await?
.send(packet)
.send_to(packet, key.dest)
.await
.map_err(From::from)
}
Expand All @@ -326,8 +337,7 @@ impl SessionPool {

/// Forces removal of session to make testing quicker.
#[cfg(test)]
async fn drop_session(&self, key: SessionKey, session: SessionRef<'_>) -> bool {
drop(session);
async fn drop_session(&self, key: SessionKey) -> bool {
let is_removed = self.session_map.remove(key);
// Sleep because there's no async drop.
tokio::time::sleep(Duration::from_millis(100)).await;
Expand All @@ -343,6 +353,7 @@ impl SessionPool {
}: SessionKey,
port: u16,
) {
tracing::trace!("releasing socket");
let mut storage = self.storage.write().await;
let socket_set = storage.destination_to_sockets.get_mut(dest).unwrap();

Expand All @@ -366,6 +377,7 @@ impl SessionPool {
.destination_to_sources
.remove(&(*dest, port))
.is_some());
tracing::trace!("socket released");
}
}

Expand Down Expand Up @@ -421,11 +433,6 @@ impl Session {
Ok(s)
}

pub async fn send(&self, packet: &[u8]) -> std::io::Result<usize> {
tracing::trace!(dest=%self.key.dest, "sending packet upstream");
self.socket.send_to(packet, self.key.dest).await
}

fn active_session_metric(&self) -> prometheus::IntGauge {
metrics::active_sessions(self.asn_info.as_ref())
}
Expand Down Expand Up @@ -512,9 +519,9 @@ mod tests {
)
.into();

let session = pool.get(key, None).await.unwrap();
let _session = pool.get(key, None).await.unwrap();

assert!(pool.drop_session(key, session).await);
assert!(pool.drop_session(key).await);

assert!(pool.has_no_allocated_sockets().await);
}
Expand All @@ -533,12 +540,12 @@ mod tests {
)
.into();

let session1 = pool.get(key1, None).await.unwrap();
let session2 = pool.get(key2, None).await.unwrap();
let _session1 = pool.get(key1, None).await.unwrap();
let _session2 = pool.get(key2, None).await.unwrap();

assert!(pool.drop_session(key1, session1).await);
assert!(pool.drop_session(key1).await);
assert!(!pool.has_no_allocated_sockets().await);
assert!(pool.drop_session(key2, session2).await);
assert!(pool.drop_session(key2).await);

assert!(pool.has_no_allocated_sockets().await);
drop(pool);
Expand All @@ -558,10 +565,15 @@ mod tests {
)
.into();

let socket1 = pool.get(key1, None).await.unwrap();
let socket2 = pool.get(key2, None).await.unwrap();
let _socket1 = pool.get(key1, None).await.unwrap();
let _socket2 = pool.get(key2, None).await.unwrap();
assert_ne!(
pool.session_map.get(&key1).unwrap().socket_port,
pool.session_map.get(&key2).unwrap().socket_port
);

assert_ne!(socket1.socket_port, socket2.socket_port);
assert!(pool.drop_session(key1).await);
assert!(pool.drop_session(key2).await);
}

#[tokio::test]
Expand All @@ -578,10 +590,13 @@ mod tests {
)
.into();

let socket1 = pool.get(key1, None).await.unwrap();
let socket2 = pool.get(key2, None).await.unwrap();
let _socket1 = pool.get(key1, None).await.unwrap();
let _socket2 = pool.get(key2, None).await.unwrap();

assert_eq!(socket1.socket_port, socket2.socket_port);
assert_eq!(
pool.session_map.get(&key1).unwrap().socket_port,
pool.session_map.get(&key2).unwrap().socket_port
);
}

#[tokio::test]
Expand Down
13 changes: 0 additions & 13 deletions src/ttl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,19 +308,6 @@ impl<'a, K, V> OccupiedEntry<'a, K, Value<V>>
where
K: Eq + Hash,
{
/// Returns a reference to the entry's value.
/// The value will be reset to expire at the configured TTL after the time of retrieval.
pub fn into_ref(self) -> Ref<'a, K, Value<V>> {
match self.inner {
DashMapEntry::Occupied(entry) => {
let value = entry.into_ref();
value.value().update_expiration(self.ttl);
value.downgrade()
}
_ => unreachable!("BUG: entry type should be occupied"),
}
}

/// Returns a reference to the entry's value.
/// The value will be reset to expire at the configured TTL after the time of retrieval.
pub fn get(&self) -> &Value<V> {
Expand Down

0 comments on commit 2a4b520

Please sign in to comment.