From 53637e79ef499088007b74e702fb864748622bce Mon Sep 17 00:00:00 2001 From: David Calavera Date: Mon, 20 Nov 2023 10:42:07 -0800 Subject: [PATCH] Remove function config allocations per invocation. (#732) Every invocation clones the function config. This allocates memory in the heap for no reason. This change removes those extra allocations by wrapping the config into an Arc and sharing that between invocations. Signed-off-by: David Calavera --- Cargo.toml | 1 + lambda-runtime/Cargo.toml | 2 +- lambda-runtime/src/lib.rs | 28 ++++++++----- lambda-runtime/src/types.rs | 81 ++++++++++++++++++++++--------------- 4 files changed, 69 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48bcd5db..16f57a7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "lambda-http", "lambda-integration-tests", diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index 9fb8eb8b..335b5482 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -32,7 +32,7 @@ hyper = { version = "0.14.20", features = [ "server", ] } futures = "0.3" -serde = { version = "1", features = ["derive"] } +serde = { version = "1", features = ["derive", "rc"] } serde_json = "^1" bytes = "1.0" http = "0.2" diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 18b1066e..5404fb96 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -23,6 +23,7 @@ use std::{ future::Future, marker::PhantomData, panic, + sync::Arc, }; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; @@ -58,6 +59,8 @@ pub struct Config { pub log_group: String, } +type RefConfig = Arc; + impl Config { /// Attempts to read configuration from environment variables. pub fn from_env() -> Result { @@ -86,7 +89,7 @@ where struct Runtime = HttpConnector> { client: Client, - config: Config, + config: RefConfig, } impl Runtime @@ -127,8 +130,7 @@ where continue; } - let ctx: Context = Context::try_from(parts.headers)?; - let ctx: Context = ctx.with_config(&self.config); + let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?; let request_id = &ctx.request_id.clone(); let request_span = match &ctx.xray_trace_id { @@ -263,7 +265,10 @@ where trace!("Loading config from env"); let config = Config::from_env()?; let client = Client::builder().build().expect("Unable to create a runtime client"); - let runtime = Runtime { client, config }; + let runtime = Runtime { + client, + config: Arc::new(config), + }; let client = &runtime.client; let incoming = incoming(client); @@ -294,7 +299,7 @@ mod endpoint_tests { }, simulated, types::Diagnostic, - Error, Runtime, + Config, Error, Runtime, }; use futures::future::BoxFuture; use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri}; @@ -302,7 +307,7 @@ mod endpoint_tests { use lambda_runtime_api_client::Client; use serde_json::json; use simulated::DuplexStreamWrapper; - use std::{convert::TryFrom, env, marker::PhantomData}; + use std::{convert::TryFrom, env, marker::PhantomData, sync::Arc}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, select, @@ -531,9 +536,12 @@ mod endpoint_tests { if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() { env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); } - let config = crate::Config::from_env().expect("Failed to read env vars"); + let config = Config::from_env().expect("Failed to read env vars"); - let runtime = Runtime { client, config }; + let runtime = Runtime { + client, + config: Arc::new(config), + }; let client = &runtime.client; let incoming = incoming(client).take(1); runtime.run(incoming, f).await?; @@ -568,13 +576,13 @@ mod endpoint_tests { let f = crate::service_fn(func); - let config = crate::Config { + let config = Arc::new(Config { function_name: "test_fn".to_string(), memory: 128, version: "1".to_string(), log_stream: "test_stream".to_string(), log_group: "test_log".to_string(), - }; + }); let runtime = Runtime { client, config }; let client = &runtime.client; diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 2f0287ee..a252475b 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -1,4 +1,4 @@ -use crate::{Config, Error}; +use crate::{Error, RefConfig}; use base64::prelude::*; use bytes::Bytes; use http::{HeaderMap, HeaderValue, StatusCode}; @@ -97,7 +97,7 @@ pub struct CognitoIdentity { /// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html) /// and [the headers returned by the poll request to the Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-next). #[non_exhaustive] -#[derive(Clone, Debug, Eq, PartialEq, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] pub struct Context { /// The AWS request ID generated by the Lambda service. pub request_id: String, @@ -117,12 +117,14 @@ pub struct Context { /// Lambda function configuration from the local environment variables. /// Includes information such as the function name, memory allocation, /// version, and log streams. - pub env_config: Config, + pub env_config: RefConfig, } -impl TryFrom for Context { +impl TryFrom<(RefConfig, HeaderMap)> for Context { type Error = Error; - fn try_from(headers: HeaderMap) -> Result { + fn try_from(data: (RefConfig, HeaderMap)) -> Result { + let env_config = data.0; + let headers = data.1; let client_context: Option = if let Some(value) = headers.get("lambda-runtime-client-context") { serde_json::from_str(value.to_str()?)? } else { @@ -158,13 +160,20 @@ impl TryFrom for Context { .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()), client_context, identity, - ..Default::default() + env_config, }; Ok(ctx) } } +impl Context { + /// The execution deadline for the current invocation. + pub fn deadline(&self) -> SystemTime { + SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline) + } +} + /// Incoming Lambda request containing the event payload and context. #[derive(Clone, Debug)] pub struct LambdaEvent { @@ -273,6 +282,8 @@ where #[cfg(test)] mod test { use super::*; + use crate::Config; + use std::sync::Arc; #[test] fn round_trip_lambda_error() { @@ -292,6 +303,8 @@ mod test { #[test] fn context_with_expected_values_and_types_resolves() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); @@ -300,16 +313,18 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_ok()); } #[test] fn context_with_certain_missing_headers_still_resolves() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_ok()); } @@ -338,7 +353,9 @@ mod test { "lambda-runtime-client-context", HeaderValue::from_str(&client_context_str).unwrap(), ); - let tried = Context::try_from(headers); + + let config = Arc::new(Config::default()); + let tried = Context::try_from((config, headers)); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.client_context.is_some()); @@ -347,17 +364,20 @@ mod test { #[test] fn context_with_empty_client_context_resolves() { + let config = Arc::new(Config::default()); let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}")); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_ok()); assert!(tried.unwrap().client_context.is_some()); } #[test] fn context_with_identity_resolves() { + let config = Arc::new(Config::default()); + let cognito_identity = CognitoIdentity { identity_id: String::new(), identity_pool_id: String::new(), @@ -370,7 +390,7 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_str(&cognito_identity_str).unwrap(), ); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.identity.is_some()); @@ -379,6 +399,8 @@ mod test { #[test] fn context_with_bad_deadline_type_is_err() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert( @@ -390,12 +412,14 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_err()); } #[test] fn context_with_bad_client_context_is_err() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); @@ -403,22 +427,26 @@ mod test { "lambda-runtime-client-context", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_err()); } #[test] fn context_with_empty_identity_is_err() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}")); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_err()); } #[test] fn context_with_bad_identity_is_err() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); @@ -426,7 +454,7 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from(headers); + let tried = Context::try_from((config, headers)); assert!(tried.is_err()); } @@ -434,6 +462,8 @@ mod test { #[should_panic] #[allow(unused_must_use)] fn context_with_missing_request_id_should_panic() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert( @@ -441,13 +471,15 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from(headers); + Context::try_from((config, headers)); } #[test] #[should_panic] #[allow(unused_must_use)] fn context_with_missing_deadline_should_panic() { + let config = Arc::new(Config::default()); + let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert( @@ -455,21 +487,6 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from(headers); - } -} - -impl Context { - /// Add environment details to the context by setting `env_config`. - pub fn with_config(self, config: &Config) -> Self { - Self { - env_config: config.clone(), - ..self - } - } - - /// The execution deadline for the current invocation. - pub fn deadline(&self) -> SystemTime { - SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline) + Context::try_from((config, headers)); } }