Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support key versionless #403

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: Support key versionless
Retrieve latest key version from akv and
put key version into annotation for decryption.

Signed-off-by: Zhecheng Li <zhechengli@microsoft.com>
lzhecheng committed Dec 17, 2024
commit e6d4654ce757609413f8eee8ebfb921d6a0d2017
7 changes: 4 additions & 3 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
@@ -31,9 +31,10 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
// If --key-version not set or is empty, the plugin will use the latest version of the key.
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
93 changes: 71 additions & 22 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyVersionAnnotationKey = "keyversion.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
@@ -70,7 +71,7 @@ type KeyVaultClient struct {
keyName string
keyVersion string
vaultURL string
keyIDHash string
keyIDHash string // keyIDHash is used when key version-less is disabled
azureEnvironment *azure.Environment
}

@@ -90,9 +91,10 @@ func NewKeyVaultClient(

// this should be the case for bring your own key, clusters bootstrapped with
// aks-engine or aks and standalone kms plugin deployments
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
return nil, fmt.Errorf("key vault name, key name and key version are required")
if len(vaultName) == 0 || len(keyName) == 0 {
return nil, fmt.Errorf("key vault name and key name are required")
}

kvClient := kv.New()
err := kvClient.AddToUserAgent(version.GetUserAgent())
if err != nil {
@@ -121,9 +123,12 @@ func NewKeyVaultClient(
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
keyIDHash := ""
if len(keyVersion) != 0 {
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

if proxyMode {
@@ -158,17 +163,39 @@ func (kvc *KeyVaultClient) Encrypt(
Algorithm: encryptionAlgorithm,
Value: &value,
}
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)

keyVersion := kvc.keyVersion
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
}

if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
keyIDHash := ""
if result.Kid == nil {
return nil, fmt.Errorf("key id is nil in encryption result")
}
if len(keyVersion) == 0 {
keyVersion = path.Base(strings.TrimSuffix(*result.Kid, "/"))
keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
keyIDHash,
*result.Kid,
)
}
} else {
if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
}
keyIDHash = kvc.keyIDHash
}

annotations := map[string][]byte{
@@ -177,11 +204,13 @@ func (kvc *KeyVaultClient) Encrypt(
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(encryptionAlgorithm),
keyVersionAnnotationKey: []byte(keyVersion),
}

mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return &service.EncryptResponse{
Ciphertext: []byte(*result.Result),
KeyID: kvc.keyIDHash,
KeyID: keyIDHash,
Annotations: annotations,
}, nil
}
@@ -208,7 +237,12 @@ func (kvc *KeyVaultClient) Decrypt(
Value: &value,
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
keyVersion := kvc.keyVersion
if len(annotations[keyVersionAnnotationKey]) != 0 {
keyVersion = string(annotations[keyVersionAnnotationKey])
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
}
@@ -217,6 +251,7 @@ func (kvc *KeyVaultClient) Decrypt(
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
}

mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return bytes, nil
}

@@ -234,19 +269,33 @@ func (kvc *KeyVaultClient) GetVaultURL() string {
// It also validates keyID that the API server checks.
func (kvc *KeyVaultClient) validateAnnotations(
annotations map[string][]byte,
keyID string,
keyIDHash string,
encryptionAlgorithm kv.JSONWebKeyEncryptionAlgorithm,
) error {
if len(annotations) == 0 {
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

if keyID != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyID,
kvc.keyIDHash,
)
if len(annotations[keyVersionAnnotationKey]) == 0 {
if keyIDHash != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
kvc.keyIDHash,
)
}
} else {
keyIDHashLocal, err := getKeyIDHash(kvc.vaultURL, kvc.keyName, string(annotations[keyVersionAnnotationKey]))
if err != nil {
return fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHashLocal != keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
keyIDHashLocal,
)
}
}

algorithm := string(annotations[algorithmAnnotationKey])
30 changes: 20 additions & 10 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
@@ -21,15 +21,16 @@ var (

func TestNewKeyVaultClientError(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
}{
{
desc: "vault name not provided",
@@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
proxyMode: false,
},
{
desc: "key version not provided",
desc: "key version not provided when not keyVersionlessEnabled",
config: &config.AzureConfig{},
vaultName: "testkv",
keyName: "k8s",
@@ -127,6 +128,15 @@ func TestNewKeyVaultClient(t *testing.T) {
proxyMode: false,
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
},
{
desc: "no error when no key version (version-less)",
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
vaultName: "testkv",
keyName: "key1",
keyVersion: "",
proxyMode: false,
expectedVaultURL: "https://testkv.vault.azure.net/",
},
}

for _, test := range tests {