diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index d78df80112..c1831aac7b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -561,9 +561,9 @@ class ServerProtocolTestGenerator( private fun checkResponse(rustWriter: RustWriter, testCase: HttpResponseTestCase) { checkStatusCode(rustWriter, testCase.code) - checkHeaders(rustWriter, "&http_response.headers()", testCase.headers) - checkForbidHeaders(rustWriter, "&http_response.headers()", testCase.forbidHeaders) - checkRequiredHeaders(rustWriter, "&http_response.headers()", testCase.requireHeaders) + checkHeaders(rustWriter, "http_response.headers()", testCase.headers) + checkForbidHeaders(rustWriter, "http_response.headers()", testCase.forbidHeaders) + checkRequiredHeaders(rustWriter, "http_response.headers()", testCase.requireHeaders) // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to @@ -579,7 +579,7 @@ class ServerProtocolTestGenerator( private fun checkResponse(rustWriter: RustWriter, testCase: HttpMalformedResponseDefinition) { checkStatusCode(rustWriter, testCase.code) - checkHeaders(rustWriter, "&http_response.headers()", testCase.headers) + checkHeaders(rustWriter, "http_response.headers()", testCase.headers) // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to diff --git a/rust-runtime/aws-smithy-protocol-test/src/lib.rs b/rust-runtime/aws-smithy-protocol-test/src/lib.rs index 88f6ee4545..e1b2f6adfc 100644 --- a/rust-runtime/aws-smithy-protocol-test/src/lib.rs +++ b/rust-runtime/aws-smithy-protocol-test/src/lib.rs @@ -13,11 +13,12 @@ mod urlencoded; mod xml; +use crate::sealed::GetNormalizedHeader; use crate::xml::try_xml_equivalent; use assert_json_diff::assert_json_eq_no_panic; use aws_smithy_runtime_api::client::http::request::Headers; use aws_smithy_runtime_api::client::orchestrator::HttpRequest; -use http::Uri; +use http::{HeaderMap, Uri}; use pretty_assertions::Comparison; use std::collections::HashSet; use std::fmt::{self, Debug}; @@ -211,14 +212,46 @@ pub fn require_query_params( Ok(()) } +mod sealed { + pub trait GetNormalizedHeader { + fn get_header(&self, key: &str) -> Option; + } +} + +impl<'a> GetNormalizedHeader for &'a Headers { + fn get_header(&self, key: &str) -> Option { + if !self.contains_key(key) { + None + } else { + Some(self.get_all(key).collect::>().join(", ")) + } + } +} + +impl<'a> GetNormalizedHeader for &'a HeaderMap { + fn get_header(&self, key: &str) -> Option { + if !self.contains_key(key) { + None + } else { + Some( + self.get_all(key) + .iter() + .map(|value| std::str::from_utf8(value.as_bytes()).expect("invalid utf-8")) + .collect::>() + .join(", "), + ) + } + } +} + pub fn validate_headers<'a>( - actual_headers: &Headers, + actual_headers: impl GetNormalizedHeader, expected_headers: impl IntoIterator + 'a, impl AsRef + 'a)>, ) -> Result<(), ProtocolTestFailure> { for (key, expected_value) in expected_headers { let key = key.as_ref(); let expected_value = expected_value.as_ref(); - match normalized_header(actual_headers, key) { + match actual_headers.get_header(key) { None => { return Err(ProtocolTestFailure::MissingHeader { expected: key.to_string(), @@ -237,21 +270,13 @@ pub fn validate_headers<'a>( Ok(()) } -fn normalized_header(headers: &Headers, key: &str) -> Option { - if !headers.contains_key(key) { - None - } else { - Some(headers.get_all(key).collect::>().join(", ")) - } -} - pub fn forbid_headers( - headers: &Headers, + headers: impl GetNormalizedHeader, forbidden_headers: &[&str], ) -> Result<(), ProtocolTestFailure> { for key in forbidden_headers { // Protocol tests store header lists as comma-delimited - if let Some(value) = normalized_header(headers, key) { + if let Some(value) = headers.get_header(key) { return Err(ProtocolTestFailure::ForbiddenHeader { forbidden: key.to_string(), found: format!("{}: {}", key, value), @@ -262,12 +287,12 @@ pub fn forbid_headers( } pub fn require_headers( - headers: &Headers, + headers: impl GetNormalizedHeader, required_headers: &[&str], ) -> Result<(), ProtocolTestFailure> { for key in required_headers { // Protocol tests store header lists as comma-delimited - if normalized_header(headers, key).is_none() { + if headers.get_header(key).is_none() { return Err(ProtocolTestFailure::MissingHeader { expected: key.to_string(), }); @@ -442,10 +467,10 @@ mod tests { #[test] fn test_validate_headers() { let mut headers = Headers::new(); - headers.append("X-Foo", "foo"); - headers.append("X-Foo-List", "foo"); - headers.append("X-Foo-List", "bar"); - headers.append("X-Inline", "inline, other"); + headers.append("x-foo", "foo"); + headers.append("x-foo-list", "foo"); + headers.append("x-foo-list", "bar"); + headers.append("x-inline", "inline, other"); validate_headers(&headers, [("X-Foo", "foo")]).expect("header present"); validate_headers(&headers, [("X-Foo", "Foo")]).expect_err("case sensitive"); @@ -465,7 +490,7 @@ mod tests { #[test] fn test_forbidden_headers() { let mut headers = Headers::new(); - headers.append("X-Foo", "foo"); + headers.append("x-foo", "foo"); assert_eq!( forbid_headers(&headers, &["X-Foo"]).expect_err("should be error"), ProtocolTestFailure::ForbiddenHeader { @@ -479,7 +504,7 @@ mod tests { #[test] fn test_required_headers() { let mut headers = Headers::new(); - headers.append("X-Foo", "foo"); + headers.append("x-foo", "foo"); require_headers(&headers, &["X-Foo"]).expect("header present"); require_headers(&headers, &["X-Bar"]).expect_err("header not present"); } @@ -520,6 +545,12 @@ mod tests { .expect("inputs matched exactly") } + #[test] + fn test_validate_headers_http0x() { + let request = http::Request::builder().header("a", "b").body(()).unwrap(); + validate_headers(request.headers(), [("a", "b")]).unwrap() + } + #[test] fn test_float_equals() { let a = f64::NAN; diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/http/request.rs b/rust-runtime/aws-smithy-runtime-api/src/client/http/request.rs index 904a539c18..cc95daef20 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/http/request.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/http/request.rs @@ -582,9 +582,16 @@ impl Error for HttpError { fn header_name(name: impl AsHeaderComponent) -> Result { name.repr_as_http03x_header_name().or_else(|name| { - name.into_maybe_static().and_then(|cow| match cow { - Cow::Borrowed(staticc) => Ok(http0::HeaderName::from_static(staticc)), - Cow::Owned(s) => http0::HeaderName::try_from(s).map_err(HttpError::invalid_header_name), + name.into_maybe_static().and_then(|cow| { + if cow.chars().any(|c| c.is_uppercase()) { + return Err(HttpError::new("Header names must be all lower case")); + } + match cow { + Cow::Borrowed(staticc) => Ok(http0::HeaderName::from_static(staticc)), + Cow::Owned(s) => { + http0::HeaderName::try_from(s).map_err(HttpError::invalid_header_name) + } + } }) }) }