Skip to content

Commit

Permalink
Merge branch 'main' into ui/stretch-text-field-to-fit-content-in-JSON…
Browse files Browse the repository at this point in the history
…-editor
  • Loading branch information
karabanov authored Oct 9, 2024
2 parents e4eae80 + 3b0614a commit fd6a99e
Show file tree
Hide file tree
Showing 33 changed files with 1,310 additions and 148 deletions.
72 changes: 66 additions & 6 deletions CHANGELOG.md

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions builtin/logical/transit/api_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package transit

import (
"fmt"

"github.com/hashicorp/vault/sdk/helper/keysutil"
)

// parsePaddingSchemeArg validate that the provided padding scheme argument received on the api can be used.
func parsePaddingSchemeArg(keyType keysutil.KeyType, rawPs any) (keysutil.PaddingScheme, error) {
ps, ok := rawPs.(string)
if !ok {
return "", fmt.Errorf("argument was not a string: %T", rawPs)
}

paddingScheme, err := keysutil.ParsePaddingScheme(ps)
if err != nil {
return "", err
}

if !keyType.PaddingSchemesSupported() {
return "", fmt.Errorf("unsupported key type %s for padding scheme", keyType.String())
}

return paddingScheme, nil
}
52 changes: 52 additions & 0 deletions builtin/logical/transit/api_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package transit

import (
"testing"

"github.com/hashicorp/vault/sdk/helper/keysutil"
)

// Test_parsePaddingSchemeArg validate the various use cases we have around parsing
// the various padding_scheme arg possible values.
func Test_parsePaddingSchemeArg(t *testing.T) {
type args struct {
keyType keysutil.KeyType
rawPs any
}
tests := []struct {
name string
args args
want keysutil.PaddingScheme
wantErr bool
}{
// Error cases
{name: "nil-ps", args: args{keyType: keysutil.KeyType_RSA2048, rawPs: nil}, wantErr: true},
{name: "nonstring-ps", args: args{keyType: keysutil.KeyType_RSA2048, rawPs: 5}, wantErr: true},
{name: "invalid-ps", args: args{keyType: keysutil.KeyType_RSA2048, rawPs: "unknown"}, wantErr: true},
{name: "bad-keytype-oaep", args: args{keyType: keysutil.KeyType_AES128_CMAC, rawPs: "oaep"}, wantErr: true},
{name: "bad-keytype-pkcs1", args: args{keyType: keysutil.KeyType_ECDSA_P256, rawPs: "pkcs1v15"}, wantErr: true},
{name: "oaep-capped", args: args{keyType: keysutil.KeyType_RSA4096, rawPs: "OAEP"}, wantErr: true},
{name: "pkcs1-whitespace", args: args{keyType: keysutil.KeyType_RSA3072, rawPs: " pkcs1v15 "}, wantErr: true},

// Valid cases
{name: "oaep-2048", args: args{keyType: keysutil.KeyType_RSA2048, rawPs: "oaep"}, want: keysutil.PaddingScheme_OAEP},
{name: "oaep-3072", args: args{keyType: keysutil.KeyType_RSA3072, rawPs: "oaep"}, want: keysutil.PaddingScheme_OAEP},
{name: "oaep-4096", args: args{keyType: keysutil.KeyType_RSA4096, rawPs: "oaep"}, want: keysutil.PaddingScheme_OAEP},
{name: "pkcs1", args: args{keyType: keysutil.KeyType_RSA3072, rawPs: "pkcs1v15"}, want: keysutil.PaddingScheme_PKCS1v15},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parsePaddingSchemeArg(tt.args.keyType, tt.args.rawPs)
if (err != nil) != tt.wantErr {
t.Errorf("parsePaddingSchemeArg() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("parsePaddingSchemeArg() got = %v, want %v", got, tt.want)
}
})
}
}
145 changes: 79 additions & 66 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,83 +148,96 @@ func testTransit_RSA(t *testing.T, keyType string) {

plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"

encryptReq := &logical.Request{
Path: "encrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"plaintext": plaintext,
},
}
for _, padding := range []keysutil.PaddingScheme{keysutil.PaddingScheme_OAEP, keysutil.PaddingScheme_PKCS1v15, ""} {
encryptReq := &logical.Request{
Path: "encrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"plaintext": plaintext,
},
}

resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if padding != "" {
encryptReq.Data["padding_scheme"] = padding
}

ciphertext1 := resp.Data["ciphertext"].(string)
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

decryptReq := &logical.Request{
Path: "decrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"ciphertext": ciphertext1,
},
}
ciphertext1 := resp.Data["ciphertext"].(string)

resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
decryptReq := &logical.Request{
Path: "decrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"ciphertext": ciphertext1,
},
}
if padding != "" {
decryptReq.Data["padding_scheme"] = padding
}

decryptedPlaintext := resp.Data["plaintext"]
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}
decryptedPlaintext := resp.Data["plaintext"]

// Rotate the key
rotateReq := &logical.Request{
Path: "keys/rsa/rotate",
Operation: logical.UpdateOperation,
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}

// Encrypt again
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext2 := resp.Data["ciphertext"].(string)
// Rotate the key
rotateReq := &logical.Request{
Path: "keys/rsa/rotate",
Operation: logical.UpdateOperation,
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

if ciphertext1 == ciphertext2 {
t.Fatalf("expected different ciphertexts")
}
// Encrypt again
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext2 := resp.Data["ciphertext"].(string)

// See if the older ciphertext can still be decrypted
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt old ciphertext after rotating the key")
}
if ciphertext1 == ciphertext2 {
t.Fatalf("expected different ciphertexts")
}

// Decrypt the new ciphertext
decryptReq.Data = map[string]interface{}{
"ciphertext": ciphertext2,
}
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt ciphertext after rotating the key")
// See if the older ciphertext can still be decrypted
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt old ciphertext after rotating the key")
}

// Decrypt the new ciphertext
decryptReq.Data = map[string]interface{}{
"ciphertext": ciphertext2,
}
if padding != "" {
decryptReq.Data["padding_scheme"] = padding
}

resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt ciphertext after rotating the key")
}
}

signReq := &logical.Request{
Expand Down
22 changes: 18 additions & 4 deletions builtin/logical/transit/path_datakey.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ func (b *backend) pathDatakey() *framework.Path {
ciphertext; "wrapped" will return the ciphertext only.`,
},

"padding_scheme": {
Type: framework.TypeString,
Description: `The padding scheme to use for decrypt. Currently only applies to RSA key types.
Options are 'oaep' or 'pkcs1v15'. Defaults to 'oaep'`,
},

"context": {
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
Expand Down Expand Up @@ -142,23 +148,31 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
return nil, err
}

var managedKeyFactory ManagedKeyFactory
factories := make([]any, 0)
if ps, ok := d.GetOk("padding_scheme"); ok {
paddingScheme, err := parsePaddingSchemeArg(p.Type, ps)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("padding_scheme argument invalid: %s", err.Error())), logical.ErrInvalidRequest
}
factories = append(factories, paddingScheme)

}
if p.Type == keysutil.KeyType_MANAGED_KEY {
managedKeySystemView, ok := b.System().(logical.ManagedKeySystemView)
if !ok {
return nil, errors.New("unsupported system view")
}

managedKeyFactory = ManagedKeyFactory{
factories = append(factories, ManagedKeyFactory{
managedKeyParams: keysutil.ManagedKeyParameters{
ManagedKeySystemView: managedKeySystemView,
BackendUUID: b.backendUUID,
Context: ctx,
},
}
})
}

ciphertext, err := p.EncryptWithFactory(ver, context, nonce, base64.StdEncoding.EncodeToString(newKey), nil, managedKeyFactory)
ciphertext, err := p.EncryptWithFactory(ver, context, nonce, base64.StdEncoding.EncodeToString(newKey), factories...)
if err != nil {
switch err.(type) {
case errutil.UserError:
Expand Down
Loading

0 comments on commit fd6a99e

Please sign in to comment.