From f5fabbc72d4a4bfeb7591d1a7a575651b793c6d4 Mon Sep 17 00:00:00 2001 From: Michel Laterman <82832767+michel-laterman@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:27:48 -0600 Subject: [PATCH] add ServerConfig tags, export TLSClientAuth (#176) Export TLSClient auth, change ServerConfig.Client auth to a pointer, and add YAML tags. This will allow the elastic-agent to pass all mTLS attributes to fleet-server without injecting "required" when "none" is passed. --- transport/tlscommon/server_config.go | 23 +-- transport/tlscommon/server_config_test.go | 94 ++++++++++++ transport/tlscommon/types.go | 43 ++++-- transport/tlscommon/types_test.go | 165 ++++++++++++++++++++++ 4 files changed, 306 insertions(+), 19 deletions(-) create mode 100644 transport/tlscommon/server_config_test.go diff --git a/transport/tlscommon/server_config.go b/transport/tlscommon/server_config.go index b02d03be..24ab12a7 100644 --- a/transport/tlscommon/server_config.go +++ b/transport/tlscommon/server_config.go @@ -27,14 +27,14 @@ import ( // ServerConfig defines the user configurable tls options for any TCP based service. type ServerConfig struct { - Enabled *bool `config:"enabled"` - VerificationMode TLSVerificationMode `config:"verification_mode"` // one of 'none', 'full', 'strict', 'certificate' - Versions []TLSVersion `config:"supported_protocols"` - CipherSuites []CipherSuite `config:"cipher_suites"` - CAs []string `config:"certificate_authorities"` - Certificate CertificateConfig `config:",inline"` - CurveTypes []tlsCurveType `config:"curve_types"` - ClientAuth tlsClientAuth `config:"client_authentication"` //`none`, `optional` or `required` + Enabled *bool `config:"enabled" yaml:"enabled,omitempty"` + VerificationMode TLSVerificationMode `config:"verification_mode" yaml:"verification_mode,omitempty"` // one of 'none', 'full', 'strict', 'certificate' + Versions []TLSVersion `config:"supported_protocols" yaml:"supported_protocols,omitempty"` + CipherSuites []CipherSuite `config:"cipher_suites" yaml:"cipher_suites,omitempty"` + CAs []string `config:"certificate_authorities" yaml:"certificate_authorities,omitempty"` + Certificate CertificateConfig `config:",inline" yaml:",inline"` + CurveTypes []tlsCurveType `config:"curve_types" yaml:"curve_types,omitempty"` + ClientAuth *TLSClientAuth `config:"client_authentication" yaml:"client_authentication,omitempty"` //`none`, `optional` or `required` CASha256 []string `config:"ca_sha256" yaml:"ca_sha256,omitempty"` } @@ -80,6 +80,11 @@ func LoadTLSServerConfig(config *ServerConfig) (*TLSConfig, error) { certs = []tls.Certificate{*cert} } + clientAuth := TLSClientAuthNone + if config.ClientAuth != nil { + clientAuth = *config.ClientAuth + } + // return config if no error occurred return &TLSConfig{ Versions: config.Versions, @@ -88,7 +93,7 @@ func LoadTLSServerConfig(config *ServerConfig) (*TLSConfig, error) { ClientCAs: cas, CipherSuites: config.CipherSuites, CurvePreferences: curves, - ClientAuth: tls.ClientAuthType(config.ClientAuth), + ClientAuth: tls.ClientAuthType(clientAuth), CASha256: config.CASha256, }, nil } diff --git a/transport/tlscommon/server_config_test.go b/transport/tlscommon/server_config_test.go new file mode 100644 index 00000000..b12b98be --- /dev/null +++ b/transport/tlscommon/server_config_test.go @@ -0,0 +1,94 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tlscommon + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +// variables so we can use pointers in tests +var ( + required = TLSClientAuthRequired + optional = TLSClientAuthOptional + none = TLSClientAuthNone +) + +func Test_ServerConfig_Serialization_ClientAuth(t *testing.T) { + tests := []struct { + name string + cfg ServerConfig + clientAuth *TLSClientAuth + }{{ + name: "with ca", + cfg: ServerConfig{ + Certificate: CertificateConfig{ + Certificate: "/path/to/cert.crt", + Key: "/path/to/cert.key", + }, + CAs: []string{"/path/to/ca.crt"}, + }, + clientAuth: &required, + }, { + name: "no ca", + cfg: ServerConfig{ + Certificate: CertificateConfig{ + Certificate: "/path/to/cert.crt", + Key: "/path/to/cert.key", + }, + }, + clientAuth: nil, + }, { + name: "with ca and client auth none", + cfg: ServerConfig{ + Certificate: CertificateConfig{ + Certificate: "/path/to/cert.crt", + Key: "/path/to/cert.key", + }, + CAs: []string{"/path/to/ca.crt"}, + ClientAuth: &none, + }, + clientAuth: &none, + }, { + name: "no ca and client auth none", + cfg: ServerConfig{ + Certificate: CertificateConfig{ + Certificate: "/path/to/cert.crt", + Key: "/path/to/cert.key", + }, + ClientAuth: &none, + }, + clientAuth: &none, + }} + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p, err := yaml.Marshal(&tc.cfg) + require.NoError(t, err) + t.Logf("YAML Config:\n%s", string(p)) + scfg := mustLoadServerConfig(t, string(p)) + if tc.clientAuth == nil { + require.Nil(t, scfg.ClientAuth) + } else { + require.Equal(t, *tc.clientAuth, *scfg.ClientAuth) + } + }) + } +} diff --git a/transport/tlscommon/types.go b/transport/tlscommon/types.go index a4e7a731..2dbf5359 100644 --- a/transport/tlscommon/types.go +++ b/transport/tlscommon/types.go @@ -74,6 +74,7 @@ var tlsCipherSuites = map[string]CipherSuite{ var tlsCipherSuitesInverse = make(map[CipherSuite]string, len(tlsCipherSuites)) var tlsRenegotiationSupportTypesInverse = make(map[TLSRenegotiationSupport]string, len(tlsRenegotiationSupportTypes)) var tlsVerificationModesInverse = make(map[TLSVerificationMode]string, len(tlsVerificationModes)) +var tlsClientAuthTypesInverse = make(map[TLSClientAuth]string, len(tlsClientAuthTypes)) // Init creates a inverse representation of the values mapping. func init() { @@ -88,6 +89,10 @@ func init() { for name, t := range tlsVerificationModes { tlsVerificationModesInverse[t] = name } + + for name, t := range tlsClientAuthTypes { + tlsClientAuthTypesInverse[t] = name + } } var tlsCurveTypes = map[string]tlsCurveType{ @@ -103,20 +108,20 @@ var tlsRenegotiationSupportTypes = map[string]TLSRenegotiationSupport{ "freely": TLSRenegotiationSupport(tls.RenegotiateFreelyAsClient), } -type tlsClientAuth int +type TLSClientAuth int const ( - tlsClientAuthNone tlsClientAuth = tlsClientAuth(tls.NoClientCert) - tlsClientAuthOptional = tlsClientAuth(tls.VerifyClientCertIfGiven) - tlsClientAuthRequired = tlsClientAuth(tls.RequireAndVerifyClientCert) + TLSClientAuthNone TLSClientAuth = TLSClientAuth(tls.NoClientCert) + TLSClientAuthOptional = TLSClientAuth(tls.VerifyClientCertIfGiven) + TLSClientAuthRequired = TLSClientAuth(tls.RequireAndVerifyClientCert) unknownType = "unknown" ) -var tlsClientAuthTypes = map[string]tlsClientAuth{ - "none": tlsClientAuthNone, - "optional": tlsClientAuthOptional, - "required": tlsClientAuthRequired, +var tlsClientAuthTypes = map[string]TLSClientAuth{ + "none": TLSClientAuthNone, + "optional": TLSClientAuthOptional, + "required": TLSClientAuthRequired, } // TLSVerificationMode represents the type of verification to do on the remote host: @@ -179,10 +184,28 @@ func (m *TLSVerificationMode) Unpack(in interface{}) error { return nil } -func (m *tlsClientAuth) Unpack(s string) error { +func (m TLSClientAuth) String() string { + if s, ok := tlsClientAuthTypesInverse[m]; ok { + return s + } + return unknownType +} + +func (m TLSClientAuth) MarshalText() ([]byte, error) { + if s, ok := tlsClientAuthTypesInverse[m]; ok { + return []byte(s), nil + } + return nil, fmt.Errorf("could not marshal '%+v' to text", m) +} + +func (m *TLSClientAuth) Unpack(s string) error { + if s == "" { + *m = TLSClientAuthNone + return nil + } mode, found := tlsClientAuthTypes[s] if !found { - return fmt.Errorf("unknown client authentication mode'%v'", s) + return fmt.Errorf("unknown client authentication mode '%v'", s) } *m = mode diff --git a/transport/tlscommon/types_test.go b/transport/tlscommon/types_test.go index 8b031ca3..7a58c5e9 100644 --- a/transport/tlscommon/types_test.go +++ b/transport/tlscommon/types_test.go @@ -18,8 +18,10 @@ package tlscommon import ( + "fmt" "testing" + "github.com/elastic/elastic-agent-libs/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -66,3 +68,166 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) { assert.NoError(t, err) assert.Equal(t, cfg.VerificationMode, VerifyFull) } + +func TestTLSClientAuthUnpack(t *testing.T) { + tests := []struct { + val string + expect TLSClientAuth + err error + }{{ + val: "", + expect: TLSClientAuthNone, + err: nil, + }, { + val: "none", + expect: TLSClientAuthNone, + err: nil, + }, { + val: "optional", + expect: TLSClientAuthOptional, + err: nil, + }, { + val: "required", + expect: TLSClientAuthRequired, + err: nil, + }, { + val: "invalid", + err: fmt.Errorf("unknown client authentication mode 'invalid'"), + }} + for _, tc := range tests { + t.Run(tc.val, func(t *testing.T) { + var auth TLSClientAuth + err := auth.Unpack(tc.val) + assert.Equal(t, tc.expect, auth) + if tc.err != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTLSClientAuthMarshalText(t *testing.T) { + tests := []struct { + name string + val TLSClientAuth + expect []byte + }{{ + name: "no value", + expect: []byte("none"), + }, { + name: "none", + val: TLSClientAuthNone, + expect: []byte("none"), + }, { + name: "optional", + val: TLSClientAuthOptional, + expect: []byte("optional"), + }, { + name: "required", + val: TLSClientAuthRequired, + expect: []byte("required"), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p, err := tc.val.MarshalText() + assert.Equal(t, tc.expect, p) + assert.NoError(t, err) + }) + } +} + +func TestLoadTLSClientAuth(t *testing.T) { + tests := []struct { + name string + yaml string + expect *TLSClientAuth + }{{ + name: "no client auth value", + yaml: ` + certificate: mycert.pem + key: mycert.key`, + expect: nil, + }, { + name: "client auth empty", + yaml: ` + certificate: mycert.pem + key: mycert.key + client_authentication: `, + expect: nil, + }, { + name: "client auth none", + yaml: ` + certificate: mycert.pem + key: mycert.key + client_authentication: none`, + expect: &none, + }, { + name: "client auth optional", + yaml: ` + certificate: mycert.pem + key: mycert.key + client_authentication: optional`, + expect: &optional, + }, { + name: "client auth required", + yaml: ` + certificate: mycert.pem + key: mycert.key + client_authentication: required`, + expect: &required, + }, { + name: "certificate_authorities is not null, no client_authentication", + yaml: ` + certificate: mycert.pem + key: mycert.key + certificate_authorities: [ca.crt]`, + expect: &required, // NOTE Unpack will insert required if cas are present and no client_authentication is passed + }, { + name: "certificate_authorities is not null, client_authentication is none", + yaml: ` + certificate: mycert.pem + key: mycert.key + client_authentication: none + certificate_authorities: [ca.crt]`, + expect: &none, + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := mustLoadServerConfig(t, tc.yaml) + if tc.expect == nil { + assert.Nil(t, cfg.ClientAuth) + } else { + assert.Equal(t, *tc.expect, *cfg.ClientAuth) + } + }) + } + + t.Run("invalid", func(t *testing.T) { + _, err := loadServerConfig(`client_authentication: invalid`) + assert.Error(t, err) + }) +} + +func loadServerConfig(yamlStr string) (*ServerConfig, error) { + var cfg ServerConfig + config, err := config.NewConfigWithYAML([]byte(yamlStr), "") + if err != nil { + return nil, err + } + + if err := config.Unpack(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func mustLoadServerConfig(t *testing.T, yamlStr string) *ServerConfig { + t.Helper() + cfg, err := loadServerConfig(yamlStr) + if err != nil { + t.Fatal(err) + } + return cfg +}