diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index ad9a9188c254..5dc971a65b74 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -2,7 +2,9 @@ package transit import ( "crypto/elliptic" + "crypto/x509" "encoding/base64" + "encoding/pem" "fmt" "strconv" "time" @@ -131,6 +133,8 @@ func (b *backend) pathPolicyWrite( polReq.KeyType = keysutil.KeyType_ECDSA_P256 case "ed25519": polReq.KeyType = keysutil.KeyType_ED25519 + case "rsa": + polReq.KeyType = keysutil.KeyType_RSA default: return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest } @@ -225,7 +229,7 @@ func (b *backend) pathPolicyRead( } resp.Data["keys"] = retKeys - case keysutil.KeyType_ECDSA_P256, keysutil.KeyType_ED25519: + case keysutil.KeyType_ECDSA_P256, keysutil.KeyType_ED25519, keysutil.KeyType_RSA: retKeys := map[string]map[string]interface{}{} for k, v := range p.Keys { key := asymKey{ @@ -253,6 +257,24 @@ func (b *backend) pathPolicyRead( } } key.Name = "ed25519" + case keysutil.KeyType_RSA: + key.Name = "rsa" + + // Encode the RSA public key in PEM format to return over the + // API + derBytes, err := x509.MarshalPKIXPublicKey(v.RSAKey.Public()) + if err != nil { + return nil, fmt.Errorf("error marshaling RSA public key: %v", err) + } + pemBlock := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + if pemBytes == nil || len(pemBytes) == 0 { + return nil, fmt.Errorf("failed to PEM-encode RSA public key") + } + key.PublicKey = string(pemBytes) } retKeys[strconv.Itoa(k)] = structs.New(key).Map() diff --git a/helper/keysutil/lock_manager.go b/helper/keysutil/lock_manager.go index 75881997340e..a96c106f7c69 100644 --- a/helper/keysutil/lock_manager.go +++ b/helper/keysutil/lock_manager.go @@ -256,7 +256,13 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic case KeyType_ED25519: if req.Convergent { lm.UnlockPolicy(lock, lockType) - return nil, nil, false, fmt.Errorf("convergent encryption not not supported for keys of type %v", req.KeyType) + return nil, nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType) + } + + case KeyType_RSA: + if req.Derived || req.Convergent { + lm.UnlockPolicy(lock, lockType) + return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType) } default: diff --git a/helper/keysutil/policy.go b/helper/keysutil/policy.go index 5e14334f6a93..139b0d0be8ae 100644 --- a/helper/keysutil/policy.go +++ b/helper/keysutil/policy.go @@ -9,6 +9,7 @@ import ( "crypto/elliptic" "crypto/hmac" "crypto/rand" + "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/asn1" @@ -44,6 +45,7 @@ const ( KeyType_AES256_GCM96 = iota KeyType_ECDSA_P256 KeyType_ED25519 + KeyType_RSA ) const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)" @@ -61,7 +63,7 @@ type KeyType int func (kt KeyType) EncryptionSupported() bool { switch kt { - case KeyType_AES256_GCM96: + case KeyType_AES256_GCM96, KeyType_RSA: return true } return false @@ -69,7 +71,7 @@ func (kt KeyType) EncryptionSupported() bool { func (kt KeyType) DecryptionSupported() bool { switch kt { - case KeyType_AES256_GCM96: + case KeyType_AES256_GCM96, KeyType_RSA: return true } return false @@ -77,7 +79,7 @@ func (kt KeyType) DecryptionSupported() bool { func (kt KeyType) SigningSupported() bool { switch kt { - case KeyType_ECDSA_P256, KeyType_ED25519: + case KeyType_ECDSA_P256, KeyType_ED25519, KeyType_RSA: return true } return false @@ -107,6 +109,8 @@ func (kt KeyType) String() string { return "ecdsa-p256" case KeyType_ED25519: return "ed25519" + case KeyType_RSA: + return "rsa" } return "[unknown]" @@ -127,6 +131,8 @@ type KeyEntry struct { EC_Y *big.Int `json:"ec_y"` EC_D *big.Int `json:"ec_d"` + RSAKey *rsa.PrivateKey `json:"rsa_key"` + // The public key in an appropriate format for the type of key FormattedPublicKey string `json:"public_key"` @@ -519,13 +525,6 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)} } - // Guard against a potentially invalid key type - switch p.Type { - case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } - // Decode the plaintext value plaintext, err := base64.StdEncoding.DecodeString(value) if err != nil { @@ -543,62 +542,69 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"} } - // Derive the key that should be used - key, err := p.DeriveKey(context, ver) - if err != nil { - return "", err - } + var ciphertext []byte - // Guard against a potentially invalid key type switch p.Type { case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } - - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Derive the key that should be used + key, err := p.DeriveKey(context, ver) + if err != nil { + return "", err + } - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - if p.ConvergentEncryption { - switch p.ConvergentVersion { - case 1: - if len(nonce) != gcm.NonceSize() { - return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} - } - default: - nonceHmac := hmac.New(sha256.New, context) - nonceHmac.Write(plaintext) - nonceSum := nonceHmac.Sum(nil) - nonce = nonceSum[:gcm.NonceSize()] - } - } else { - // Compute random nonce - nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize()) + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) if err != nil { return "", errutil.InternalError{Err: err.Error()} } - } - // Encrypt and tag with GCM - out := gcm.Seal(nil, nonce, plaintext, nil) + if p.ConvergentEncryption { + switch p.ConvergentVersion { + case 1: + if len(nonce) != gcm.NonceSize() { + return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} + } + default: + nonceHmac := hmac.New(sha256.New, context) + nonceHmac.Write(plaintext) + nonceSum := nonceHmac.Sum(nil) + nonce = nonceSum[:gcm.NonceSize()] + } + } else { + // Compute random nonce + nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize()) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } + } + + // Encrypt and tag with GCM + out := gcm.Seal(nil, nonce, plaintext, nil) + + // Place the encrypted data after the nonce + if !p.ConvergentEncryption || p.ConvergentVersion > 1 { + ciphertext = append(nonce, out...) + } - // Place the encrypted data after the nonce - full := out - if !p.ConvergentEncryption || p.ConvergentVersion > 1 { - full = append(nonce, out...) + case KeyType_RSA: + key := p.Keys[ver].RSAKey + ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil) + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("RSA: Error from encryption: %s\n", err)} + } + + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } // Convert to base64 - encoded := base64.StdEncoding.EncodeToString(full) + encoded := base64.StdEncoding.EncodeToString(ciphertext) // Prepend some information encoded = "vault:v" + strconv.Itoa(ver) + ":" + encoded @@ -644,50 +650,57 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.UserError{Err: ErrTooOld} } - // Derive the key that should be used - key, err := p.DeriveKey(context, ver) + // Decode the base64 + decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", err + return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"} } - // Guard against a potentially invalid key type + var plain []byte + switch p.Type { case KeyType_AES256_GCM96: - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} - } + key, err := p.DeriveKey(context, ver) + if err != nil { + return "", err + } - // Decode the base64 - decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) - if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"} - } + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return "", errutil.InternalError{Err: err.Error()} + } - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + // Extract the nonce and ciphertext + var ciphertext []byte + if p.ConvergentEncryption && p.ConvergentVersion < 2 { + ciphertext = decoded + } else { + nonce = decoded[:gcm.NonceSize()] + ciphertext = decoded[gcm.NonceSize():] + } - // Extract the nonce and ciphertext - var ciphertext []byte - if p.ConvergentEncryption && p.ConvergentVersion < 2 { - ciphertext = decoded - } else { - nonce = decoded[:gcm.NonceSize()] - ciphertext = decoded[gcm.NonceSize():] - } + // Verify and Decrypt + plain, err = gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + } - // Verify and Decrypt - plain, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + case KeyType_RSA: + key := p.Keys[ver].RSAKey + plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil) + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("error while decrypting RSA ciphertext: %v", err)} + } + + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } return base64.StdEncoding.EncodeToString(plain), nil @@ -773,6 +786,20 @@ func (p *Policy) Sign(ver int, context, input []byte) (*SigningResult, error) { return nil, err } + case KeyType_RSA: + key := p.Keys[ver].RSAKey + + hash := crypto.SHA256.New() + hash.Write(input) + inputHash := hash.Sum(nil) + + sig, err = rsa.SignPSS(rand.Reader, key, crypto.SHA256, inputHash, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }) + if err != nil { + return nil, err + } + default: return nil, fmt.Errorf("unsupported key type %v", p.Type) } @@ -857,6 +884,19 @@ func (p *Policy) VerifySignature(context, input []byte, sig string) (bool, error return ed25519.Verify(key.Public().(ed25519.PublicKey), input, sigBytes), nil + case KeyType_RSA: + key := p.Keys[ver].RSAKey + + hash := crypto.SHA256.New() + hash.Write(input) + inputHash := hash.Sum(nil) + + err = rsa.VerifyPSS(&key.PublicKey, crypto.SHA256, inputHash, sigBytes, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }) + + return err == nil, nil + default: return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} } @@ -923,6 +963,12 @@ func (p *Policy) Rotate(storage logical.Storage) error { } entry.Key = pri entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub) + + case KeyType_RSA: + entry.RSAKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } } p.Keys[p.LatestVersion] = entry