Skip to content

Commit

Permalink
feat: introduce AWS KMS types.MessageTypeRaw for AWS KMS signing oper…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
alexatcanva committed Aug 26, 2024
1 parent 7164ba9 commit 708e1bb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 37 deletions.
80 changes: 65 additions & 15 deletions kms/awskms/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,48 @@ import (
"go.step.sm/crypto/pemutil"
)

// AWSOptions implements the crypto.SignerOpts interface, it provides a Raw
// boolean field to indicate to the AWS KMS operation that the MessageType is
// RAW.
//
// Example:
//
// // Sign a raw message with KMS
// client := kms.NewFromConfig(cfg)
// kmsSigner, err := awskms.NewSigner(client, "my-key-id")
// if err != nil {
// // handle error ...
// }
// raw := []byte("my raw message")
// sig, err := kmsSigner.Sign(rand.Reader, raw, &awskms.AWSOptions{
// Raw: true,
// Options: crypto.SHA256,
// })
// if err != nil {
// // handle error ...
// }
type AWSOptions struct {
// Raw specifies to the AWS KMS operation that MessageType is RAW.
Raw bool
Options crypto.SignerOpts
}

// HashFunc implements crypto.SignerOpts.
func (a *AWSOptions) HashFunc() crypto.Hash {

Check warning on line 45 in kms/awskms/signer.go

View check run for this annotation

Codecov / codecov/patch

kms/awskms/signer.go#L45

Added line #L45 was not covered by tests
// The GoLang [crypto.SignerOpt] interfaces states that if the [HashFunc]
// returns 0, then it indicates to the [Sign] function that no hashing
// has occured over the message.

Check failure on line 48 in kms/awskms/signer.go

View workflow job for this annotation

GitHub Actions / ci / lint / lint

`occured` is a misspelling of `occurred` (misspell)
// However, the AWS KMS Sign operation always requires that a
// SigningAlgorithm is specified.
// As such, the AWSOptions HashFunc() must return a valid (non-zero) Hash,
// such that the [getMessageTypeAndSigningAlgorithm] function can return a valid AWS KMS
// [types.SigningAlgorithmSpec]
return a.Options.HashFunc()

Check warning on line 54 in kms/awskms/signer.go

View check run for this annotation

Codecov / codecov/patch

kms/awskms/signer.go#L54

Added line #L54 was not covered by tests
}

// compile time check that AWSOptions implements crypto.SignerOpts
var _ crypto.SignerOpts = (*AWSOptions)(nil)

// Signer implements a crypto.Signer using the AWS KMS.
type Signer struct {
client KeyManagementClient
Expand Down Expand Up @@ -63,7 +105,7 @@ func (s *Signer) Public() crypto.PublicKey {

// Sign signs digest with the private key stored in the AWS KMS.
func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
alg, err := getSigningAlgorithm(s.Public(), opts)
messageType, alg, err := getMessageTypeAndSigningAlgorithm(s.Public(), opts)
if err != nil {
return nil, err
}
Expand All @@ -72,7 +114,7 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt
KeyId: pointer(s.keyID),
SigningAlgorithm: alg,
Message: digest,
MessageType: types.MessageTypeDigest,
MessageType: messageType,
}

ctx, cancel := defaultContext()
Expand All @@ -86,41 +128,49 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt
return resp.Signature, nil
}

func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.SigningAlgorithmSpec, error) {
func getMessageTypeAndSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.MessageType, types.SigningAlgorithmSpec, error) {
messageType := types.MessageTypeDigest
if awsOpts, ok := opts.(*AWSOptions); ok {
if awsOpts.Raw {
messageType = types.MessageTypeRaw
}
opts = awsOpts.Options
}

switch key.(type) {
case *rsa.PublicKey:
_, isPSS := opts.(*rsa.PSSOptions)
switch h := opts.HashFunc(); h {
case crypto.SHA256:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha256, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha256, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
case crypto.SHA384:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha384, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha384, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
case crypto.SHA512:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha512, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha512, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
return messageType, "", errors.Errorf("unsupported hash function %v", h)
}
case *ecdsa.PublicKey:
switch h := opts.HashFunc(); h {
case crypto.SHA256:
return types.SigningAlgorithmSpecEcdsaSha256, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha256, nil
case crypto.SHA384:
return types.SigningAlgorithmSpecEcdsaSha384, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha384, nil
case crypto.SHA512:
return types.SigningAlgorithmSpecEcdsaSha512, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
return messageType, "", errors.Errorf("unsupported hash function %v", h)
}
default:
return "", errors.Errorf("unsupported key type %T", key)
return messageType, "", errors.Errorf("unsupported key type %T", key)
}
}
59 changes: 37 additions & 22 deletions kms/awskms/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func TestSigner_Sign(t *testing.T) {
wantErr bool
}{
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"(raw) ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.SHA256}}, signature, false},
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.MD5}}, nil, true},
{"(raw) fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
{"fail sign", fields{&MockClient{
sign: func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) {
Expand All @@ -152,39 +154,52 @@ func TestSigner_Sign(t *testing.T) {
}
}

func Test_getSigningAlgorithm(t *testing.T) {
func Test_getMessageTypeAndSigningAlgorithm(t *testing.T) {
type args struct {
key crypto.PublicKey
opts crypto.SignerOpts
}
tests := []struct {
name string
args args
want types.SigningAlgorithmSpec
wantErr bool
name string
args args
wantMessageType types.MessageType
wantAlgo types.SigningAlgorithmSpec
wantErr bool
}{
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_256", false},
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_384", false},
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_512", false},
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "ECDSA_SHA_256", false},
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "ECDSA_SHA_384", false},
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "ECDSA_SHA_512", false},
{"(raw)rsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"(raw)rsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"(raw)rsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"(raw)pssrsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_256", false},
{"(raw)pssrsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_384", false},
{"(raw)pssrsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_512", false},
{"(raw)P256", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "ECDSA_SHA_256", false},
{"(raw)P384", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "ECDSA_SHA_384", false},
{"(raw)P521", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "ECDSA_SHA_512", false},
{"fail type", args{[]byte("key"), crypto.SHA256}, types.MessageTypeDigest, "", true},
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true},
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
gotMessageType, gotAlgo, err := getMessageTypeAndSigningAlgorithm(tt.args.key, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("getMessageTypeAndSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
if gotMessageType != tt.wantMessageType {
t.Errorf("getMessageTypeAndSigningAlgorithm() (message type) = %v, want %v", gotMessageType, tt.wantMessageType)
}
if gotAlgo != tt.wantAlgo {
t.Errorf("getMessageTypeAndSigningAlgorithm() (algorithm) = %v, want %v", gotAlgo, tt.wantAlgo)
}
})
}
Expand Down

0 comments on commit 708e1bb

Please sign in to comment.