Skip to content

Commit

Permalink
add ServerConfig tags, export TLSClientAuth (#176)
Browse files Browse the repository at this point in the history
Export TLSClient auth, change ServerConfig.Client auth to a pointer, and add YAML tags.
This will allow the elastic-agent to pass all mTLS attributes to fleet-server without injecting
"required" when "none" is passed.
  • Loading branch information
michel-laterman authored Jan 12, 2024
1 parent 1cf0afd commit f5fabbc
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 19 deletions.
23 changes: 14 additions & 9 deletions transport/tlscommon/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ import (

// ServerConfig defines the user configurable tls options for any TCP based service.
type ServerConfig struct {
Enabled *bool `config:"enabled"`
VerificationMode TLSVerificationMode `config:"verification_mode"` // one of 'none', 'full', 'strict', 'certificate'
Versions []TLSVersion `config:"supported_protocols"`
CipherSuites []CipherSuite `config:"cipher_suites"`
CAs []string `config:"certificate_authorities"`
Certificate CertificateConfig `config:",inline"`
CurveTypes []tlsCurveType `config:"curve_types"`
ClientAuth tlsClientAuth `config:"client_authentication"` //`none`, `optional` or `required`
Enabled *bool `config:"enabled" yaml:"enabled,omitempty"`
VerificationMode TLSVerificationMode `config:"verification_mode" yaml:"verification_mode,omitempty"` // one of 'none', 'full', 'strict', 'certificate'
Versions []TLSVersion `config:"supported_protocols" yaml:"supported_protocols,omitempty"`
CipherSuites []CipherSuite `config:"cipher_suites" yaml:"cipher_suites,omitempty"`
CAs []string `config:"certificate_authorities" yaml:"certificate_authorities,omitempty"`
Certificate CertificateConfig `config:",inline" yaml:",inline"`
CurveTypes []tlsCurveType `config:"curve_types" yaml:"curve_types,omitempty"`
ClientAuth *TLSClientAuth `config:"client_authentication" yaml:"client_authentication,omitempty"` //`none`, `optional` or `required`
CASha256 []string `config:"ca_sha256" yaml:"ca_sha256,omitempty"`
}

Expand Down Expand Up @@ -80,6 +80,11 @@ func LoadTLSServerConfig(config *ServerConfig) (*TLSConfig, error) {
certs = []tls.Certificate{*cert}
}

clientAuth := TLSClientAuthNone
if config.ClientAuth != nil {
clientAuth = *config.ClientAuth
}

// return config if no error occurred
return &TLSConfig{
Versions: config.Versions,
Expand All @@ -88,7 +93,7 @@ func LoadTLSServerConfig(config *ServerConfig) (*TLSConfig, error) {
ClientCAs: cas,
CipherSuites: config.CipherSuites,
CurvePreferences: curves,
ClientAuth: tls.ClientAuthType(config.ClientAuth),
ClientAuth: tls.ClientAuthType(clientAuth),
CASha256: config.CASha256,
}, nil
}
Expand Down
94 changes: 94 additions & 0 deletions transport/tlscommon/server_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package tlscommon

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)

// variables so we can use pointers in tests
var (
required = TLSClientAuthRequired
optional = TLSClientAuthOptional
none = TLSClientAuthNone
)

func Test_ServerConfig_Serialization_ClientAuth(t *testing.T) {
tests := []struct {
name string
cfg ServerConfig
clientAuth *TLSClientAuth
}{{
name: "with ca",
cfg: ServerConfig{
Certificate: CertificateConfig{
Certificate: "/path/to/cert.crt",
Key: "/path/to/cert.key",
},
CAs: []string{"/path/to/ca.crt"},
},
clientAuth: &required,
}, {
name: "no ca",
cfg: ServerConfig{
Certificate: CertificateConfig{
Certificate: "/path/to/cert.crt",
Key: "/path/to/cert.key",
},
},
clientAuth: nil,
}, {
name: "with ca and client auth none",
cfg: ServerConfig{
Certificate: CertificateConfig{
Certificate: "/path/to/cert.crt",
Key: "/path/to/cert.key",
},
CAs: []string{"/path/to/ca.crt"},
ClientAuth: &none,
},
clientAuth: &none,
}, {
name: "no ca and client auth none",
cfg: ServerConfig{
Certificate: CertificateConfig{
Certificate: "/path/to/cert.crt",
Key: "/path/to/cert.key",
},
ClientAuth: &none,
},
clientAuth: &none,
}}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p, err := yaml.Marshal(&tc.cfg)
require.NoError(t, err)
t.Logf("YAML Config:\n%s", string(p))
scfg := mustLoadServerConfig(t, string(p))
if tc.clientAuth == nil {
require.Nil(t, scfg.ClientAuth)
} else {
require.Equal(t, *tc.clientAuth, *scfg.ClientAuth)
}
})
}
}
43 changes: 33 additions & 10 deletions transport/tlscommon/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ var tlsCipherSuites = map[string]CipherSuite{
var tlsCipherSuitesInverse = make(map[CipherSuite]string, len(tlsCipherSuites))
var tlsRenegotiationSupportTypesInverse = make(map[TLSRenegotiationSupport]string, len(tlsRenegotiationSupportTypes))
var tlsVerificationModesInverse = make(map[TLSVerificationMode]string, len(tlsVerificationModes))
var tlsClientAuthTypesInverse = make(map[TLSClientAuth]string, len(tlsClientAuthTypes))

// Init creates a inverse representation of the values mapping.
func init() {
Expand All @@ -88,6 +89,10 @@ func init() {
for name, t := range tlsVerificationModes {
tlsVerificationModesInverse[t] = name
}

for name, t := range tlsClientAuthTypes {
tlsClientAuthTypesInverse[t] = name
}
}

var tlsCurveTypes = map[string]tlsCurveType{
Expand All @@ -103,20 +108,20 @@ var tlsRenegotiationSupportTypes = map[string]TLSRenegotiationSupport{
"freely": TLSRenegotiationSupport(tls.RenegotiateFreelyAsClient),
}

type tlsClientAuth int
type TLSClientAuth int

const (
tlsClientAuthNone tlsClientAuth = tlsClientAuth(tls.NoClientCert)
tlsClientAuthOptional = tlsClientAuth(tls.VerifyClientCertIfGiven)
tlsClientAuthRequired = tlsClientAuth(tls.RequireAndVerifyClientCert)
TLSClientAuthNone TLSClientAuth = TLSClientAuth(tls.NoClientCert)
TLSClientAuthOptional = TLSClientAuth(tls.VerifyClientCertIfGiven)
TLSClientAuthRequired = TLSClientAuth(tls.RequireAndVerifyClientCert)

unknownType = "unknown"
)

var tlsClientAuthTypes = map[string]tlsClientAuth{
"none": tlsClientAuthNone,
"optional": tlsClientAuthOptional,
"required": tlsClientAuthRequired,
var tlsClientAuthTypes = map[string]TLSClientAuth{
"none": TLSClientAuthNone,
"optional": TLSClientAuthOptional,
"required": TLSClientAuthRequired,
}

// TLSVerificationMode represents the type of verification to do on the remote host:
Expand Down Expand Up @@ -179,10 +184,28 @@ func (m *TLSVerificationMode) Unpack(in interface{}) error {
return nil
}

func (m *tlsClientAuth) Unpack(s string) error {
func (m TLSClientAuth) String() string {
if s, ok := tlsClientAuthTypesInverse[m]; ok {
return s
}
return unknownType
}

func (m TLSClientAuth) MarshalText() ([]byte, error) {
if s, ok := tlsClientAuthTypesInverse[m]; ok {
return []byte(s), nil
}
return nil, fmt.Errorf("could not marshal '%+v' to text", m)
}

func (m *TLSClientAuth) Unpack(s string) error {
if s == "" {
*m = TLSClientAuthNone
return nil
}
mode, found := tlsClientAuthTypes[s]
if !found {
return fmt.Errorf("unknown client authentication mode'%v'", s)
return fmt.Errorf("unknown client authentication mode '%v'", s)
}

*m = mode
Expand Down
165 changes: 165 additions & 0 deletions transport/tlscommon/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package tlscommon

import (
"fmt"
"testing"

"github.com/elastic/elastic-agent-libs/config"
"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -66,3 +68,166 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, cfg.VerificationMode, VerifyFull)
}

func TestTLSClientAuthUnpack(t *testing.T) {
tests := []struct {
val string
expect TLSClientAuth
err error
}{{
val: "",
expect: TLSClientAuthNone,
err: nil,
}, {
val: "none",
expect: TLSClientAuthNone,
err: nil,
}, {
val: "optional",
expect: TLSClientAuthOptional,
err: nil,
}, {
val: "required",
expect: TLSClientAuthRequired,
err: nil,
}, {
val: "invalid",
err: fmt.Errorf("unknown client authentication mode 'invalid'"),
}}
for _, tc := range tests {
t.Run(tc.val, func(t *testing.T) {
var auth TLSClientAuth
err := auth.Unpack(tc.val)
assert.Equal(t, tc.expect, auth)
if tc.err != nil {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

func TestTLSClientAuthMarshalText(t *testing.T) {
tests := []struct {
name string
val TLSClientAuth
expect []byte
}{{
name: "no value",
expect: []byte("none"),
}, {
name: "none",
val: TLSClientAuthNone,
expect: []byte("none"),
}, {
name: "optional",
val: TLSClientAuthOptional,
expect: []byte("optional"),
}, {
name: "required",
val: TLSClientAuthRequired,
expect: []byte("required"),
}}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p, err := tc.val.MarshalText()
assert.Equal(t, tc.expect, p)
assert.NoError(t, err)
})
}
}

func TestLoadTLSClientAuth(t *testing.T) {
tests := []struct {
name string
yaml string
expect *TLSClientAuth
}{{
name: "no client auth value",
yaml: `
certificate: mycert.pem
key: mycert.key`,
expect: nil,
}, {
name: "client auth empty",
yaml: `
certificate: mycert.pem
key: mycert.key
client_authentication: `,
expect: nil,
}, {
name: "client auth none",
yaml: `
certificate: mycert.pem
key: mycert.key
client_authentication: none`,
expect: &none,
}, {
name: "client auth optional",
yaml: `
certificate: mycert.pem
key: mycert.key
client_authentication: optional`,
expect: &optional,
}, {
name: "client auth required",
yaml: `
certificate: mycert.pem
key: mycert.key
client_authentication: required`,
expect: &required,
}, {
name: "certificate_authorities is not null, no client_authentication",
yaml: `
certificate: mycert.pem
key: mycert.key
certificate_authorities: [ca.crt]`,
expect: &required, // NOTE Unpack will insert required if cas are present and no client_authentication is passed
}, {
name: "certificate_authorities is not null, client_authentication is none",
yaml: `
certificate: mycert.pem
key: mycert.key
client_authentication: none
certificate_authorities: [ca.crt]`,
expect: &none,
}}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cfg := mustLoadServerConfig(t, tc.yaml)
if tc.expect == nil {
assert.Nil(t, cfg.ClientAuth)
} else {
assert.Equal(t, *tc.expect, *cfg.ClientAuth)
}
})
}

t.Run("invalid", func(t *testing.T) {
_, err := loadServerConfig(`client_authentication: invalid`)
assert.Error(t, err)
})
}

func loadServerConfig(yamlStr string) (*ServerConfig, error) {
var cfg ServerConfig
config, err := config.NewConfigWithYAML([]byte(yamlStr), "")
if err != nil {
return nil, err
}

if err := config.Unpack(&cfg); err != nil {
return nil, err
}
return &cfg, nil
}

func mustLoadServerConfig(t *testing.T, yamlStr string) *ServerConfig {
t.Helper()
cfg, err := loadServerConfig(yamlStr)
if err != nil {
t.Fatal(err)
}
return cfg
}

0 comments on commit f5fabbc

Please sign in to comment.