Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement initial lazy caching credentials provider #578

Merged
merged 12 commits into from
Jul 20, 2021
Merged
12 changes: 8 additions & 4 deletions aws/rust-runtime/aws-auth/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
[package]
name = "aws-auth"
version = "0.1.0"
authors = ["Russell Cohen <rcoh@amazon.com>"]
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@amazon.com>"]
license = "Apache-2.0"
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
smithy-http = { path = "../../../rust-runtime/smithy-http" }
tokio = { version = "1", features = ["sync"] }
tracing = "0.1.25"
zeroize = "1.2.0"

[dev-dependencies]
http = "0.2.3"
tokio = { version = "1.0", features = ["rt", "macros"] }
async-trait = "0.1.50"
env_logger = "*"
http = "0.2.3"
test-env-log = { version = "0.2.7", features = ["trace"] }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
5 changes: 5 additions & 0 deletions aws/rust-runtime/aws-auth/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use zeroize::Zeroizing;
#[derive(Clone)]
pub struct Credentials(Arc<Inner>);

#[derive(Clone)]
struct Inner {
access_key_id: Zeroizing<String>,
secret_access_key: Zeroizing<String>,
Expand Down Expand Up @@ -89,6 +90,10 @@ impl Credentials {
self.0.expires_after
}

pub fn expiry_mut(&mut self) -> &mut Option<SystemTime> {
&mut Arc::make_mut(&mut self.0).expires_after
}

pub fn session_token(&self) -> Option<&str> {
self.0.session_token.as_deref()
}
Expand Down
3 changes: 3 additions & 0 deletions aws/rust-runtime/aws-auth/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
* SPDX-License-Identifier: Apache-2.0.
*/

mod cache;
pub mod env;
pub mod lazy_caching;
mod time;

use crate::Credentials;
use smithy_http::property_bag::PropertyBag;
Expand Down
133 changes: 133 additions & 0 deletions aws/rust-runtime/aws-auth/src/provider/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

use crate::provider::CredentialsResult;
use crate::Credentials;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{OnceCell, RwLock};

#[derive(Clone)]
pub(super) struct Cache {
/// Amount of time before the actual credential expiration time
/// where credentials are considered expired.
buffer_time: Duration,
value: Arc<RwLock<OnceCell<Credentials>>>,
}

impl Cache {
pub fn new(buffer_time: Duration) -> Cache {
Cache {
buffer_time,
value: Arc::new(RwLock::new(OnceCell::new())),
}
}

#[cfg(test)]
async fn get(&self) -> Option<Credentials> {
self.value.read().await.get().cloned()
}

/// Attempts to refresh the cached credentials with the given async future.
/// If multiple threads attempt to refresh at the same time, one of them will win,
/// and the others will await that thread's result rather than multiple refreshes occurring.
/// The function given to acquire a credentials future, `f`, will not be called
/// if another thread is chosen to load the credentials.
pub async fn get_or_load<F, Fut>(&self, f: F) -> CredentialsResult
where
F: FnOnce() -> Fut,
Fut: Future<Output = CredentialsResult>,
{
let lock = self.value.read().await;
let future = lock.get_or_try_init(f);
future.await.map(|credentials| credentials.clone())
}

/// If the credentials are expired, clears the cache. Otherwise, yields the current credentials value.
pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<Credentials> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll probably need a way to also explicitly expire credentials.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use-case for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

credentials can become invalidated (eg. manually invalidated session credentials)

// Short-circuit if the credential is not expired
if let Some(credentials) = self.value.read().await.get() {
if !expired(credentials, self.buffer_time, now) {
return Some(credentials.clone());
}
}

// Acquire a write lock to clear the cache, but then once the lock is acquired,
// check again that the credential is not already cleared. If it has been cleared,
// then another thread is refreshing the cache by the time the write lock was acquired.
let mut lock = self.value.write().await;
if let Some(credentials) = lock.get() {
// Also check that we're clearing the expired credentials and not credentials
// that have been refreshed by another thread.
if expired(credentials, self.buffer_time, now) {
*lock = OnceCell::new();
}
}
None
}
}

fn expired(credentials: &Credentials, buffer_time: Duration, now: SystemTime) -> bool {
credentials
.expiry()
.map(|expiration| now >= (expiration - buffer_time))
.expect("Cached credentials don't have an expiration time. This is a bug in aws-auth.")
}

#[cfg(test)]
mod tests {
use super::{expired, Cache};
use crate::Credentials;
use std::time::{Duration, SystemTime};

fn credentials(expired_secs: u64) -> Credentials {
Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test")
}

fn epoch_secs(secs: u64) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
}

#[test]
fn expired_check() {
let creds = credentials(100);
assert!(expired(&creds, Duration::from_secs(10), epoch_secs(1000)));
assert!(expired(&creds, Duration::from_secs(10), epoch_secs(90)));
assert!(!expired(&creds, Duration::from_secs(10), epoch_secs(10)));
}

#[test_env_log::test(tokio::test)]
async fn cache_clears_if_expired_only() {
let cache = Cache::new(Duration::from_secs(10));
assert!(cache
.yield_or_clear_if_expired(epoch_secs(100))
.await
.is_none());

cache
.get_or_load(|| async { Ok(credentials(100)) })
.await
.unwrap();
assert_eq!(Some(epoch_secs(100)), cache.get().await.unwrap().expiry());

// It should not clear the credentials if they're not expired
assert_eq!(
Some(epoch_secs(100)),
cache
.yield_or_clear_if_expired(epoch_secs(10))
.await
.unwrap()
.expiry()
);

// It should clear the credentials if they're expired
assert!(cache
.yield_or_clear_if_expired(epoch_secs(500))
.await
.is_none());
assert!(cache.get().await.is_none());
}
}
Loading