Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support skipping certificate private key check on request #475

Merged
merged 16 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 172 additions & 9 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.Get(HashArg)
sha1Hash := u.GetHexEncoded(HashArg)
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand Down Expand Up @@ -544,18 +544,15 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
var certHandle *windows.CertContext

switch {
case sha1Hash != "":
sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial

sha1Bytes, err := hex.DecodeString(sha1Hash)
if err != nil {
return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err)
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
maraino marked this conversation as resolved.
Show resolved Hide resolved
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Bytes)),
data: uintptr(unsafe.Pointer(&sha1Bytes[0])),
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
Expand Down Expand Up @@ -714,6 +711,172 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
return nil
}

// DeleteCertificate deletes a certificate from the Windows certificate store. It uses
// largely the same logic for searching for the certificate as [LoadCertificate], but
// deletes it as soon as it's found.
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
u, err := uri.ParseWithScheme(Scheme, req.Name)
if err != nil {
return fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.GetHexEncoded(HashArg)
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)

var storeLocation string
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
storeLocation = "user"
}

var certStoreLocation uint32
switch storeLocation {
case "user":
certStoreLocation = certStoreCurrentUser
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return fmt.Errorf("invalid cert store location %v", storeLocation)
hslatman marked this conversation as resolved.
Show resolved Hide resolved
}

var storeName string
if storeName = u.Get(StoreNameArg); storeName == "" {
storeName = "My"
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
hslatman marked this conversation as resolved.
Show resolved Hide resolved
}

var certHandle *windows.CertContext

switch {
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findCertID,
uintptr(unsafe.Pointer(&searchData)), nil)
if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}
return nil
case keyID != "":
keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial

keyIDBytes, err := hex.DecodeString(keyID)
if err != nil {
return fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err)
hslatman marked this conversation as resolved.
Show resolved Hide resolved
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_KEY_IDENTIFIER,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(keyIDBytes)),
data: uintptr(unsafe.Pointer(&keyIDBytes[0])),
},
}
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findCertID,
uintptr(unsafe.Pointer(&searchData)), nil)
if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}
return nil
case issuerName != "" && serialNumber != "":
//TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
var serialBytes []byte
if strings.HasPrefix(serialNumber, "0x") {
serialNumber = strings.TrimPrefix(serialNumber, "0x")
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
serialBytes, err = hex.DecodeString(serialNumber)
if err != nil {
return fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err)
}
} else {
bi := new(big.Int)
bi, ok := bi.SetString(serialNumber, 10)
if !ok {
return fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg)
}
serialBytes = bi.Bytes()
}
var prevCert *windows.CertContext
for {
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(issuerName))), prevCert)

if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

x509Cert, err := certContextToX509(certHandle)
if err != nil {
return fmt.Errorf("could not unmarshal certificate to DER: %w", err)
}

if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) {
if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}

return nil
}
prevCert = certHandle
}
maraino marked this conversation as resolved.
Show resolved Hide resolved
default:
return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
}
}

type CAPISigner struct {
algorithmGroup string
keyHandle uintptr
Expand Down
126 changes: 125 additions & 1 deletion kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,25 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store
if o.store != "" {
store = o.store
}
skipFindCertificateKey := "false"
if o.skipFindCertificateKey {
skipFindCertificateKey = "true"
}

leaf := req.CertificateChain[0]
fp, err := fingerprint.New(leaf.Raw, crypto.SHA1, fingerprint.HexFingerprint)
if err != nil {
return fmt.Errorf("failed calculating certificate SHA1 fingerprint: %w", err)
}

uv := url.Values{}
uv.Set("sha1", fp)
uv.Set("store-location", location)
uv.Set("store", store)
uv.Set("skip-find-certificate-key", skipFindCertificateKey)

if err := k.windowsCertificateManager.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: fmt.Sprintf("capi:sha1=%s;store-location=%s;store=%s;", fp, location, store),
Name: uri.New("capi", uv).String(),
Certificate: leaf,
}); err != nil {
return fmt.Errorf("failed storing certificate using Windows platform cryptography provider: %w", err)
Expand Down Expand Up @@ -850,6 +860,109 @@ func (k *TPMKMS) storeIntermediateToWindowsCertificateStore(c *x509.Certificate,
return nil
}

// DeleteCertificate deletes a certificate for the key identified by name from the
// TPMKMS. If the instance is configured to use the Windows certificate store, it'll
// delete the certificate from the certificate store, backed by a CAPIKMS instance.
//
// It's possible to delete a specific certificate for a key by specifying it's SHA1
// or serial. This is only supported if the instance is configured to use the Windows
// certificate store.
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
if req.Name == "" {
return errors.New("deleteCertificateRequest 'name' cannot be empty")
}

if k.usesWindowsCertificateStore() {
if err := k.deleteCertificateFromWindowsCertificateStore(&apiv1.DeleteCertificateRequest{
Name: req.Name,
}); err != nil {
return fmt.Errorf("failed deleting certificate from Windows platform cryptography provider: %w", err)
}

return nil
}

// TODO(hs): support delete by serial? If not, the behavior for TPM storage and Windows
// certificate store storage will be different, and may need different behavior when
// implementing certificate management.

properties, err := parseNameURI(req.Name)
if err != nil {
return fmt.Errorf("failed parsing %q: %w", req.Name, err)
}

ctx := context.Background()
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
if err != nil {
return err
}
if err := ak.SetCertificateChain(ctx, nil); err != nil {
return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err)
}
} else {
key, err := k.tpm.GetKey(ctx, properties.name)
if err != nil {
return err
}
if err := key.SetCertificateChain(ctx, nil); err != nil {
return fmt.Errorf("failed storing certificate for key %q: %w", properties.name, err)
}
}

return nil
}

func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteCertificateRequest) error {
o, err := parseNameURI(req.Name)
if err != nil {
return fmt.Errorf("failed parsing %q: %w", req.Name, err)
}

location := k.windowsCertificateStoreLocation
if o.storeLocation != "" {
location = o.storeLocation
}
store := k.windowsCertificateStore
if o.store != "" {
store = o.store
}

uv := url.Values{}
uv.Set("store-location", location)
uv.Set("store", store)

switch {
case o.serial != "":
uv.Set("serial", o.serial)
uv.Set("issuer", o.issuer)
case o.keyID != "":
uv.Set("key-id", o.keyID)
case o.sha1 != "":
uv.Set("sha1", o.sha1)
default:
return errors.New(`at least one of "serial", "key-id" or "sha1" is expected to be set`)
}

dk, ok := k.windowsCertificateManager.(deletingCertificateManager)
if !ok {
return fmt.Errorf("expected Windows certificate manager to implement DeleteCertificate")
}

if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{
Name: uri.New("capi", uv).String(),
}); err != nil {
return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err)
}

return nil
}

// attestationClient is a wrapper for [attestation.Client], containing
// all of the required references to perform attestation against the
// Smallstep Attestation CA.
Expand Down Expand Up @@ -1173,8 +1286,19 @@ func generateWindowsSubjectKeyID(pub crypto.PublicKey) (string, error) {
return hex.EncodeToString(hash[:]), nil
}

type deletingCertificateManager interface {
apiv1.CertificateManager
DeleteCertificate(req *apiv1.DeleteCertificateRequest) error
}

type deletingCertificateChainManager interface {
apiv1.CertificateChainManager
DeleteCertificate(req *apiv1.DeleteCertificateRequest) error
}

var _ apiv1.KeyManager = (*TPMKMS)(nil)
var _ apiv1.Attester = (*TPMKMS)(nil)
var _ apiv1.CertificateManager = (*TPMKMS)(nil)
var _ apiv1.CertificateChainManager = (*TPMKMS)(nil)
var _ deletingCertificateChainManager = (*TPMKMS)(nil)
var _ apiv1.AttestationClient = (*attestationClient)(nil)
10 changes: 10 additions & 0 deletions kms/tpmkms/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ type objectProperties struct {
store string
intermediateStoreLocation string
intermediateStore string
skipFindCertificateKey bool
keyID string
sha1 string
serial string
issuer string
}

func parseNameURI(nameURI string) (o objectProperties, err error) {
Expand Down Expand Up @@ -59,6 +64,11 @@ func parseNameURI(nameURI string) (o objectProperties, err error) {
o.store = u.Get("store")
o.intermediateStoreLocation = u.Get("intermediate-store-location")
o.intermediateStore = u.Get("intermediate-store")
o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key")
o.keyID = u.Get("key-id")
o.sha1 = u.Get("sha1")
o.serial = u.Get("serial")
o.issuer = u.Get("issuer")

// validation
if o.ak && o.attestBy != "" {
Expand Down
Loading