diff --git a/transport/tlscommon/server_config_test.go b/transport/tlscommon/server_config_test.go index b12b98be..2b7d9756 100644 --- a/transport/tlscommon/server_config_test.go +++ b/transport/tlscommon/server_config_test.go @@ -20,6 +20,7 @@ package tlscommon import ( "testing" + "github.com/elastic/go-ucfg" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -92,3 +93,83 @@ func Test_ServerConfig_Serialization_ClientAuth(t *testing.T) { }) } } + +func Test_ServerConfig_Repack(t *testing.T) { + tests := []struct { + name string + yaml string + auth *TLSClientAuth + }{{ + name: "with client auth", + yaml: ` + enabled: true + verification_mode: certificate + supported_protocols: [TLSv1.1, TLSv1.2] + cipher_suites: + - RSA-AES-256-CBC-SHA + certificate_authorities: + - /path/to/ca.crt + certificate: /path/to/cert.cry + key: /path/to/key/crt + curve_types: + - P-521 + client_authentication: optional + ca_sha256: + - example`, + auth: &optional, + }, { + name: "nil client auth", + yaml: ` + enabled: true + verification_mode: certificate + supported_protocols: [TLSv1.1, TLSv1.2] + cipher_suites: + - RSA-AES-256-CBC-SHA + certificate_authorities: + - /path/to/ca.crt + certificate: /path/to/cert.cry + key: /path/to/key/crt + curve_types: + - P-521 + ca_sha256: + - example`, + auth: &required, + }, { + name: "nil client auth, no cas", + yaml: ` + enabled: true + verification_mode: certificate + supported_protocols: [TLSv1.1, TLSv1.2] + cipher_suites: + - RSA-AES-256-CBC-SHA + certificate: /path/to/cert.cry + key: /path/to/key/crt + curve_types: + - P-521 + ca_sha256: + - example`, + auth: nil, + }} + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := mustLoadServerConfig(t, tc.yaml) + if tc.auth != nil { + require.Equal(t, *tc.auth, *cfg.ClientAuth) + } else { + require.Nil(t, cfg.ClientAuth) + } + + tmp, err := ucfg.NewFrom(cfg) + require.NoError(t, err) + + err = tmp.Unpack(&cfg) + require.NoError(t, err) + if tc.auth != nil { + require.Equal(t, *tc.auth, *cfg.ClientAuth) + } else { + require.Nil(t, cfg.ClientAuth) + } + }) + } +} diff --git a/transport/tlscommon/types.go b/transport/tlscommon/types.go index 5d178302..513db1ed 100644 --- a/transport/tlscommon/types.go +++ b/transport/tlscommon/types.go @@ -165,7 +165,6 @@ func (m *TLSVerificationMode) Unpack(in interface{}) error { *m = VerifyFull return nil } - switch o := in.(type) { case string: if o == "" { @@ -207,17 +206,30 @@ func (m TLSClientAuth) MarshalText() ([]byte, error) { return nil, fmt.Errorf("could not marshal '%+v' to text", m) } -func (m *TLSClientAuth) Unpack(s string) error { - if s == "" { +func (m *TLSClientAuth) Unpack(in interface{}) error { + if in == nil { *m = TLSClientAuthNone return nil } - mode, found := tlsClientAuthTypes[s] - if !found { - return fmt.Errorf("unknown client authentication mode '%v'", s) - } + switch o := in.(type) { + case string: + if o == "" { + *m = TLSClientAuthNone + return nil + } + mode, found := tlsClientAuthTypes[o] + if !found { + return fmt.Errorf("unknown client authentication mode '%v'", o) + } - *m = mode + *m = mode + case uint64: + *m = TLSClientAuth(o) + case int64: // underlying type is int so we need both uint64 and int64 as options for TLSClientAuth + *m = TLSClientAuth(o) + default: + return fmt.Errorf("client auth mode is an unknown type: %T", o) + } return nil }