diff --git a/gateway/cert.go b/gateway/cert.go index b0b1329520d..284d762e22d 100644 --- a/gateway/cert.go +++ b/gateway/cert.go @@ -582,12 +582,13 @@ func (gw *Gateway) certHandler(w http.ResponseWriter, r *http.Request) { } func getCipherAliases(ciphers []string) (cipherCodes []uint16) { - for _, v := range tls.CipherSuites() { - for _, str := range ciphers { - if str == v.Name { - cipherCodes = append(cipherCodes, v.ID) - } + for _, v := range ciphers { + id, err := crypto.ResolveCipher(v) + if err != nil { + log.Debugf("cipher %s not found; skipped", v) + continue } + cipherCodes = append(cipherCodes, id) } return cipherCodes } diff --git a/internal/crypto/ciphers.go b/internal/crypto/ciphers.go new file mode 100644 index 00000000000..9f5d66dc7e3 --- /dev/null +++ b/internal/crypto/ciphers.go @@ -0,0 +1,75 @@ +package crypto + +import ( + "crypto/tls" + "fmt" + "strings" +) + +// CipherSuite stores information about a cipher suite. +// It shadows tls.CipherSuite but translates TLS versions to strings. +type CipherSuite struct { + ID uint16 `json:"id"` + Name string `json:"name"` + Insecure bool `json:"insecure"` + TLS []string `json:"tls"` +} + +// NewCipher translates tls.CipherSuite to our local type. +func NewCipher(in *tls.CipherSuite) *CipherSuite { + return &CipherSuite{ + ID: in.ID, + Name: in.Name, + Insecure: in.Insecure, + TLS: TLSVersions(in.SupportedVersions), + } +} + +// String returns a human-readable string for the cipher. +func (c *CipherSuite) String() string { + return fmt.Sprintf("Cipher ID: %d, Name: %s, Insecure: %t, TLS: %v", c.ID, c.Name, c.Insecure, c.TLS) +} + +// TLSVersions will return a list of TLS versions as a string. +func TLSVersions(in []uint16) []string { + versions := make([]string, len(in)) + for i, v := range in { + switch v { + case tls.VersionTLS10: + versions[i] = "1.0" + case tls.VersionTLS11: + versions[i] = "1.1" + case tls.VersionTLS12: + versions[i] = "1.2" + case tls.VersionTLS13: + versions[i] = "1.3" + default: + versions[i] = "" + } + } + return versions +} + +// GetCiphers generates a list of CipherSuite from the available ciphers. +func GetCiphers() []*CipherSuite { + ciphers := tls.CipherSuites() + result := make([]*CipherSuite, 0, len(ciphers)) + + for _, cipher := range ciphers { + result = append(result, NewCipher(cipher)) + } + + return result +} + +// ResolveCipher translates a string representation of a cipher to its uint16 ID. +// It's case-insensitive when matching the cipher by name. +func ResolveCipher(cipherName string) (uint16, error) { + ciphers := GetCiphers() + for _, cipher := range ciphers { + if strings.EqualFold(cipher.Name, cipherName) { + return cipher.ID, nil + } + } + return 0, fmt.Errorf("cipher %s not found", cipherName) +} diff --git a/internal/crypto/ciphers_test.go b/internal/crypto/ciphers_test.go new file mode 100644 index 00000000000..a116c06a331 --- /dev/null +++ b/internal/crypto/ciphers_test.go @@ -0,0 +1,99 @@ +package crypto + +import ( + "crypto/tls" + "testing" +) + +func TestNewCipher(t *testing.T) { + mockCipher := &tls.CipherSuite{ + ID: uint16(0x0001), + Name: "TLS_MOCK_CIPHER", + Insecure: false, + SupportedVersions: []uint16{tls.VersionTLS12, tls.VersionTLS13}, + } + + cipher := NewCipher(mockCipher) + + if cipher.ID != mockCipher.ID { + t.Errorf("Expected ID %d, got %d", mockCipher.ID, cipher.ID) + } + if cipher.Name != mockCipher.Name { + t.Errorf("Expected Name %s, got %s", mockCipher.Name, cipher.Name) + } + if cipher.Insecure != mockCipher.Insecure { + t.Errorf("Expected Insecure %t, got %t", mockCipher.Insecure, cipher.Insecure) + } + if len(cipher.TLS) != 2 || cipher.TLS[0] != "1.2" || cipher.TLS[1] != "1.3" { + t.Errorf("Expected TLS versions [1.2, 1.3], got %v", cipher.TLS) + } +} + +func TestGetCiphers(t *testing.T) { + ciphers := GetCiphers() + if len(ciphers) == 0 { + t.Error("Expected non-empty cipher list") + } + + for _, cipher := range ciphers { + if cipher.ID == 0 || cipher.Name == "" { + t.Errorf("Invalid cipher: %v", cipher) + } + } +} + +func TestResolveCipher(t *testing.T) { + testCases := []struct { + name string + input string + expected uint16 + hasError bool + }{ + {"Valid cipher", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", 0xc02f, false}, + {"Invalid cipher", "INVALID_CIPHER", 0, true}, + {"Case insensitive", "tls_ecdhe_rsa_with_aes_128_gcm_sha256", 0xc02f, false}, + {"Empty input", "", 0, true}, + {"Partial match", "TLS_ECDHE", 0, true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := ResolveCipher(tc.input) + if tc.hasError && err == nil { + t.Error("Expected error, got nil") + } + if !tc.hasError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + if result != tc.expected { + t.Errorf("Expected %d, got %d", tc.expected, result) + } + }) + } +} + +func TestTLSVersions(t *testing.T) { + testCases := []struct { + name string + input []uint16 + expected []string + }{ + {"All versions", []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13}, []string{"1.0", "1.1", "1.2", "1.3"}}, + {"Unknown version", []uint16{0x0000}, []string{""}}, + {"Mixed versions", []uint16{tls.VersionTLS12, 0x0000, tls.VersionTLS13}, []string{"1.2", "", "1.3"}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := TLSVersions(tc.input) + if len(result) != len(tc.expected) { + t.Errorf("Expected %d versions, got %d", len(tc.expected), len(result)) + } + for i, v := range result { + if v != tc.expected[i] { + t.Errorf("Expected version %s at index %d, got %s", tc.expected[i], i, v) + } + } + }) + } +}