Skip to content

Commit

Permalink
Fix deadlock in auth cache (solves Tobira freezing problem) (#1141)
Browse files Browse the repository at this point in the history
Fixes #1129

See first commit description for the technical details on how this was
caused. But the gist is: I incorrectly used `DashMap`, holding locks
across await points. This causes a deadlock if the timing is right and
two specific tasks are scheduled to run in the same thread. I could have
fixed the code around the await point, but since this is a very easy and
subtle mistake to make, I decided to use a different concurrent hashmap
that can better deal with these scenarios.

The second commit also fixes the cache dealing with a 0 cache duration
(which is supposed to disable the cache). See commits for more details.

On the ETH test system I deployed v2.6 plus this patch and verified that
the freeze is not happening anymore in the only situation where I could
reliably reproduce it before: starting Tobira and immediately making an
authenticated request. Since the cache_duration was set to 0, the timing
somehow worked out most of the time. Doing that does not freeze Tobira
anymore with this patch (I tried several times).

---

<details>
<summary>Some additional tests/details</summary>

The `scc` hashmap has no problem when a lock is held on the same thread
that `retain` is called. I tested it like this:

```rust
#[tokio::main(flavor = "current_thread")]
async fn main() {
    let start = Instant::now();

    let map = HashMap::new();
    let _ = map.insert_async("foo", 4).await;
    let _ = map.insert_async("bar", 27).await;
    let map = Arc::new(map);

    {
        let map = Arc::clone(&map);
        tokio::spawn(async move {
            println!("--- {:.2?} calling entry", start.elapsed());
            let e = map.entry_async("foo").await;
            println!("--- {:.2?} acquired entry", start.elapsed());
            tokio::time::sleep(Duration::from_millis(3000)).await;
            *e.or_insert(6).get_mut() *= 2;
            println!("--- {:.2?} done with entry", start.elapsed());
        });
    }

    {
        let map = Arc::clone(&map);
        tokio::spawn(async move {
            tokio::time::sleep(Duration::from_millis(1500)).await;
            println!("--- {:.2?} calling entry 2", start.elapsed());
            let e = map.entry_async("foo").await;
            println!("--- {:.2?} acquired entry 2", start.elapsed());
            tokio::time::sleep(Duration::from_millis(3000)).await;
            *e.or_insert(6).get_mut() *= 2;
            println!("--- {:.2?} done with entry 2", start.elapsed());
        });
    }

    tokio::time::sleep(Duration::from_millis(1000)).await;
    println!("--- {:.2?} calling retain", start.elapsed());
    map.retain_async(|_, v| *v % 2 != 0).await;
    println!("--- {:.2?} done retain", start.elapsed());
}
```

This outputs:

```
--- 31.88µs calling entry
--- 38.91µs acquired entry
--- 1.00s calling retain
--- 1.50s calling entry 2
--- 3.00s done with entry
--- 3.00s acquired entry 2
--- 6.00s done with entry 2
--- 6.00s done retain
```

Of course a single test doesn't guarantee that this is supported by the
library, but all docs indicate as well that the library can deal with
these situations. "async" is used everywhere and these kinds of
situations regularly occur in async code.

Edit: thinking about it more, I suppose the key feature here is that
`retain_async` can yield, whereas the `retain` from dashmap could only
block when it couldn't make any progress.

</details>
  • Loading branch information
owi92 authored Mar 12, 2024
2 parents 39cdd95 + 057dda0 commit 4a8e304
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 50 deletions.
21 changes: 7 additions & 14 deletions backend/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ chrono = { version = "0.4", default-features = false, features = ["serde", "std"
clap = { version = "4.2.2", features = ["derive", "string"] }
confique = { version = "0.2.0", default-features = false, features = ["toml"] }
cookie = "0.18.0"
dashmap = "5.5.3"
deadpool = { version = "0.10.0", default-features = false, features = ["managed", "rt_tokio_1"] }
deadpool-postgres = { version = "0.12.1", default-features = false, features = ["rt_tokio_1"] }
elliptic-curve = { version = "0.13.4", features = ["jwk", "sec1"] }
Expand Down Expand Up @@ -65,6 +64,7 @@ ring = "0.17.8"
rustls = "0.22.2"
rustls-native-certs = "0.7.0"
rustls-pemfile = "2.1.0"
scc = "2.0.17"
secrecy = { version = "0.8", features = ["serde"] }
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1"
Expand Down
80 changes: 52 additions & 28 deletions backend/src/auth/cache.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{collections::HashSet, hash::Hash, time::{Instant, Duration}};

use dashmap::{DashMap, mapref::entry::Entry};
use deadpool_postgres::Client;
use hyper::HeaderMap;
use prometheus_client::metrics::counter::Counter;
use scc::{hash_map::Entry, HashMap};

use crate::{config::Config, prelude::*};
use crate::{auth::config::CallbackCacheDuration, config::Config, prelude::*};

use super::{config::CallbackConfig, User};
use super::User;

pub struct Caches {
pub(crate) user: UserCache,
Expand All @@ -24,14 +24,14 @@ impl Caches {

/// Starts a daemon that regularly removes outdated entries from the cache.
pub(crate) async fn maintainence_task(&self, config: &Config) -> ! {
fn cleanup<K: Eq + Hash, V>(
async fn cleanup<K: Eq + Hash, V>(
now: Instant,
map: &DashMap<K, V>,
map: &HashMap<K, V>,
cache_duration: Duration,
mut timestamp: impl FnMut(&V) -> Instant,
) -> Option<Instant> {
let mut out = None;
map.retain(|_, v| {
map.retain_async(|_, v| {
let instant = timestamp(v);
let is_outdated = now.saturating_duration_since(instant) > cache_duration;
if !is_outdated {
Expand All @@ -41,19 +41,34 @@ impl Caches {
};
}
!is_outdated
});
}).await;
out.map(|out| out + cache_duration)
}

let empty_wait_time = std::cmp::min(CACHE_DURATION, config.auth.callback.cache_duration);
let empty_wait_time = {
let mut out = CACHE_DURATION;
if let CallbackCacheDuration::Enabled(duration) = config.auth.callback.cache_duration {
out = std::cmp::min(duration, out);
}
out
};
tokio::time::sleep(empty_wait_time).await;

loop {
let now = Instant::now();
let next_user_action =
cleanup(now, &self.user.0, CACHE_DURATION, |v| v.last_written_to_db);
let next_callback_action =
cleanup(now, &self.callback.map, config.auth.callback.cache_duration, |v| v.timestamp);
let next_user_action = cleanup(
now,
&self.user.0,
CACHE_DURATION,
|v| v.last_written_to_db,
).await;
let next_callback_action = if let CallbackCacheDuration::Enabled(duration)
= config.auth.callback.cache_duration
{
cleanup(now, &self.callback.map, duration, |v| v.timestamp).await
} else {
None
};

// We will wait until the next entry in the hashmap gets stale, but
// at least 30s to not do cleanup too often. In case there are no
Expand All @@ -77,6 +92,7 @@ impl Caches {

const CACHE_DURATION: Duration = Duration::from_secs(60 * 10);

#[derive(Clone)]
struct UserCacheEntry {
display_name: String,
email: Option<String>,
Expand All @@ -92,15 +108,15 @@ struct UserCacheEntry {
/// This works fine in multi-node setups: each node just has its local cache and
/// prevents some DB writes. But as this data is never used otherwise, we don't
/// run into data inconsistency problems.
pub(crate) struct UserCache(DashMap<String, UserCacheEntry>);
pub(crate) struct UserCache(HashMap<String, UserCacheEntry>);

impl UserCache {
fn new() -> Self {
Self(DashMap::new())
Self(HashMap::new())
}

pub(crate) async fn upsert_user_info(&self, user: &super::User, db: &Client) {
match self.0.entry(user.username.clone()) {
match self.0.entry_async(user.username.clone()).await {
Entry::Occupied(mut occupied) => {
let entry = occupied.get();
let needs_update = entry.last_written_to_db.elapsed() > CACHE_DURATION
Expand All @@ -119,7 +135,7 @@ impl UserCache {
Entry::Vacant(vacant) => {
let res = Self::write_to_db(user, db).await;
if res.is_ok() {
vacant.insert(UserCacheEntry {
vacant.insert_entry(UserCacheEntry {
display_name: user.display_name.clone(),
email: user.email.clone(),
roles: user.roles.clone(),
Expand Down Expand Up @@ -167,7 +183,7 @@ impl UserCache {

// ---------------------------------------------------------------------------

#[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq, Clone)]
struct AuthCallbackCacheKey(HeaderMap);

impl Hash for AuthCallbackCacheKey {
Expand All @@ -191,6 +207,7 @@ impl Hash for AuthCallbackCacheKey {
}
}

#[derive(Clone)]
struct AuthCallbackCacheEntry {
user: Option<User>,
timestamp: Instant,
Expand All @@ -199,7 +216,7 @@ struct AuthCallbackCacheEntry {

/// Cache for `auth-callback` calls.
pub(crate) struct AuthCallbackCache {
map: DashMap<AuthCallbackCacheKey, AuthCallbackCacheEntry>,
map: HashMap<AuthCallbackCacheKey, AuthCallbackCacheEntry>,
// Metrics
hits: Counter,
misses: Counter,
Expand All @@ -208,7 +225,7 @@ pub(crate) struct AuthCallbackCache {
impl AuthCallbackCache {
fn new() -> Self {
Self {
map: DashMap::new(),
map: HashMap::new(),
hits: Counter::default(),
misses: Counter::default(),
}
Expand All @@ -225,13 +242,18 @@ impl AuthCallbackCache {
self.map.len()
}

pub(super) fn get(&self, key: &HeaderMap, config: &CallbackConfig) -> Option<Option<User>> {
pub(super) async fn get(
&self,
key: &HeaderMap,
cache_duration: Duration,
) -> Option<Option<User>> {
// TODO: this `clone` should not be necessary. It can be removed with
// `#[repr(transparent)]` and an `unsafe`, but I don't want to just
// throw around `unsafe` here.
let out = self.map.get(&AuthCallbackCacheKey(key.clone()))
.filter(|e| e.timestamp.elapsed() < config.cache_duration)
.map(|e| e.user.clone());
let out = self.map.get_async(&AuthCallbackCacheKey(key.clone()))
.await
.filter(|e| e.get().timestamp.elapsed() < cache_duration)
.map(|e| e.get().user.clone());

match out.is_some() {
true => self.hits.inc(),
Expand All @@ -241,11 +263,13 @@ impl AuthCallbackCache {
out
}

pub(super) fn insert(&self, key: HeaderMap, user: Option<User>) {
self.map.insert(AuthCallbackCacheKey(key), AuthCallbackCacheEntry {
user,
timestamp: Instant::now(),
});
pub(super) async fn insert(&self, key: HeaderMap, user: Option<User>) {
self.map.entry_async(AuthCallbackCacheKey(key))
.await
.insert_entry(AuthCallbackCacheEntry {
user,
timestamp: Instant::now(),
});
}
}

22 changes: 20 additions & 2 deletions backend/src/auth/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,28 @@ pub(crate) struct CallbackConfig {
/// For how long a callback's response is cached. The key of the cache is
/// the set of headers forwarded to the callback. Set to 0 to disable
/// caching.
#[config(default = "5min", deserialize_with = crate::config::deserialize_duration)]
pub(crate) cache_duration: Duration,
#[config(default = "5min", deserialize_with = CallbackCacheDuration::deserialize)]
pub(crate) cache_duration: CallbackCacheDuration,
}

#[derive(Debug, Clone)]
pub(crate) enum CallbackCacheDuration {
Disabled,
Enabled(Duration),
}

impl CallbackCacheDuration {
fn deserialize<'de, D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de>,
{
let duration = crate::config::deserialize_duration(deserializer)?;
if duration.is_zero() {
Ok(Self::Disabled)
} else {
Ok(Self::Enabled(duration))
}
}
}

#[derive(Debug, Clone, confique::Config)]
pub(crate) struct RoleConfig {
Expand Down
10 changes: 5 additions & 5 deletions backend/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod handlers;
mod session_id;
mod jwt;

use self::config::CallbackCacheDuration;
pub(crate) use self::{
cache::Caches,
config::{AuthConfig, AuthSource, CallbackUri},
Expand Down Expand Up @@ -252,10 +253,9 @@ impl User {

// Check cache.
let mut header_copy = None;
if !ctx.config.auth.callback.cache_duration.is_zero() {
if let CallbackCacheDuration::Enabled(duration) = ctx.config.auth.callback.cache_duration {
header_copy = Some(req.headers().clone());
let callback_config = &ctx.config.auth.callback;
if let Some(user) = ctx.auth_caches.callback.get(req.headers(), callback_config) {
if let Some(user) = ctx.auth_caches.callback.get(req.headers(), duration).await {
return Ok(user);
}
}
Expand All @@ -264,8 +264,8 @@ impl User {
let out = Self::from_callback_impl(req, callback_url, ctx).await?;

// Insert into cache
if !ctx.config.auth.callback.cache_duration.is_zero() {
ctx.auth_caches.callback.insert(header_copy.unwrap(), out.clone());
if let CallbackCacheDuration::Enabled(_) = ctx.config.auth.callback.cache_duration {
ctx.auth_caches.callback.insert(header_copy.unwrap(), out.clone()).await;
}

Ok(out)
Expand Down

0 comments on commit 4a8e304

Please sign in to comment.