From 87016c732850981514e402f386d268ab30aa7af2 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 17 Oct 2024 17:03:27 +0530 Subject: [PATCH 1/5] Apply secure defaults to TLS configs provided through GetConfigForClient --- credentials/tls.go | 30 +++- credentials/tls_ext_test.go | 307 ++++++++++++++++++++++++++---------- 2 files changed, 245 insertions(+), 92 deletions(-) diff --git a/credentials/tls.go b/credentials/tls.go index 4114358545ef..a361bffb259a 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -200,25 +200,41 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{ // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { - tc := &tlsCreds{credinternal.CloneTLSConfig(c)} - tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) + cfg := applyDefaults(c) + if cfg.GetConfigForClient != nil { + oldFn := cfg.GetConfigForClient + cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + cfgForClient, err := oldFn(hello) + if err != nil || cfgForClient == nil { + return cfgForClient, err + } + return applyDefaults(cfgForClient), nil + } + } + tc := &tlsCreds{config: cfg} + return tc +} + +func applyDefaults(c *tls.Config) *tls.Config { + config := credinternal.CloneTLSConfig(c) + config.NextProtos = credinternal.AppendH2ToNextProtos(config.NextProtos) // If the user did not configure a MinVersion and did not configure a // MaxVersion < 1.2, use MinVersion=1.2, which is required by // https://datatracker.ietf.org/doc/html/rfc7540#section-9.2 - if tc.config.MinVersion == 0 && (tc.config.MaxVersion == 0 || tc.config.MaxVersion >= tls.VersionTLS12) { - tc.config.MinVersion = tls.VersionTLS12 + if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) { + config.MinVersion = tls.VersionTLS12 } // If the user did not configure CipherSuites, use all "secure" cipher // suites reported by the TLS package, but remove some explicitly forbidden // by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A - if tc.config.CipherSuites == nil { + if config.CipherSuites == nil { for _, cs := range tls.CipherSuites() { if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok { - tc.config.CipherSuites = append(tc.config.CipherSuites, cs.ID) + config.CipherSuites = append(config.CipherSuites, cs.ID) } } } - return tc + return config } // NewClientTLSFromCert constructs TLS credentials from the provided root diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index c817777b2f89..76e6e5a3b5dd 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -79,43 +79,67 @@ func (s) TestTLS_MinVersion12(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - // Create server creds without a minimum version. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ // MinVersion should be set to 1.2 by gRPC by default. Certificates: []tls.Certificate{serverCert}, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that supports V1.0-V1.1. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - MinVersion: tls.VersionTLS10, - MaxVersion: tls.VersionTLS11, - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without a minimum version. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - // Start server and client separately, because Start() blocks on a - // successful connection, which we will not get. - if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { - t.Fatalf("Error starting server: %v", err) - } - defer ss.Stop() + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) - cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) - if err != nil { - t.Fatalf("grpc.NewClient error: %v", err) - } - defer cc.Close() + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() - client := testgrpc.NewTestServiceClient(cc) + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } - const wantStr = "authentication handshake failed" - if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { - t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + }) } } @@ -129,35 +153,58 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { for _, cs := range tls.CipherSuites() { allCipherSuites = append(allCipherSuites, cs.ID) } - - // Create server creds that allow v1.0. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ MinVersion: tls.VersionTLS10, Certificates: []tls.Certificate{serverCert}, CipherSuites: allCipherSuites, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that supports V1.0-V1.1. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: allCipherSuites, - MinVersion: tls.VersionTLS10, - MaxVersion: tls.VersionTLS11, - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds that allow v1.0. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { - t.Fatalf("Error starting stub server: %v", err) - } - defer ss.Stop() + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: allCipherSuites, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("EmptyCall err = %v; want ", err) + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) } } @@ -165,43 +212,66 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { func (s) TestTLS_CipherSuites(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - - // Create server creds without cipher suites. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that use a forbidden suite only. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, - MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without cipher suites. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - // Start server and client separately, because Start() blocks on a - // successful connection, which we will not get. - if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { - t.Fatalf("Error starting server: %v", err) - } - defer ss.Stop() + // Create client creds that use a forbidden suite only. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) - cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) - if err != nil { - t.Fatalf("grpc.NewClient error: %v", err) - } - defer cc.Close() + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() - client := testgrpc.NewTestServiceClient(cc) + client := testgrpc.NewTestServiceClient(cc) - const wantStr = "authentication handshake failed" - if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { - t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } + }) } } @@ -210,23 +280,90 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - // Create server that allows only a forbidden cipher suite. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server that allows only a forbidden cipher suite. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create server that allows only a forbidden cipher suite. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) + + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) + } +} + +// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configures +// correctly for a server that doesn't specify the NextProtos field and uses +// GetConfigForClient to provide the TLS config during the handshake. +func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { + initialVal := envconfig.EnforceALPNEnabled + defer func() { + envconfig.EnforceALPNEnabled = initialVal + }() + envconfig.EnforceALPNEnabled = true + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Create a server that doesn't set the NextProtos field. + serverCreds := credentials.NewTLS(&tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, nil + }, }) + ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } - // Create server that allows only a forbidden cipher suite. clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, - MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + ServerName: serverName, + RootCAs: certPool, }) if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { From 67dfb321df2c0ae9c176d831488abe832ad6553d Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 17 Oct 2024 20:05:29 +0530 Subject: [PATCH 2/5] Add cases which return nil from GetConfigForClient --- credentials/tls_ext_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index 76e6e5a3b5dd..ee7cfcdaf725 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -83,6 +83,11 @@ func (s) TestTLS_MinVersion12(t *testing.T) { // MinVersion should be set to 1.2 by gRPC by default. Certificates: []tls.Certificate{serverCert}, } + noOpGetCfgForClient := serverTLS.Clone() + noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + testCases := []struct { name string serverTLS *tls.Config @@ -91,6 +96,10 @@ func (s) TestTLS_MinVersion12(t *testing.T) { name: "base_case", serverTLS: serverTLS, }, + { + name: "fallback_to_base", + serverTLS: noOpGetCfgForClient, + }, { name: "dynamic_using_get_config_for_client", serverTLS: &tls.Config{ @@ -159,6 +168,11 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { CipherSuites: allCipherSuites, } + noOpGetCfgForClient := serverTLS.Clone() + noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + testCases := []struct { name string serverTLS *tls.Config @@ -167,6 +181,10 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { name: "base_case", serverTLS: serverTLS, }, + { + name: "fallback_to_base", + serverTLS: noOpGetCfgForClient, + }, { name: "dynamic_using_get_config_for_client", serverTLS: &tls.Config{ @@ -215,6 +233,10 @@ func (s) TestTLS_CipherSuites(t *testing.T) { serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, } + noOpGetCfgForClient := serverTLS.Clone() + noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } testCases := []struct { name string @@ -224,6 +246,10 @@ func (s) TestTLS_CipherSuites(t *testing.T) { name: "base_case", serverTLS: serverTLS, }, + { + name: "fallback_to_base", + serverTLS: noOpGetCfgForClient, + }, { name: "dynamic_using_get_config_for_client", serverTLS: &tls.Config{ @@ -284,6 +310,10 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { Certificates: []tls.Certificate{serverCert}, CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, } + noOpGetCfgForClient := serverTLS.Clone() + noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } testCases := []struct { name string @@ -293,6 +323,10 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { name: "base_case", serverTLS: serverTLS, }, + { + name: "fallback_to_base", + serverTLS: noOpGetCfgForClient, + }, { name: "dynamic_using_get_config_for_client", serverTLS: &tls.Config{ From 042ec4ba8bda7e6530ea717a747bdb2c49f33ef9 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 17 Oct 2024 22:48:58 +0530 Subject: [PATCH 3/5] better variable names --- credentials/tls.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/credentials/tls.go b/credentials/tls.go index a361bffb259a..e163a473df93 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -200,10 +200,10 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{ // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { - cfg := applyDefaults(c) - if cfg.GetConfigForClient != nil { - oldFn := cfg.GetConfigForClient - cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + config := applyDefaults(c) + if config.GetConfigForClient != nil { + oldFn := config.GetConfigForClient + config.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { cfgForClient, err := oldFn(hello) if err != nil || cfgForClient == nil { return cfgForClient, err @@ -211,8 +211,7 @@ func NewTLS(c *tls.Config) TransportCredentials { return applyDefaults(cfgForClient), nil } } - tc := &tlsCreds{config: cfg} - return tc + return &tlsCreds{config: config} } func applyDefaults(c *tls.Config) *tls.Config { From 0b1146d634efba63fe6d815ea3084dd351b37fff Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 18 Oct 2024 10:53:56 +0530 Subject: [PATCH 4/5] Use a getter for tls config in tests --- credentials/tls_ext_test.go | 125 ++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index ee7cfcdaf725..91726cd19b7e 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -83,29 +83,33 @@ func (s) TestTLS_MinVersion12(t *testing.T) { // MinVersion should be set to 1.2 by gRPC by default. Certificates: []tls.Certificate{serverCert}, } - noOpGetCfgForClient := serverTLS.Clone() - noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { - return nil, nil - } testCases := []struct { name string - serverTLS *tls.Config + serverTLS func() *tls.Config }{ { name: "base_case", - serverTLS: serverTLS, + serverTLS: func() *tls.Config { return serverTLS }, }, { - name: "fallback_to_base", - serverTLS: noOpGetCfgForClient, + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := serverTLS.Clone() + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, }, { name: "dynamic_using_get_config_for_client", - serverTLS: &tls.Config{ - GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil - }, + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + } }, }, } @@ -113,7 +117,7 @@ func (s) TestTLS_MinVersion12(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create server creds without a minimum version. - serverCreds := credentials.NewTLS(tc.serverTLS) + serverCreds := credentials.NewTLS(tc.serverTLS()) ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -168,29 +172,32 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { CipherSuites: allCipherSuites, } - noOpGetCfgForClient := serverTLS.Clone() - noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { - return nil, nil - } - testCases := []struct { name string - serverTLS *tls.Config + serverTLS func() *tls.Config }{ { name: "base_case", - serverTLS: serverTLS, + serverTLS: func() *tls.Config { return serverTLS }, }, { - name: "fallback_to_base", - serverTLS: noOpGetCfgForClient, + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := serverTLS.Clone() + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, }, { name: "dynamic_using_get_config_for_client", - serverTLS: &tls.Config{ - GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil - }, + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + } }, }, } @@ -198,7 +205,7 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create server creds that allow v1.0. - serverCreds := credentials.NewTLS(tc.serverTLS) + serverCreds := credentials.NewTLS(tc.serverTLS()) ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -233,29 +240,32 @@ func (s) TestTLS_CipherSuites(t *testing.T) { serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, } - noOpGetCfgForClient := serverTLS.Clone() - noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { - return nil, nil - } - testCases := []struct { name string - serverTLS *tls.Config + serverTLS func() *tls.Config }{ { name: "base_case", - serverTLS: serverTLS, + serverTLS: func() *tls.Config { return serverTLS }, }, { - name: "fallback_to_base", - serverTLS: noOpGetCfgForClient, + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := serverTLS.Clone() + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, }, { name: "dynamic_using_get_config_for_client", - serverTLS: &tls.Config{ - GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil - }, + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + } }, }, } @@ -263,7 +273,7 @@ func (s) TestTLS_CipherSuites(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create server creds without cipher suites. - serverCreds := credentials.NewTLS(tc.serverTLS) + serverCreds := credentials.NewTLS(tc.serverTLS()) ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -310,29 +320,32 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { Certificates: []tls.Certificate{serverCert}, CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, } - noOpGetCfgForClient := serverTLS.Clone() - noOpGetCfgForClient.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { - return nil, nil - } - testCases := []struct { name string - serverTLS *tls.Config + serverTLS func() *tls.Config }{ { name: "base_case", - serverTLS: serverTLS, + serverTLS: func() *tls.Config { return serverTLS }, }, { - name: "fallback_to_base", - serverTLS: noOpGetCfgForClient, + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := serverTLS.Clone() + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, }, { name: "dynamic_using_get_config_for_client", - serverTLS: &tls.Config{ - GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil - }, + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + } }, }, } @@ -340,7 +353,7 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create server that allows only a forbidden cipher suite. - serverCreds := credentials.NewTLS(tc.serverTLS) + serverCreds := credentials.NewTLS(tc.serverTLS()) ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -367,7 +380,7 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { } } -// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configures +// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured // correctly for a server that doesn't specify the NextProtos field and uses // GetConfigForClient to provide the TLS config during the handshake. func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { From 5da078b9c2fa7b5199b1bd0a7ad928055914247b Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Mon, 21 Oct 2024 10:38:17 +0530 Subject: [PATCH 5/5] Inline TLS configs in test cases --- credentials/tls_ext_test.go | 94 +++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index 91726cd19b7e..22881a6f497a 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -79,23 +79,26 @@ func (s) TestTLS_MinVersion12(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - serverTLS := &tls.Config{ - // MinVersion should be set to 1.2 by gRPC by default. - Certificates: []tls.Certificate{serverCert}, - } - testCases := []struct { name string serverTLS func() *tls.Config }{ { - name: "base_case", - serverTLS: func() *tls.Config { return serverTLS }, + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + } + }, }, { name: "fallback_to_base", serverTLS: func() *tls.Config { - config := serverTLS.Clone() + config := &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + } config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } @@ -107,7 +110,10 @@ func (s) TestTLS_MinVersion12(t *testing.T) { serverTLS: func() *tls.Config { return &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil + return &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + }, nil }, } }, @@ -166,24 +172,28 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { for _, cs := range tls.CipherSuites() { allCipherSuites = append(allCipherSuites, cs.ID) } - serverTLS := &tls.Config{ - MinVersion: tls.VersionTLS10, - Certificates: []tls.Certificate{serverCert}, - CipherSuites: allCipherSuites, - } - testCases := []struct { name string serverTLS func() *tls.Config }{ { - name: "base_case", - serverTLS: func() *tls.Config { return serverTLS }, + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + } + }, }, { name: "fallback_to_base", serverTLS: func() *tls.Config { - config := serverTLS.Clone() + config := &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + } config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } @@ -195,7 +205,11 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { serverTLS: func() *tls.Config { return &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil + return &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + }, nil }, } }, @@ -237,21 +251,24 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { func (s) TestTLS_CipherSuites(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - serverTLS := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - } testCases := []struct { name string serverTLS func() *tls.Config }{ { - name: "base_case", - serverTLS: func() *tls.Config { return serverTLS }, + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + }, }, { name: "fallback_to_base", serverTLS: func() *tls.Config { - config := serverTLS.Clone() + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } @@ -263,7 +280,9 @@ func (s) TestTLS_CipherSuites(t *testing.T) { serverTLS: func() *tls.Config { return &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, nil }, } }, @@ -316,22 +335,26 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - serverTLS := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, - } testCases := []struct { name string serverTLS func() *tls.Config }{ { - name: "base_case", - serverTLS: func() *tls.Config { return serverTLS }, + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + }, }, { name: "fallback_to_base", serverTLS: func() *tls.Config { - config := serverTLS.Clone() + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } @@ -343,7 +366,10 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { serverTLS: func() *tls.Config { return &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return serverTLS, nil + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + }, nil }, } },