diff --git a/tests/unit/s2n_protocol_version_getter_test.c b/tests/unit/s2n_protocol_version_getter_test.c index 08a1baae6f3..d69ff8f06ff 100644 --- a/tests/unit/s2n_protocol_version_getter_test.c +++ b/tests/unit/s2n_protocol_version_getter_test.c @@ -78,13 +78,14 @@ int main(int argc, char **argv) uint8_t supported_versions_data[3] = { 0 }; struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, s2n_array_len(supported_versions_data))); + EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, + s2n_array_len(supported_versions_data))); struct s2n_stuffer supported_versions_stuffer = { 0 }; EXPECT_SUCCESS(s2n_stuffer_init(&supported_versions_stuffer, &supported_versions_blob)); - /* Write length byte */ + /* Write the length byte. */ EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); - /* Write supported version */ + /* Write the supported version. */ POSIX_GUARD(s2n_stuffer_write_uint8(&supported_versions_stuffer, client_supported_version / 10)); POSIX_GUARD(s2n_stuffer_write_uint8(&supported_versions_stuffer, client_supported_version % 10)); @@ -106,7 +107,6 @@ int main(int argc, char **argv) EXPECT_NOT_NULL(server); EXPECT_SUCCESS(s2n_connection_set_config(server, config)); - struct s2n_stuffer *hello_stuffer = &client->handshake.io; EXPECT_SUCCESS(s2n_client_hello_send(client)); /* Overwrite the client hello version according to the test case. */ @@ -114,11 +114,19 @@ int main(int argc, char **argv) protocol_version[0] = client_hello_version / 10; protocol_version[1] = client_hello_version % 10; + struct s2n_stuffer *hello_stuffer = &client->handshake.io; EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); EXPECT_SUCCESS(s2n_stuffer_write_bytes(hello_stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_EQUAL(context.invoked_count, 1); + + /* Ensure that a supported versions extension was received. */ + bool supported_versions_received = false; + EXPECT_SUCCESS(s2n_client_hello_has_extension(&server->client_hello, S2N_EXTENSION_SUPPORTED_VERSIONS, + &supported_versions_received)); + EXPECT_TRUE(supported_versions_received); EXPECT_EQUAL(s2n_connection_get_server_protocol_version(server), server_version); @@ -134,7 +142,7 @@ int main(int argc, char **argv) EXPECT_EQUAL(s2n_connection_get_client_hello_version(server), MIN(client_hello_version, S2N_TLS12)); uint8_t actual_protocol_version = s2n_connection_get_actual_protocol_version(server); - if (server_version < S2N_TLS13) { + if (server_version == S2N_TLS12) { /* For backwards compatibility, TLS 1.2 servers always use the client hello * version to determine the client's maximum version, even if a supported * versions extension was received. @@ -152,7 +160,7 @@ int main(int argc, char **argv) /* Test protocol version getters on the server when a supported versions extension isn't received */ for (uint8_t server_version = S2N_TLS12; server_version <= S2N_TLS13; server_version++) { - for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS13; client_hello_version++) { + for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS12; client_hello_version++) { DEFER_CLEANUP(struct s2n_config *client_config = s2n_config_new(), s2n_config_ptr_free); /* A TLS 1.2 security policy is set to prevent the client from sending a supported @@ -202,15 +210,10 @@ int main(int argc, char **argv) &supported_versions_received)); EXPECT_FALSE(supported_versions_received); - /* With no supported versions extension, the maximum version the client can support is - * TLS 1.2. - */ - uint8_t maximum_client_version = S2N_TLS12; - EXPECT_EQUAL(s2n_connection_get_server_protocol_version(server), server_version); - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), MIN(client_hello_version, maximum_client_version)); - EXPECT_EQUAL(s2n_connection_get_client_hello_version(server), MIN(client_hello_version, maximum_client_version)); - EXPECT_EQUAL(s2n_connection_get_actual_protocol_version(server), MIN(client_hello_version, maximum_client_version)); + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); + EXPECT_EQUAL(s2n_connection_get_client_hello_version(server), client_hello_version); + EXPECT_EQUAL(s2n_connection_get_actual_protocol_version(server), client_hello_version); } } @@ -247,142 +250,146 @@ int main(int argc, char **argv) } /* Test get_client_protocol_version fallback behavior on TLS 1.2 servers */ - { - /* Report client hello version if the supported version extension is malformed */ - for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS12; client_hello_version++) { - for (uint8_t send_valid_extension = 0; send_valid_extension <= 1; send_valid_extension++) { - DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); - EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); - EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); - - uint8_t supported_versions_data[3] = { 0 }; - struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, s2n_array_len(supported_versions_data))); - struct s2n_stuffer supported_versions_stuffer = { 0 }; - EXPECT_SUCCESS(s2n_stuffer_init(&supported_versions_stuffer, &supported_versions_blob)); - - /* Write length byte */ - if (send_valid_extension) { - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); - } else { - /* Create a malformed supported versions extension by writing an invalid length byte */ - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 11)); - } - /* Write supported version */ - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, S2N_TLS13 / 10)); - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, S2N_TLS13 % 10)); + for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS12; client_hello_version++) { + /* Report the client hello version if the supported versions extension is malformed */ + for (uint8_t send_valid_extension = 0; send_valid_extension <= 1; send_valid_extension++) { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); + + uint8_t supported_versions_data[3] = { 0 }; + struct s2n_blob supported_versions_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, + s2n_array_len(supported_versions_data))); + struct s2n_stuffer supported_versions_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&supported_versions_stuffer, &supported_versions_blob)); + + /* Write the length byte. */ + if (send_valid_extension) { + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); + } else { + /* Create a malformed supported versions extension by writing an invalid length + * byte. */ + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 11)); + } + /* Write the supported version. */ + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, S2N_TLS13 / 10)); + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, S2N_TLS13 % 10)); - struct s2n_override_extension_ctx context = { - .extension_blob = supported_versions_blob - }; - EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); + struct s2n_override_extension_ctx context = { + .extension_blob = supported_versions_blob + }; + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); - DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), - s2n_connection_ptr_free); - EXPECT_NOT_NULL(client); - EXPECT_SUCCESS(s2n_connection_set_config(client, config)); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(client); + EXPECT_SUCCESS(s2n_connection_set_config(client, config)); - DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), - s2n_connection_ptr_free); - EXPECT_NOT_NULL(server); - EXPECT_SUCCESS(s2n_connection_set_config(server, config)); + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(server); + EXPECT_SUCCESS(s2n_connection_set_config(server, config)); - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_client_hello_send(client)); + struct s2n_stuffer *hello_stuffer = &client->handshake.io; + EXPECT_SUCCESS(s2n_client_hello_send(client)); - /* Overwrite the client hello version according to the test case. */ - uint8_t protocol_version[S2N_TLS_PROTOCOL_VERSION_LEN] = { 0 }; - protocol_version[0] = client_hello_version / 10; - protocol_version[1] = client_hello_version % 10; + /* Overwrite the client hello version according to the test case. */ + uint8_t protocol_version[S2N_TLS_PROTOCOL_VERSION_LEN] = { 0 }; + protocol_version[0] = client_hello_version / 10; + protocol_version[1] = client_hello_version % 10; - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_SUCCESS(s2n_stuffer_write_bytes(hello_stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); + EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); + EXPECT_SUCCESS(s2n_stuffer_write_bytes(hello_stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); + EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_EQUAL(context.invoked_count, 1); - if (send_valid_extension) { - /* TLS 1.3 was written to the supported versions extension. If a valid extension was - * sent, the reported client protocol version should be TLS 1.3. - */ - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), S2N_TLS13); - } else { - /* The reported client protocol version should fall back to the client hello version - * if the supported versions extension is malformed. - */ - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); - } + if (send_valid_extension) { + /* TLS 1.3 was written to the supported versions extension. If a valid extension was + * sent, the reported client protocol version should be TLS 1.3. + */ + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), S2N_TLS13); + } else { + /* The reported client protocol version should fall back to the client hello version + * if the supported versions extension is malformed. + */ + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); } } - /* Report client hello version if an invalid supported version was received */ - for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS12; client_hello_version++) { - for (uint8_t send_valid_version = 0; send_valid_version <= 1; send_valid_version++) { - DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); - EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); - EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); + /* Report the client hello version if an invalid supported version was received */ + for (uint8_t send_valid_version = 0; send_valid_version <= 1; send_valid_version++) { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); - uint8_t supported_versions_data[3] = { 0 }; - struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, s2n_array_len(supported_versions_data))); - struct s2n_stuffer supported_versions_stuffer = { 0 }; - EXPECT_SUCCESS(s2n_stuffer_init(&supported_versions_stuffer, &supported_versions_blob)); + uint8_t supported_versions_data[3] = { 0 }; + struct s2n_blob supported_versions_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, + s2n_array_len(supported_versions_data))); + struct s2n_stuffer supported_versions_stuffer = { 0 }; + EXPECT_SUCCESS(s2n_stuffer_init(&supported_versions_stuffer, &supported_versions_blob)); - /* Write length byte */ - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); - /* Write supported version */ - uint8_t valid_supported_version = S2N_TLS13; - uint8_t invalid_supported_version = S2N_TLS13 + 10; - uint8_t supported_version = invalid_supported_version; - if (send_valid_version) { - supported_version = valid_supported_version; - } - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, supported_version / 10)); - EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, supported_version % 10)); + /* Write the length byte. */ + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); - struct s2n_override_extension_ctx context = { - .extension_blob = supported_versions_blob - }; - EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); + uint8_t valid_supported_version = S2N_TLS13; + uint8_t invalid_supported_version = S2N_TLS13 + 10; - DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), - s2n_connection_ptr_free); - EXPECT_NOT_NULL(client); - EXPECT_SUCCESS(s2n_connection_set_config(client, config)); + uint8_t supported_version = invalid_supported_version; + if (send_valid_version) { + supported_version = valid_supported_version; + } - DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), - s2n_connection_ptr_free); - EXPECT_NOT_NULL(server); - EXPECT_SUCCESS(s2n_connection_set_config(server, config)); + /* Write the supported version. */ + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, supported_version / 10)); + EXPECT_SUCCESS(s2n_stuffer_write_uint8(&supported_versions_stuffer, supported_version % 10)); - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_client_hello_send(client)); + struct s2n_override_extension_ctx context = { + .extension_blob = supported_versions_blob + }; + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); - /* Overwrite the client hello version according to the test case. */ - uint8_t protocol_version[S2N_TLS_PROTOCOL_VERSION_LEN] = { 0 }; - protocol_version[0] = client_hello_version / 10; - protocol_version[1] = client_hello_version % 10; + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(client); + EXPECT_SUCCESS(s2n_connection_set_config(client, config)); - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_SUCCESS(s2n_stuffer_write_bytes(hello_stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(server); + EXPECT_SUCCESS(s2n_connection_set_config(server, config)); - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + struct s2n_stuffer *hello_stuffer = &client->handshake.io; + EXPECT_SUCCESS(s2n_client_hello_send(client)); - if (send_valid_version) { - /* If a valid supported version was sent, the version should be reported regardless of - * the client hello version. - */ - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), valid_supported_version); - } else { - /* The reported client protocol version should fall back to the client hello version - * if the received supported version is invalid. - */ - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); - } + /* Overwrite the client hello version according to the test case. */ + uint8_t protocol_version[S2N_TLS_PROTOCOL_VERSION_LEN] = { 0 }; + protocol_version[0] = client_hello_version / 10; + protocol_version[1] = client_hello_version % 10; + + EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); + EXPECT_SUCCESS(s2n_stuffer_write_bytes(hello_stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); + EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); + + EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_EQUAL(context.invoked_count, 1); + + if (send_valid_version) { + /* If a valid supported version was sent, the version should be reported regardless + * of the client hello version. + */ + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), valid_supported_version); + } else { + /* The reported client protocol version should fall back to the client hello + * version if the received supported version is invalid. + */ + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); } } } END_TEST(); -} \ No newline at end of file +} diff --git a/tls/extensions/s2n_client_supported_versions.c b/tls/extensions/s2n_client_supported_versions.c index 1f104a8c432..577640eecee 100644 --- a/tls/extensions/s2n_client_supported_versions.c +++ b/tls/extensions/s2n_client_supported_versions.c @@ -139,7 +139,7 @@ static S2N_RESULT s2n_client_supported_versions_recv_impl(struct s2n_connection static int s2n_client_supported_versions_recv(struct s2n_connection *conn, struct s2n_stuffer *extension) { - /* For backwards compatibility, the supported versions extension is not used for protocol + /* For backwards compatibility, the supported versions extension isn't used for protocol * version selection if the server doesn't support TLS 1.3. This ensures that TLS 1.2 servers * experience no behavior change due to processing the TLS 1.3 extension. See * https://github.com/aws/s2n-tls/issues/4240. diff --git a/tls/s2n_connection.c b/tls/s2n_connection.c index d2e4757a6da..83cc3caf91b 100644 --- a/tls/s2n_connection.c +++ b/tls/s2n_connection.c @@ -970,12 +970,12 @@ int s2n_connection_get_client_protocol_version(struct s2n_connection *conn) { POSIX_ENSURE_REF(conn); - /* If a server connection doesn't support TLS 1.3, the client protocol version isn't updated - * via the supported versions extension in order to maintain backwards compatibility with - * TLS 1.2 connections. See https://github.com/aws/s2n-tls/issues/4240. + /* For backwards compatibility, the client_protocol_version field isn't updated via the + * supported versions extension on TLS 1.2 servers. See + * https://github.com/aws/s2n-tls/issues/4240. * - * The extension is processed in s2n_connection_get_client_protocol_version() to ensure that - * TLS 1.2 servers report the same client protocol version as TLS 1.3 servers. + * The extension is processed here to ensure that TLS 1.2 servers report the same client + * protocol version to applications as TLS 1.3 servers. */ if (conn->mode == S2N_SERVER && conn->server_protocol_version <= S2N_TLS12) { uint8_t client_supported_version = s2n_unknown_protocol_version;