From 578416e18521e0ddecbc5daf1f49025d49841c85 Mon Sep 17 00:00:00 2001 From: James Mayclin Date: Tue, 18 Jun 2024 17:36:44 +0000 Subject: [PATCH 1/2] test(bindings/s2n-tls): refactor main unit tests --- bindings/rust/s2n-tls/src/testing/s2n_tls.rs | 300 ++++++------------- 1 file changed, 85 insertions(+), 215 deletions(-) diff --git a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs index 7cab316f35f..aa9e3f1993b 100644 --- a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs +++ b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs @@ -236,7 +236,7 @@ mod tests { callbacks::{ClientHelloCallback, ConnectionFuture, ConnectionFutureResult}, enums::ClientAuthType, error::ErrorType, - testing::{client_hello::*, s2n_tls::*, *}, + testing::{client_hello::*, *}, }; use alloc::sync::Arc; use core::sync::atomic::Ordering; @@ -246,13 +246,13 @@ mod tests { #[test] fn handshake_default() { let config = build_config(&security::DEFAULT).unwrap(); - establish_connection(config); + assert!(TestPair::handshake_with_config(&config).is_ok()); } #[test] fn handshake_default_tls13() { let config = build_config(&security::DEFAULT_TLS13).unwrap(); - establish_connection(config) + assert!(TestPair::handshake_with_config(&config).is_ok()); } #[test] @@ -314,6 +314,7 @@ mod tests { Ok(()) } + #[test] fn connnection_waker() { let config = build_config(&security::DEFAULT_TLS13).unwrap(); @@ -368,44 +369,20 @@ mod tests { config.build()? }; - let server = { - // create and configure a server connection - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - server.set_waker(Some(&waker))?; - Harness::new(server) - }; - - let client = { - // create a client connection - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; + let mut pair = TestPair::from_config(&config); + pair.server.set_waker(Some(&waker))?; + let s2n_err = pair.handshake().unwrap_err(); + // the underlying error should be the custom error the application provided + let app_err = s2n_err.application_error().unwrap(); + let io_err = app_err.downcast_ref::().unwrap(); + let _custom_err = io_err + .get_ref() + .unwrap() + .downcast_ref::() + .unwrap(); - let mut pair = Pair::new(server, client); - loop { - match pair.poll() { - Poll::Ready(result) => { - let err = result.expect_err("handshake should fail"); - - // the underlying error should be the custom error the application provided - let s2n_err = err.downcast_ref::().unwrap(); - let app_err = s2n_err.application_error().unwrap(); - let io_err = app_err.downcast_ref::().unwrap(); - let _custom_err = io_err - .get_ref() - .unwrap() - .downcast_ref::() - .unwrap(); - break; - } - Poll::Pending => continue, - } - } // assert that the future is async returned Poll::Pending once assert_eq!(wake_count, 1); - Ok(()) } @@ -422,24 +399,10 @@ mod tests { config.build()? }; - let server = { - // create and configure a server connection - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - server.set_waker(Some(&waker))?; - Harness::new(server) - }; - - let client = { - // create a client connection - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; - - let pair = Pair::new(server, client); + let mut pair = TestPair::from_config(&config); + pair.server.set_waker(Some(&waker))?; + pair.handshake()?; - poll_tls_pair(pair); // confirm that the callback returned Pending `require_pending_count` times assert_eq!(wake_count, require_pending_count); // confirm that the final invoked count is +1 more than `require_pending_count` @@ -450,6 +413,7 @@ mod tests { Ok(()) } + #[test] fn client_hello_callback_sync() -> Result<(), Error> { let (waker, wake_count) = new_count_waker(); @@ -491,25 +455,12 @@ mod tests { config.build()? }; - let server = { - // create and configure a server connection - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - server.set_waker(Some(&waker))?; - Harness::new(server) - }; - - let client = { - // create a client connection - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; - - let pair = Pair::new(server, client); + let mut pair = TestPair::from_config(&config); + pair.server.set_waker(Some(&waker))?; assert_eq!(callback.count(), 0); - poll_tls_pair(pair); + + pair.handshake()?; assert_eq!(callback.count(), 1); assert_eq!(wake_count, 0); Ok(()) @@ -538,7 +489,7 @@ mod tests { builder.load_pem(&fs::read(&cert)?, &fs::read(&key)?)?; builder.trust_location(Some(&cert), None)?; - establish_connection(builder.build()?); + TestPair::handshake_with_config(&builder.build()?)?; Ok(()) } @@ -576,13 +527,9 @@ mod tests { config.build()? }; - let mut pair = tls_pair(config); - pair.server - .0 - .connection_mut() - .set_waker(Some(&noop_waker()))?; - - poll_tls_pair(pair); + let mut pair = TestPair::from_config(&config); + pair.server.set_waker(Some(&noop_waker()))?; + pair.handshake()?; } Ok(()) } @@ -602,23 +549,17 @@ mod tests { }; // confirm that default connection establishment fails - let mut pair = tls_pair(reject_config.clone()); - assert!(poll_tls_pair_result(&mut pair).is_err()); + let mut pair = TestPair::from_config(&reject_config); + assert!(pair.handshake().is_err()); // confirm that overriding the verify_host_callback on connection causes // the handshake to succeed - pair = tls_pair(reject_config); - pair.server - .0 - .connection - .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {}) - .unwrap(); + let mut pair = TestPair::from_config(&reject_config); pair.client - .0 - .connection - .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {}) - .unwrap(); - assert!(poll_tls_pair_result(&mut pair).is_ok()); + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; + pair.server + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; + pair.handshake()?; Ok(()) } @@ -633,24 +574,10 @@ mod tests { config.build()? }; - let server = { - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - Harness::new(server) - }; - - let client = { - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; - - let pair = Pair::new(server, client); - let pair = poll_tls_pair(pair); - let server = pair.server.0.connection; - let client = pair.client.0.connection; + let mut pair = TestPair::from_config(&config); + pair.handshake()?; - for conn in [server, client] { + for conn in [pair.server, pair.client] { assert!(!conn.client_cert_used()); let cert = conn.client_cert_chain_bytes()?; assert!(cert.is_none()); @@ -673,28 +600,14 @@ mod tests { config.build()? }; - let server = { - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - Harness::new(server) - }; - - let client = { - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; - - let pair = Pair::new(server, client); - let pair = poll_tls_pair(pair); - let server = pair.server.0.connection; - let client = pair.client.0.connection; + let mut pair = TestPair::from_config(&config); + pair.handshake()?; - let cert = server.client_cert_chain_bytes()?; + let cert = pair.server.client_cert_chain_bytes()?; assert!(cert.is_some()); assert!(!cert.unwrap().is_empty()); - for conn in [server, client] { + for conn in [pair.server, pair.client] { assert!(conn.client_cert_used()); let sig_alg = conn.selected_client_signature_algorithm()?; assert!(sig_alg.is_some()); @@ -706,7 +619,7 @@ mod tests { } #[test] - fn system_certs_loaded_by_default() { + fn system_certs_loaded_by_default() -> Result<(), Error> { let keypair = CertKeyPair::default(); // Load the server certificate into the trust store by overriding the OpenSSL default @@ -714,20 +627,18 @@ mod tests { temp_env::with_var("SSL_CERT_FILE", Some(keypair.cert_path()), || { let mut builder = Builder::new(); builder - .load_pem(keypair.cert(), keypair.key()) - .unwrap() - .set_security_policy(&security::DEFAULT_TLS13) - .unwrap() - .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {}) - .unwrap(); + .load_pem(keypair.cert(), keypair.key())? + .set_security_policy(&security::DEFAULT_TLS13)? + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; let config = builder.build().unwrap(); - establish_connection(config); - }); + TestPair::handshake_with_config(&config)?; + Ok(()) + }) } #[test] - fn disable_loading_system_certs() { + fn disable_loading_system_certs() -> Result<(), Error> { let keypair = CertKeyPair::default(); // Load the server certificate into the trust store by overriding the OpenSSL default @@ -736,24 +647,19 @@ mod tests { // Test the Builder itself, and also the Builder produced by the Config builder() API. for mut builder in [Builder::new(), Config::builder()] { builder - .load_pem(keypair.cert(), keypair.key()) - .unwrap() - .set_security_policy(&security::DEFAULT_TLS13) - .unwrap() - .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {}) - .unwrap(); + .load_pem(keypair.cert(), keypair.key())? + .set_security_policy(&security::DEFAULT_TLS13)? + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; // Disable loading system certificates - builder.with_system_certs(false).unwrap(); + builder.with_system_certs(false)?; - let config = builder.build().unwrap(); + let config = builder.build()?; let mut config_with_system_certs = config.clone(); - let mut pair = tls_pair(config); - // System certificates should not be loaded into the trust store. The handshake // should fail since the certificate should not be trusted. - assert!(poll_tls_pair_result(&mut pair).is_err()); + assert!(TestPair::handshake_with_config(&config).is_err()); // The handshake should succeed after trusting the certificate. unsafe { @@ -761,9 +667,10 @@ mod tests { config_with_system_certs.as_mut_ptr(), ); } - establish_connection(config_with_system_certs); + TestPair::handshake_with_config(&config_with_system_certs)?; } - }); + Ok(()) + }) } #[test] @@ -776,24 +683,10 @@ mod tests { config.build()? }; - let server = { - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - Harness::new(server) - }; - - let client = { - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; - - let pair = Pair::new(server, client); - let pair = poll_tls_pair(pair); - let server = pair.server.0.connection; - let client = pair.client.0.connection; + let mut pair = TestPair::from_config(&config); + pair.handshake()?; - for conn in [server, client] { + for conn in [pair.server, pair.client] { let chain = conn.peer_cert_chain()?; assert_eq!(chain.len(), 1); for cert in chain.iter() { @@ -816,29 +709,15 @@ mod tests { config.build()? }; - let server = { - let mut server = crate::connection::Connection::new_server(); - server.set_config(config.clone())?; - Harness::new(server) - }; - - let client = { - let mut client = crate::connection::Connection::new_client(); - client.set_config(config)?; - Harness::new(client) - }; + let mut pair = TestPair::from_config(&config); // None before handshake... - assert!(server.connection.selected_cert().is_none()); - assert!(client.connection.selected_cert().is_none()); - - let pair = Pair::new(server, client); + assert!(pair.server.selected_cert().is_none()); + assert!(pair.client.selected_cert().is_none()); - let pair = poll_tls_pair(pair); - let server = pair.server.0.connection; - let client = pair.client.0.connection; + pair.handshake()?; - for conn in [&server, &client] { + for conn in [&pair.server, &pair.client] { let chain = conn.selected_cert().unwrap(); assert_eq!(chain.len(), 1); for cert in chain.iter() { @@ -851,14 +730,14 @@ mod tests { // Same config is used for both and we are doing mTLS, so both should select the same // certificate. assert_eq!( - server + pair.server .selected_cert() .unwrap() .iter() .next() .unwrap()? .der()?, - client + pair.client .selected_cert() .unwrap() .iter() @@ -874,12 +753,11 @@ mod tests { fn master_secret_success() -> Result<(), Error> { let policy = security::Policy::from_version("test_all_tls12")?; let config = config_builder(&policy)?.build()?; - let pair = poll_tls_pair(tls_pair(config)); - let server = pair.server.0.connection; - let client = pair.client.0.connection; + let mut pair = TestPair::from_config(&config); + pair.handshake()?; - let server_secret = server.master_secret()?; - let client_secret = client.master_secret()?; + let server_secret = pair.server.master_secret()?; + let client_secret = pair.client.master_secret()?; assert_eq!(server_secret, client_secret); Ok(()) @@ -888,16 +766,13 @@ mod tests { #[test] fn master_secret_failure() -> Result<(), Error> { // TLS1.3 does not support getting the master secret - let config = config_builder(&security::DEFAULT_TLS13)?.build()?; - let pair = poll_tls_pair(tls_pair(config)); - let server = pair.server.0.connection; - let client = pair.client.0.connection; - - let server_error = server.master_secret().unwrap_err(); - assert_eq!(server_error.kind(), ErrorType::UsageError); + let mut pair = TestPair::from_config(&build_config(&security::DEFAULT_TLS13)?); + pair.handshake()?; - let client_error = client.master_secret().unwrap_err(); - assert_eq!(client_error.kind(), ErrorType::UsageError); + for conn in [pair.client, pair.server] { + let err = conn.master_secret().unwrap_err(); + assert_eq!(err.kind(), ErrorType::UsageError); + } Ok(()) } @@ -912,26 +787,21 @@ mod tests { send_key_updates: 0, }; - let pair = tls_pair(build_config(&security::DEFAULT_TLS13)?); - let mut pair = poll_tls_pair(pair); + let mut pair = TestPair::from_config(&build_config(&security::DEFAULT_TLS13)?); + pair.handshake()?; // there haven't been any key updates at the start of the connection - let client_updates = pair.client.0.connection.as_ref().key_update_counts()?; - assert_eq!(client_updates, empty_key_updates); - let server_updates = pair.server.0.connection.as_ref().key_update_counts()?; - assert_eq!(server_updates, empty_key_updates); + assert_eq!(pair.client.key_update_counts()?, empty_key_updates); + assert_eq!(pair.server.key_update_counts()?, empty_key_updates); pair.server - .0 - .connection - .as_mut() .request_key_update(PeerKeyUpdate::KeyUpdateNotRequested)?; - assert!(pair.poll_send(Mode::Server, &[0]).is_ready()); + assert!(pair.server.poll_send(&[0]).is_ready()); // the server send key has been updated - let client_updates = pair.client.0.connection.as_ref().key_update_counts()?; + let client_updates = pair.client.key_update_counts()?; assert_eq!(client_updates, empty_key_updates); - let server_updates = pair.server.0.connection.as_ref().key_update_counts()?; + let server_updates = pair.server.key_update_counts()?; assert_eq!(server_updates.recv_key_updates, 0); assert_eq!(server_updates.send_key_updates, 1); From e2f7f2f887a9611d05aa32f515e2717d508c5b71 Mon Sep 17 00:00:00 2001 From: James Mayclin Date: Thu, 20 Jun 2024 16:41:31 -0700 Subject: [PATCH 2/2] Update bindings/rust/s2n-tls/src/testing/s2n_tls.rs Co-authored-by: Sam Clark <3758302+goatgoose@users.noreply.github.com> --- bindings/rust/s2n-tls/src/testing/s2n_tls.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs index aa9e3f1993b..02ad414c310 100644 --- a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs +++ b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs @@ -314,7 +314,6 @@ mod tests { Ok(()) } - #[test] fn connnection_waker() { let config = build_config(&security::DEFAULT_TLS13).unwrap();