diff --git a/tests/unit/s2n_get_protocol_version_test.c b/tests/unit/s2n_connection_protocol_versions_test.c similarity index 65% rename from tests/unit/s2n_get_protocol_version_test.c rename to tests/unit/s2n_connection_protocol_versions_test.c index 6b3650c0d3d..c6970eee4cd 100644 --- a/tests/unit/s2n_get_protocol_version_test.c +++ b/tests/unit/s2n_connection_protocol_versions_test.c @@ -18,46 +18,8 @@ #include "testlib/s2n_testlib.h" #include "tls/s2n_tls.h" -struct s2n_override_extension_ctx { - struct s2n_blob extension_blob; - int invoked_count; -}; - -static int s2n_override_supported_versions_cb(struct s2n_connection *conn, void *ctx) -{ - EXPECT_NOT_NULL(conn); - EXPECT_NOT_NULL(ctx); - - struct s2n_override_extension_ctx *context = (struct s2n_override_extension_ctx *) ctx; - context->invoked_count += 1; - - struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(conn); - EXPECT_NOT_NULL(client_hello); - - s2n_extension_type_id supported_versions_id = 0; - EXPECT_SUCCESS(s2n_extension_supported_iana_value_to_id(S2N_EXTENSION_SUPPORTED_VERSIONS, &supported_versions_id)); - - s2n_parsed_extension *supported_versions_extension = &client_hello->extensions.parsed_extensions[supported_versions_id]; - supported_versions_extension->extension_type = S2N_EXTENSION_SUPPORTED_VERSIONS; - supported_versions_extension->extension = context->extension_blob; - - return S2N_SUCCESS; -} - -S2N_RESULT s2n_write_protocol_version(struct s2n_stuffer *stuffer, uint8_t version) -{ - RESULT_ENSURE_REF(stuffer); - - uint8_t protocol_version[S2N_TLS_PROTOCOL_VERSION_LEN] = { 0 }; - protocol_version[0] = version / 10; - protocol_version[1] = version % 10; - - RESULT_GUARD_POSIX(s2n_stuffer_write_bytes(stuffer, protocol_version, S2N_TLS_PROTOCOL_VERSION_LEN)); - - return S2N_RESULT_OK; -} - -S2N_RESULT s2n_write_test_supported_versions_extension(struct s2n_blob *supported_versions_blob, uint8_t version) +static S2N_RESULT s2n_write_test_supported_versions_extension(struct s2n_blob *supported_versions_blob, uint8_t version, + uint8_t extension_length) { RESULT_ENSURE_REF(supported_versions_blob); @@ -65,7 +27,7 @@ S2N_RESULT s2n_write_test_supported_versions_extension(struct s2n_blob *supporte RESULT_GUARD_POSIX(s2n_stuffer_init(&supported_versions_stuffer, supported_versions_blob)); /* Write the length byte. */ - RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&supported_versions_stuffer, 2)); + RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&supported_versions_stuffer, extension_length)); /* Write the supported version. */ RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&supported_versions_stuffer, version / 10)); RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&supported_versions_stuffer, version % 10)); @@ -73,17 +35,52 @@ S2N_RESULT s2n_write_test_supported_versions_extension(struct s2n_blob *supporte return S2N_RESULT_OK; } -S2N_RESULT s2n_write_malformed_supported_versions_extension(struct s2n_blob *supported_versions_blob) +struct s2n_overwrite_client_hello_ctx { + uint8_t client_hello_version; + uint8_t client_supported_version; + uint8_t extension_length; + + uint8_t supported_versions_data[3]; + int invoked_count; +}; + +static int s2n_overwrite_client_hello_cb(struct s2n_connection *conn, void *ctx) { - RESULT_ENSURE_REF(supported_versions_blob); + EXPECT_NOT_NULL(conn); + EXPECT_NOT_NULL(ctx); - struct s2n_stuffer supported_versions_stuffer = { 0 }; - RESULT_GUARD_POSIX(s2n_stuffer_init(&supported_versions_stuffer, supported_versions_blob)); + struct s2n_overwrite_client_hello_ctx *context = (struct s2n_overwrite_client_hello_ctx *) ctx; + context->invoked_count += 1; + + struct s2n_client_hello *client_hello = s2n_connection_get_client_hello(conn); + EXPECT_NOT_NULL(client_hello); - /* Write an invalid length byte. */ - RESULT_GUARD_POSIX(s2n_stuffer_write_uint8(&supported_versions_stuffer, 11)); + if (context->extension_length) { + struct s2n_blob supported_versions_blob = { 0 }; + EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, context->supported_versions_data, + sizeof(context->supported_versions_data))); - return S2N_RESULT_OK; + s2n_extension_type_id supported_versions_id = 0; + EXPECT_SUCCESS(s2n_extension_supported_iana_value_to_id(S2N_EXTENSION_SUPPORTED_VERSIONS, &supported_versions_id)); + s2n_parsed_extension *extension = &client_hello->extensions.parsed_extensions[supported_versions_id]; + + EXPECT_OK(s2n_write_test_supported_versions_extension(&supported_versions_blob, + context->client_supported_version, context->extension_length)); + + extension->extension_type = S2N_EXTENSION_SUPPORTED_VERSIONS; + extension->extension = supported_versions_blob; + } + + /* The client version fields are set when parsing the client hello before the client hello + * callback is invoked. The version fields are overridden to fake receiving a client hello with + * a different version. + */ + if (context->client_hello_version) { + conn->client_hello_version = context->client_hello_version; + conn->client_protocol_version = context->client_hello_version; + } + + return S2N_SUCCESS; } int main(int argc, char **argv) @@ -104,32 +101,25 @@ int main(int argc, char **argv) /* Test protocol version getters on the server when a supported versions extension is received */ for (uint8_t server_version = S2N_TLS12; server_version <= S2N_TLS13; server_version++) { - for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS13; client_hello_version++) { + for (uint8_t client_hello_version = S2N_SSLv3; client_hello_version <= S2N_TLS12; client_hello_version++) { for (uint8_t client_supported_version = S2N_SSLv3; client_supported_version <= S2N_TLS13; client_supported_version++) { DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); if (server_version == S2N_TLS12) { EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); - } else { - if (!s2n_is_tls13_fully_supported()) { - continue; - } + } else if (s2n_is_tls13_fully_supported()) { EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all")); + } else { + continue; } - uint8_t supported_versions_data[3] = { 0 }; - struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, sizeof(supported_versions_data))); - EXPECT_OK(s2n_write_test_supported_versions_extension(&supported_versions_blob, client_supported_version)); - - /* The override_supported_versions client hello callback is used to overwrite the - * supported versions extension before the extension is processed. - */ - struct s2n_override_extension_ctx context = { - .extension_blob = supported_versions_blob + struct s2n_overwrite_client_hello_ctx context = { + .client_hello_version = client_hello_version, + .client_supported_version = client_supported_version, + .extension_length = 2, }; - EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_overwrite_client_hello_cb, &context)); DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); @@ -141,15 +131,11 @@ int main(int argc, char **argv) EXPECT_NOT_NULL(server); EXPECT_SUCCESS(s2n_connection_set_config(server, config)); - EXPECT_SUCCESS(s2n_client_hello_send(client)); + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); - /* Overwrite the client hello version according to the test case. */ - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_OK(s2n_write_protocol_version(hello_stuffer, client_hello_version)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); - - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_OK(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_HELLO)); EXPECT_EQUAL(context.invoked_count, 1); /* Ensure that a supported versions extension was received. */ @@ -159,6 +145,7 @@ int main(int argc, char **argv) EXPECT_TRUE(supported_versions_received); EXPECT_EQUAL(s2n_connection_get_server_protocol_version(server), server_version); + EXPECT_EQUAL(s2n_connection_get_client_hello_version(server), client_hello_version); /* The reported client protocol version should always match the version specified * in the supported versions extension, even for TLS 1.2 servers which don't @@ -166,11 +153,6 @@ int main(int argc, char **argv) */ EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_supported_version); - /* Clients indicate support for TLS 1.3 in the supported versions extension, not - * the client hello version. A client hello version above TLS 1.2 is never reported. - */ - EXPECT_EQUAL(s2n_connection_get_client_hello_version(server), MIN(client_hello_version, S2N_TLS12)); - uint8_t actual_protocol_version = s2n_connection_get_actual_protocol_version(server); if (server_version == S2N_TLS12) { /* For backwards compatibility, TLS 1.2 servers always use the client hello @@ -203,13 +185,17 @@ int main(int argc, char **argv) if (server_version == S2N_TLS12) { EXPECT_SUCCESS(s2n_config_set_cipher_preferences(server_config, "test_all_tls12")); - } else { - if (!s2n_is_tls13_fully_supported()) { - continue; - } + } else if (s2n_is_tls13_fully_supported()) { EXPECT_SUCCESS(s2n_config_set_cipher_preferences(server_config, "test_all")); + } else { + continue; } + struct s2n_overwrite_client_hello_ctx context = { + .client_hello_version = client_hello_version, + }; + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(server_config, s2n_overwrite_client_hello_cb, &context)); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); EXPECT_NOT_NULL(client); @@ -220,15 +206,12 @@ int main(int argc, char **argv) EXPECT_NOT_NULL(server); EXPECT_SUCCESS(s2n_connection_set_config(server, server_config)); - EXPECT_SUCCESS(s2n_client_hello_send(client)); + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); - /* Overwrite the client hello version according to the test case. */ - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_OK(s2n_write_protocol_version(hello_stuffer, client_hello_version)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); - - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_OK(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_HELLO)); + EXPECT_EQUAL(context.invoked_count, 1); /* Ensure that a supported versions extension wasn't received. */ bool supported_versions_received = false; @@ -264,8 +247,7 @@ int main(int argc, char **argv) DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); - EXPECT_SUCCESS(s2n_connection_set_io_pair(client, &io_pair)); - EXPECT_SUCCESS(s2n_connection_set_io_pair(server, &io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); EXPECT_OK(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_CERT)); @@ -283,15 +265,13 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); - uint8_t supported_versions_data[1] = { 0 }; - struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, sizeof(supported_versions_data))); - EXPECT_OK(s2n_write_malformed_supported_versions_extension(&supported_versions_blob)); - - struct s2n_override_extension_ctx context = { - .extension_blob = supported_versions_blob + struct s2n_overwrite_client_hello_ctx context = { + .client_hello_version = client_hello_version, + .client_supported_version = S2N_TLS13, + /* Write an invalid length */ + .extension_length = 11, }; - EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_overwrite_client_hello_cb, &context)); DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); @@ -303,15 +283,11 @@ int main(int argc, char **argv) EXPECT_NOT_NULL(server); EXPECT_SUCCESS(s2n_connection_set_config(server, config)); - EXPECT_SUCCESS(s2n_client_hello_send(client)); - - /* Overwrite the client hello version according to the test case. */ - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_OK(s2n_write_protocol_version(hello_stuffer, client_hello_version)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_OK(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_HELLO)); EXPECT_EQUAL(context.invoked_count, 1); /* Ensure that a supported versions extension was received. */ @@ -331,17 +307,54 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); - uint8_t invalid_supported_version = S2N_TLS13 + 10; + struct s2n_overwrite_client_hello_ctx context = { + .client_hello_version = client_hello_version, + /* Write an invalid version */ + .client_supported_version = S2N_TLS13 + 10, + .extension_length = 2, + }; + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_overwrite_client_hello_cb, &context)); - uint8_t supported_versions_data[3] = { 0 }; - struct s2n_blob supported_versions_blob = { 0 }; - EXPECT_SUCCESS(s2n_blob_init(&supported_versions_blob, supported_versions_data, sizeof(supported_versions_data))); - EXPECT_OK(s2n_write_test_supported_versions_extension(&supported_versions_blob, invalid_supported_version)); + DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(client); + EXPECT_SUCCESS(s2n_connection_set_config(client, config)); + + DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(server); + EXPECT_SUCCESS(s2n_connection_set_config(server, config)); + + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); + + EXPECT_OK(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_HELLO)); + EXPECT_EQUAL(context.invoked_count, 1); + + /* Ensure that a supported versions extension was received. */ + bool supported_versions_received = false; + EXPECT_SUCCESS(s2n_client_hello_has_extension(&server->client_hello, S2N_EXTENSION_SUPPORTED_VERSIONS, + &supported_versions_received)); + EXPECT_TRUE(supported_versions_received); - struct s2n_override_extension_ctx context = { - .extension_blob = supported_versions_blob + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); + } + + /* Ensure that TLS 1.3 servers report an unknown protocol version if a supported versions + * extension can't be processed + */ + if (s2n_is_tls13_fully_supported()) { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all")); + + struct s2n_overwrite_client_hello_ctx context = { + .client_supported_version = S2N_TLS13, + /* Write an invalid length */ + .extension_length = 11, }; - EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_override_supported_versions_cb, &context)); + EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_overwrite_client_hello_cb, &context)); DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); @@ -352,16 +365,14 @@ int main(int argc, char **argv) s2n_connection_ptr_free); EXPECT_NOT_NULL(server); EXPECT_SUCCESS(s2n_connection_set_config(server, config)); + EXPECT_SUCCESS(s2n_connection_set_blinding(server, S2N_SELF_SERVICE_BLINDING)); - struct s2n_stuffer *hello_stuffer = &client->handshake.io; - EXPECT_SUCCESS(s2n_client_hello_send(client)); - - /* Overwrite the client hello version according to the test case. */ - EXPECT_SUCCESS(s2n_stuffer_rewrite(hello_stuffer)); - EXPECT_OK(s2n_write_protocol_version(hello_stuffer, client_hello_version)); - EXPECT_SUCCESS(s2n_stuffer_write(&server->handshake.io, &hello_stuffer->blob)); + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connections_set_io_pair(client, server, &io_pair)); - EXPECT_SUCCESS(s2n_client_hello_recv(server)); + EXPECT_ERROR_WITH_ERRNO(s2n_negotiate_test_server_and_client_until_message(server, client, SERVER_HELLO), + S2N_ERR_BAD_MESSAGE); EXPECT_EQUAL(context.invoked_count, 1); /* Ensure that a supported versions extension was received. */ @@ -370,7 +381,9 @@ int main(int argc, char **argv) &supported_versions_received)); EXPECT_TRUE(supported_versions_received); - EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), client_hello_version); + EXPECT_EQUAL(s2n_connection_get_server_protocol_version(server), S2N_TLS13); + EXPECT_EQUAL(s2n_connection_get_client_protocol_version(server), s2n_unknown_protocol_version); + EXPECT_EQUAL(s2n_connection_get_actual_protocol_version(server), s2n_unknown_protocol_version); } END_TEST(); diff --git a/tls/extensions/s2n_client_supported_versions.c b/tls/extensions/s2n_client_supported_versions.c index 577640eecee..307852235fc 100644 --- a/tls/extensions/s2n_client_supported_versions.c +++ b/tls/extensions/s2n_client_supported_versions.c @@ -128,11 +128,16 @@ static S2N_RESULT s2n_client_supported_versions_recv_impl(struct s2n_connection RESULT_ENSURE_REF(conn); RESULT_ENSURE_REF(extension); - RESULT_GUARD_POSIX(s2n_extensions_client_supported_versions_process(conn, extension, &conn->client_protocol_version, - &conn->actual_protocol_version)); + uint8_t client_protocol_version = s2n_unknown_protocol_version; + uint8_t actual_protocol_version = s2n_unknown_protocol_version; + RESULT_GUARD_POSIX(s2n_extensions_client_supported_versions_process(conn, extension, &client_protocol_version, + &actual_protocol_version)); - RESULT_ENSURE(conn->client_protocol_version != s2n_unknown_protocol_version, S2N_ERR_UNKNOWN_PROTOCOL_VERSION); - RESULT_ENSURE(conn->actual_protocol_version != s2n_unknown_protocol_version, S2N_ERR_PROTOCOL_VERSION_UNSUPPORTED); + RESULT_ENSURE(client_protocol_version != s2n_unknown_protocol_version, S2N_ERR_UNKNOWN_PROTOCOL_VERSION); + RESULT_ENSURE(actual_protocol_version != s2n_unknown_protocol_version, S2N_ERR_PROTOCOL_VERSION_UNSUPPORTED); + + conn->client_protocol_version = client_protocol_version; + conn->actual_protocol_version = actual_protocol_version; return S2N_RESULT_OK; } @@ -156,6 +161,9 @@ static int s2n_client_supported_versions_recv(struct s2n_connection *conn, struc return S2N_SUCCESS; } + conn->client_protocol_version = s2n_unknown_protocol_version; + conn->actual_protocol_version = s2n_unknown_protocol_version; + s2n_result result = s2n_client_supported_versions_recv_impl(conn, extension); if (s2n_result_is_error(result)) { s2n_queue_reader_unsupported_protocol_version_alert(conn);