From f2e48b1a2c937903d8c3762fceb7b65f3615f162 Mon Sep 17 00:00:00 2001 From: jpsrn Date: Thu, 21 Mar 2019 19:08:37 +0200 Subject: [PATCH] Fix KMS encryption context handling (#435) * Fix KMS encryption context handling The code copying encryption context value strings to a map containing string pointers was incorrectly getting a pointer to a string variable which is being re-used by the for loop, causing all keys to point to the same value string. * Extract helper method for KmsKey to KMS MasterKey conversion * Add test for kmsKeyToMasterKey helper function --- keyservice/server.go | 36 ++++++++--------- keyservice/server_test.go | 81 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 20 deletions(-) create mode 100644 keyservice/server_test.go diff --git a/keyservice/server.go b/keyservice/server.go index 35849c6e9..b8603fe41 100644 --- a/keyservice/server.go +++ b/keyservice/server.go @@ -29,16 +29,7 @@ func (ks *Server) encryptWithPgp(key *PgpKey, plaintext []byte) ([]byte, error) } func (ks *Server) encryptWithKms(key *KmsKey, plaintext []byte) ([]byte, error) { - ctx := make(map[string]*string) - for k, v := range key.Context { - ctx[k] = &v - } - kmsKey := kms.MasterKey{ - Arn: key.Arn, - Role: key.Role, - EncryptionContext: ctx, - AwsProfile: key.AwsProfile, - } + kmsKey := kmsKeyToMasterKey(key) err := kmsKey.Encrypt(plaintext) if err != nil { return nil, err @@ -78,16 +69,7 @@ func (ks *Server) decryptWithPgp(key *PgpKey, ciphertext []byte) ([]byte, error) } func (ks *Server) decryptWithKms(key *KmsKey, ciphertext []byte) ([]byte, error) { - ctx := make(map[string]*string) - for k, v := range key.Context { - ctx[k] = &v - } - kmsKey := kms.MasterKey{ - Arn: key.Arn, - Role: key.Role, - EncryptionContext: ctx, - AwsProfile: key.AwsProfile, - } + kmsKey := kmsKeyToMasterKey(key) kmsKey.EncryptedKey = string(ciphertext) plaintext, err := kmsKey.Decrypt() return []byte(plaintext), err @@ -249,3 +231,17 @@ func (ks Server) Decrypt(ctx context.Context, } return response, nil } + +func kmsKeyToMasterKey(key *KmsKey) kms.MasterKey { + ctx := make(map[string]*string) + for k, v := range key.Context { + value := v // Allocate a new string to prevent the pointer below from referring to only the last iteration value + ctx[k] = &value + } + return kms.MasterKey{ + Arn: key.Arn, + Role: key.Role, + EncryptionContext: ctx, + AwsProfile: key.AwsProfile, + } +} diff --git a/keyservice/server_test.go b/keyservice/server_test.go new file mode 100644 index 000000000..147a69a27 --- /dev/null +++ b/keyservice/server_test.go @@ -0,0 +1,81 @@ +package keyservice + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestKmsKeyToMasterKey(t *testing.T) { + + cases := []struct { + description string + expectedArn string + expectedRole string + expectedCtx map[string]string + expectedAwsProfile string + }{ + { + description: "empty context", + expectedArn: "arn:aws:kms:eu-west-1:123456789012:key/d5c90a06-f824-4628-922b-12424571ed4d", + expectedRole: "ExampleRole", + expectedCtx: map[string]string{}, + expectedAwsProfile: "", + }, + { + description: "context with one key-value pair", + expectedArn: "arn:aws:kms:eu-west-1:123456789012:key/d5c90a06-f824-4628-922b-12424571ed4d", + expectedRole: "", + expectedCtx: map[string]string{ + "firstKey": "first value", + }, + expectedAwsProfile: "ExampleProfile", + }, + { + description: "context with three key-value pairs", + expectedArn: "arn:aws:kms:eu-west-1:123456789012:key/d5c90a06-f824-4628-922b-12424571ed4d", + expectedRole: "", + expectedCtx: map[string]string{ + "firstKey": "first value", + "secondKey": "second value", + "thirdKey": "third value", + }, + expectedAwsProfile: "", + }, + } + + for _, c := range cases { + + t.Run(c.description, func(t *testing.T) { + + inputCtx := make(map[string]string) + for k, v := range c.expectedCtx { + inputCtx[k] = v + } + + key := &KmsKey{ + Arn: c.expectedArn, + Role: c.expectedRole, + Context: inputCtx, + AwsProfile: c.expectedAwsProfile, + } + + masterKey := kmsKeyToMasterKey(key) + foundCtx := masterKey.EncryptionContext + + for k, _ := range c.expectedCtx { + require.Containsf(t, foundCtx, k, "Context does not contain expected key '%s'", k) + } + for k, _ := range foundCtx { + require.Containsf(t, c.expectedCtx, k, "Context contains an unexpected key '%s' which cannot be found from expected map", k) + } + for k, expected := range c.expectedCtx { + foundVal := *foundCtx[k] + assert.Equalf(t, expected, foundVal, "Context key '%s' value '%s' does not match expected value '%s'", k, foundVal, expected) + } + assert.Equalf(t, c.expectedArn, masterKey.Arn, "Expected ARN to be '%s', but found '%s'", c.expectedArn, masterKey.Arn) + assert.Equalf(t, c.expectedRole, masterKey.Role, "Expected Role to be '%s', but found '%s'", c.expectedRole, masterKey.Role) + assert.Equalf(t, c.expectedAwsProfile, masterKey.AwsProfile, "Expected AWS profile to be '%s', but found '%s'", c.expectedAwsProfile, masterKey.AwsProfile) + }) + } +}