diff --git a/credentials/tls.go b/credentials/tls.go index 4114358545ef..a361bffb259a 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -200,25 +200,41 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{ // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { - tc := &tlsCreds{credinternal.CloneTLSConfig(c)} - tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) + cfg := applyDefaults(c) + if cfg.GetConfigForClient != nil { + oldFn := cfg.GetConfigForClient + cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + cfgForClient, err := oldFn(hello) + if err != nil || cfgForClient == nil { + return cfgForClient, err + } + return applyDefaults(cfgForClient), nil + } + } + tc := &tlsCreds{config: cfg} + return tc +} + +func applyDefaults(c *tls.Config) *tls.Config { + config := credinternal.CloneTLSConfig(c) + config.NextProtos = credinternal.AppendH2ToNextProtos(config.NextProtos) // If the user did not configure a MinVersion and did not configure a // MaxVersion < 1.2, use MinVersion=1.2, which is required by // https://datatracker.ietf.org/doc/html/rfc7540#section-9.2 - if tc.config.MinVersion == 0 && (tc.config.MaxVersion == 0 || tc.config.MaxVersion >= tls.VersionTLS12) { - tc.config.MinVersion = tls.VersionTLS12 + if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) { + config.MinVersion = tls.VersionTLS12 } // If the user did not configure CipherSuites, use all "secure" cipher // suites reported by the TLS package, but remove some explicitly forbidden // by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A - if tc.config.CipherSuites == nil { + if config.CipherSuites == nil { for _, cs := range tls.CipherSuites() { if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok { - tc.config.CipherSuites = append(tc.config.CipherSuites, cs.ID) + config.CipherSuites = append(config.CipherSuites, cs.ID) } } } - return tc + return config } // NewClientTLSFromCert constructs TLS credentials from the provided root diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index c817777b2f89..76e6e5a3b5dd 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -79,43 +79,67 @@ func (s) TestTLS_MinVersion12(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - // Create server creds without a minimum version. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ // MinVersion should be set to 1.2 by gRPC by default. Certificates: []tls.Certificate{serverCert}, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that supports V1.0-V1.1. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - MinVersion: tls.VersionTLS10, - MaxVersion: tls.VersionTLS11, - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without a minimum version. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - // Start server and client separately, because Start() blocks on a - // successful connection, which we will not get. - if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { - t.Fatalf("Error starting server: %v", err) - } - defer ss.Stop() + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) - cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) - if err != nil { - t.Fatalf("grpc.NewClient error: %v", err) - } - defer cc.Close() + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() - client := testgrpc.NewTestServiceClient(cc) + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } - const wantStr = "authentication handshake failed" - if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { - t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + }) } } @@ -129,35 +153,58 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { for _, cs := range tls.CipherSuites() { allCipherSuites = append(allCipherSuites, cs.ID) } - - // Create server creds that allow v1.0. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ MinVersion: tls.VersionTLS10, Certificates: []tls.Certificate{serverCert}, CipherSuites: allCipherSuites, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that supports V1.0-V1.1. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: allCipherSuites, - MinVersion: tls.VersionTLS10, - MaxVersion: tls.VersionTLS11, - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds that allow v1.0. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { - t.Fatalf("Error starting stub server: %v", err) - } - defer ss.Stop() + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: allCipherSuites, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("EmptyCall err = %v; want ", err) + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) } } @@ -165,43 +212,66 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) { func (s) TestTLS_CipherSuites(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - - // Create server creds without cipher suites. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, - }) - ss := stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, }, } - // Create client creds that use a forbidden suite only. - clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, - MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without cipher suites. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } - // Start server and client separately, because Start() blocks on a - // successful connection, which we will not get. - if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { - t.Fatalf("Error starting server: %v", err) - } - defer ss.Stop() + // Create client creds that use a forbidden suite only. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) - cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) - if err != nil { - t.Fatalf("grpc.NewClient error: %v", err) - } - defer cc.Close() + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() - client := testgrpc.NewTestServiceClient(cc) + client := testgrpc.NewTestServiceClient(cc) - const wantStr = "authentication handshake failed" - if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { - t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } + }) } } @@ -210,23 +280,90 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - // Create server that allows only a forbidden cipher suite. - serverCreds := credentials.NewTLS(&tls.Config{ + serverTLS := &tls.Config{ Certificates: []tls.Certificate{serverCert}, CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + + testCases := []struct { + name string + serverTLS *tls.Config + }{ + { + name: "base_case", + serverTLS: serverTLS, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return serverTLS, nil + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server that allows only a forbidden cipher suite. + serverCreds := credentials.NewTLS(tc.serverTLS) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create server that allows only a forbidden cipher suite. + clientCreds := credentials.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) + + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) + } +} + +// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configures +// correctly for a server that doesn't specify the NextProtos field and uses +// GetConfigForClient to provide the TLS config during the handshake. +func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { + initialVal := envconfig.EnforceALPNEnabled + defer func() { + envconfig.EnforceALPNEnabled = initialVal + }() + envconfig.EnforceALPNEnabled = true + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Create a server that doesn't set the NextProtos field. + serverCreds := credentials.NewTLS(&tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, nil + }, }) + ss := stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } - // Create server that allows only a forbidden cipher suite. clientCreds := credentials.NewTLS(&tls.Config{ - ServerName: serverName, - RootCAs: certPool, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, - MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + ServerName: serverName, + RootCAs: certPool, }) if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {