Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support key versionless #403

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
// If --key-version not set or is empty, the plugin will use the latest version of the key.
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
Expand Down
93 changes: 71 additions & 22 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyVersionAnnotationKey = "keyversion.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
Expand Down Expand Up @@ -70,7 +71,7 @@ type KeyVaultClient struct {
keyName string
keyVersion string
vaultURL string
keyIDHash string
keyIDHash string // keyIDHash is used when key version-less is disabled
azureEnvironment *azure.Environment
}

Expand All @@ -90,9 +91,10 @@ func NewKeyVaultClient(

// this should be the case for bring your own key, clusters bootstrapped with
// aks-engine or aks and standalone kms plugin deployments
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
return nil, fmt.Errorf("key vault name, key name and key version are required")
if len(vaultName) == 0 || len(keyName) == 0 {
return nil, fmt.Errorf("key vault name and key name are required")
}

kvClient := kv.New()
err := kvClient.AddToUserAgent(version.GetUserAgent())
if err != nil {
Expand Down Expand Up @@ -121,9 +123,12 @@ func NewKeyVaultClient(
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
keyIDHash := ""
if len(keyVersion) != 0 {
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

if proxyMode {
Expand Down Expand Up @@ -158,17 +163,39 @@ func (kvc *KeyVaultClient) Encrypt(
Algorithm: encryptionAlgorithm,
Value: &value,
}
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)

keyVersion := kvc.keyVersion
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
}

if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
keyIDHash := ""
if result.Kid == nil {
return nil, fmt.Errorf("key id is nil in encryption result")
}
if len(keyVersion) == 0 {
keyVersion = path.Base(strings.TrimSuffix(*result.Kid, "/"))
keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
keyIDHash,
*result.Kid,
)
}
} else {
if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
}
keyIDHash = kvc.keyIDHash
}

annotations := map[string][]byte{
Expand All @@ -177,11 +204,13 @@ func (kvc *KeyVaultClient) Encrypt(
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(encryptionAlgorithm),
keyVersionAnnotationKey: []byte(keyVersion),
}

mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return &service.EncryptResponse{
Ciphertext: []byte(*result.Result),
KeyID: kvc.keyIDHash,
KeyID: keyIDHash,
Annotations: annotations,
}, nil
}
Expand All @@ -208,7 +237,12 @@ func (kvc *KeyVaultClient) Decrypt(
Value: &value,
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
keyVersion := kvc.keyVersion
if len(annotations[keyVersionAnnotationKey]) != 0 {
keyVersion = string(annotations[keyVersionAnnotationKey])
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
}
Expand All @@ -217,6 +251,7 @@ func (kvc *KeyVaultClient) Decrypt(
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
}

mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return bytes, nil
}

Expand All @@ -234,19 +269,33 @@ func (kvc *KeyVaultClient) GetVaultURL() string {
// It also validates keyID that the API server checks.
func (kvc *KeyVaultClient) validateAnnotations(
annotations map[string][]byte,
keyID string,
keyIDHash string,
encryptionAlgorithm kv.JSONWebKeyEncryptionAlgorithm,
) error {
if len(annotations) == 0 {
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

if keyID != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyID,
kvc.keyIDHash,
)
if len(annotations[keyVersionAnnotationKey]) == 0 {
if keyIDHash != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
kvc.keyIDHash,
)
}
} else {
keyIDHashLocal, err := getKeyIDHash(kvc.vaultURL, kvc.keyName, string(annotations[keyVersionAnnotationKey]))
if err != nil {
return fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHashLocal != keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
keyIDHashLocal,
)
}
}

algorithm := string(annotations[algorithmAnnotationKey])
Expand Down
30 changes: 20 additions & 10 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ var (

func TestNewKeyVaultClientError(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
}{
{
desc: "vault name not provided",
Expand All @@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
proxyMode: false,
},
{
desc: "key version not provided",
desc: "key version not provided when not keyVersionlessEnabled",
config: &config.AzureConfig{},
vaultName: "testkv",
keyName: "k8s",
Expand Down Expand Up @@ -127,6 +128,15 @@ func TestNewKeyVaultClient(t *testing.T) {
proxyMode: false,
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
},
{
desc: "no error when no key version (version-less)",
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
vaultName: "testkv",
keyName: "key1",
keyVersion: "",
proxyMode: false,
expectedVaultURL: "https://testkv.vault.azure.net/",
},
}

for _, test := range tests {
Expand Down
Loading