diff --git a/revocation/internal/crl/crl.go b/revocation/internal/crl/crl.go new file mode 100644 index 00000000..48a5930a --- /dev/null +++ b/revocation/internal/crl/crl.go @@ -0,0 +1,288 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package crl provides methods for checking the revocation status of a +// certificate using CRL +package crl + +import ( + "context" + "crypto/x509" + "encoding/asn1" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/notaryproject/notation-core-go/revocation/result" +) + +var ( + // oidFreshestCRL is the object identifier for the distribution point + // for the delta CRL. (See RFC 5280, Section 5.2.6) + oidFreshestCRL = asn1.ObjectIdentifier{2, 5, 29, 46} + + // oidIssuingDistributionPoint is the object identifier for the issuing + // distribution point CRL extension. (See RFC 5280, Section 5.2.5) + oidIssuingDistributionPoint = asn1.ObjectIdentifier{2, 5, 29, 28} + + // oidInvalidityDate is the object identifier for the invalidity date + // CRL entry extension. (See RFC 5280, Section 5.3.2) + oidInvalidityDate = asn1.ObjectIdentifier{2, 5, 29, 24} +) + +// maxCRLSize is the maximum size of CRL in bytes +// +// CRL examples: https://chasersystems.com/blog/an-analysis-of-certificate-revocation-list-sizes/ +const maxCRLSize = 32 * 1024 * 1024 // 32 MiB + +// CertCheckStatusOptions specifies values that are needed to check CRL +type CertCheckStatusOptions struct { + // HTTPClient is the HTTP client used to download CRL + HTTPClient *http.Client + + // SigningTime is used to compare with the invalidity date during revocation + // check + SigningTime time.Time +} + +// CertCheckStatus checks the revocation status of a certificate using CRL +// +// The function checks the revocation status of the certificate by downloading +// the CRL from the CRL distribution points specified in the certificate. +// +// If the invalidity date extension is present in the CRL entry and SigningTime +// is not zero, the certificate is considered revoked if the SigningTime is +// after the invalidity date. (See RFC 5280, Section 5.3.2) +func CertCheckStatus(ctx context.Context, cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult { + if !Supported(cert) { + // CRL not enabled for this certificate. + return &result.CertRevocationResult{ + Result: result.ResultNonRevokable, + ServerResults: []*result.ServerResult{{ + RevocationMethod: result.RevocationMethodCRL, + Result: result.ResultNonRevokable, + }}, + RevocationMethod: result.RevocationMethodCRL, + } + } + + // The CRLDistributionPoints contains the URIs of all the CRL distribution + // points. Since it does not distinguish the reason field, it needs to check + // all the URIs to avoid missing any partial CRLs. + // + // For the majority of the certificates, there is only one CRL distribution + // point with one CRL URI, which will be cached, so checking all the URIs is + // not a performance issue. + var ( + serverResults = make([]*result.ServerResult, 0, len(cert.CRLDistributionPoints)) + lastErr error + crlURL string + ) + for _, crlURL = range cert.CRLDistributionPoints { + baseCRL, err := download(ctx, crlURL, opts.HTTPClient) + if err != nil { + lastErr = fmt.Errorf("failed to download CRL from %s: %w", crlURL, err) + break + } + + if err = validate(baseCRL, issuer); err != nil { + lastErr = fmt.Errorf("failed to validate CRL from %s: %w", crlURL, err) + break + } + + crlResult, err := checkRevocation(cert, baseCRL, opts.SigningTime, crlURL) + if err != nil { + lastErr = fmt.Errorf("failed to check revocation status from %s: %w", crlURL, err) + break + } + if crlResult.Result == result.ResultRevoked { + return &result.CertRevocationResult{ + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{crlResult}, + RevocationMethod: result.RevocationMethodCRL, + } + } + + serverResults = append(serverResults, crlResult) + } + + if lastErr != nil { + return &result.CertRevocationResult{ + Result: result.ResultUnknown, + ServerResults: []*result.ServerResult{ + { + Result: result.ResultUnknown, + Server: crlURL, + Error: lastErr, + RevocationMethod: result.RevocationMethodCRL, + }}, + RevocationMethod: result.RevocationMethodCRL, + } + } + + return &result.CertRevocationResult{ + Result: result.ResultOK, + ServerResults: serverResults, + RevocationMethod: result.RevocationMethodCRL, + } +} + +// Supported checks if the certificate supports CRL. +func Supported(cert *x509.Certificate) bool { + return cert != nil && len(cert.CRLDistributionPoints) > 0 +} + +func validate(crl *x509.RevocationList, issuer *x509.Certificate) error { + // check signature + if err := crl.CheckSignatureFrom(issuer); err != nil { + return fmt.Errorf("CRL is not signed by CA %s: %w,", issuer.Subject, err) + } + + // check validity + now := time.Now() + if !crl.NextUpdate.IsZero() && now.After(crl.NextUpdate) { + return fmt.Errorf("expired CRL. Current time %v is after CRL NextUpdate %v", now, crl.NextUpdate) + } + + for _, ext := range crl.Extensions { + switch { + case ext.Id.Equal(oidFreshestCRL): + return ErrDeltaCRLNotSupported + case ext.Id.Equal(oidIssuingDistributionPoint): + // IssuingDistributionPoint is a critical extension that identifies + // the scope of the CRL. Since we will check all the CRL + // distribution points, it is not necessary to check this extension. + default: + if ext.Critical { + // unsupported critical extensions is not allowed. (See RFC 5280, Section 5.2) + return fmt.Errorf("unsupported critical extension found in CRL: %v", ext.Id) + } + } + } + + return nil +} + +// checkRevocation checks if the certificate is revoked or not +func checkRevocation(cert *x509.Certificate, baseCRL *x509.RevocationList, signingTime time.Time, crlURL string) (*result.ServerResult, error) { + if cert == nil { + return nil, errors.New("certificate cannot be nil") + } + + if baseCRL == nil { + return nil, errors.New("baseCRL cannot be nil") + } + + for _, revocationEntry := range baseCRL.RevokedCertificateEntries { + if revocationEntry.SerialNumber.Cmp(cert.SerialNumber) == 0 { + extensions, err := parseEntryExtensions(revocationEntry) + if err != nil { + return nil, err + } + + // validate signingTime and invalidityDate + if !signingTime.IsZero() && !extensions.invalidityDate.IsZero() && + signingTime.Before(extensions.invalidityDate) { + // signing time is before the invalidity date which means the + // certificate is not revoked at the time of signing. + break + } + + // revoked + return &result.ServerResult{ + Result: result.ResultRevoked, + Server: crlURL, + RevocationMethod: result.RevocationMethodCRL, + }, nil + } + } + + return &result.ServerResult{ + Result: result.ResultOK, + Server: crlURL, + RevocationMethod: result.RevocationMethodCRL, + }, nil +} + +type entryExtensions struct { + // invalidityDate is the date when the key is invalid. + invalidityDate time.Time +} + +func parseEntryExtensions(entry x509.RevocationListEntry) (entryExtensions, error) { + var extensions entryExtensions + for _, ext := range entry.Extensions { + switch { + case ext.Id.Equal(oidInvalidityDate): + var invalidityDate time.Time + rest, err := asn1.UnmarshalWithParams(ext.Value, &invalidityDate, "generalized") + if err != nil { + return entryExtensions{}, fmt.Errorf("failed to parse invalidity date: %w", err) + } + if len(rest) > 0 { + return entryExtensions{}, fmt.Errorf("invalid invalidity date extension: trailing data") + } + + extensions.invalidityDate = invalidityDate + default: + if ext.Critical { + // unsupported critical extensions is not allowed. (See RFC 5280, Section 5.2) + return entryExtensions{}, fmt.Errorf("unsupported critical extension found in CRL: %v", ext.Id) + } + } + } + + return extensions, nil +} + +func download(ctx context.Context, crlURL string, client *http.Client) (*x509.RevocationList, error) { + // validate URL + parsedURL, err := url.Parse(crlURL) + if err != nil { + return nil, fmt.Errorf("invalid CRL URL: %w", err) + } + if parsedURL.Scheme != "http" { + return nil, fmt.Errorf("unsupported CRL endpoint: %s. Only urls with HTTP scheme is supported", crlURL) + } + + // download CRL + req, err := http.NewRequestWithContext(ctx, http.MethodGet, crlURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create CRL request %q: %w", crlURL, err) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed for %q: %w", crlURL, err) + } + defer resp.Body.Close() + + // check response + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("%s %q: failed to download with status code: %d", resp.Request.Method, resp.Request.URL, resp.StatusCode) + } + + // read with size limit + limitedReader := io.LimitReader(resp.Body, maxCRLSize) + data, err := io.ReadAll(limitedReader) + if err != nil { + return nil, fmt.Errorf("failed to read CRL response from %q: %w", resp.Request.URL, err) + } + if len(data) == maxCRLSize { + return nil, fmt.Errorf("%s %q: CRL size reached the %d MiB size limit", resp.Request.Method, resp.Request.URL, maxCRLSize/1024/1024) + } + + return x509.ParseRevocationList(data) +} diff --git a/revocation/internal/crl/crl_test.go b/revocation/internal/crl/crl_test.go new file mode 100644 index 00000000..2129fb79 --- /dev/null +++ b/revocation/internal/crl/crl_test.go @@ -0,0 +1,685 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crl + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "testing" + "time" + + "github.com/notaryproject/notation-core-go/revocation/result" + "github.com/notaryproject/notation-core-go/testhelper" +) + +func TestCertCheckStatus(t *testing.T) { + t.Run("certtificate does not have CRLDistributionPoints", func(t *testing.T) { + cert := &x509.Certificate{} + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{}) + if r.Result != result.ResultNonRevokable { + t.Fatalf("expected NonRevokable, got %s", r.Result) + } + }) + + t.Run("download error", func(t *testing.T) { + cert := &x509.Certificate{ + CRLDistributionPoints: []string{"http://example.com"}, + } + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: errorRoundTripperMock{}, + }, + }) + if r.ServerResults[0].Error == nil { + t.Fatal("expected error") + } + }) + + t.Run("CRL validate failed", func(t *testing.T) { + cert := &x509.Certificate{ + CRLDistributionPoints: []string{"http://example.com"}, + } + r := CertCheckStatus(context.Background(), cert, &x509.Certificate{}, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: expiredCRLRoundTripperMock{}, + }, + }) + if r.ServerResults[0].Error == nil { + t.Fatal("expected error") + } + }) + + // prepare a certificate chain + chain := testhelper.GetRevokableRSAChainWithRevocations(2, false, true) + issuerCert := chain[1].Cert + issuerKey := chain[1].PrivateKey + + t.Run("revoked", func(t *testing.T) { + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: chain[0].Cert.SerialNumber, + RevocationTime: time.Now().Add(-time.Hour), + }, + }, + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: expectedRoundTripperMock{Body: crlBytes}, + }, + }) + if r.Result != result.ResultRevoked { + t.Fatalf("expected revoked, got %s", r.Result) + } + }) + + t.Run("unknown critical extension", func(t *testing.T) { + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: chain[0].Cert.SerialNumber, + RevocationTime: time.Now().Add(-time.Hour), + ExtraExtensions: []pkix.Extension{ + { + Id: []int{1, 2, 3}, + Critical: true, + }, + }, + }, + }, + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: expectedRoundTripperMock{Body: crlBytes}, + }, + }) + if r.ServerResults[0].Error == nil { + t.Fatal("expected error") + } + }) + + t.Run("Not revoked", func(t *testing.T) { + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: expectedRoundTripperMock{Body: crlBytes}, + }, + }) + if r.Result != result.ResultOK { + t.Fatalf("expected OK, got %s", r.Result) + } + }) + + t.Run("CRL with delta CRL is not checked", func(t *testing.T) { + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + ExtraExtensions: []pkix.Extension{ + { + Id: oidFreshestCRL, + Critical: false, + }, + }, + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + r := CertCheckStatus(context.Background(), chain[0].Cert, issuerCert, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: expectedRoundTripperMock{Body: crlBytes}, + }, + }) + if !errors.Is(r.ServerResults[0].Error, ErrDeltaCRLNotSupported) { + t.Fatal("expected ErrDeltaCRLNotChecked") + } + }) +} + +func TestValidate(t *testing.T) { + t.Run("expired CRL", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(1, false, true) + issuerCert := chain[0].Cert + issuerKey := chain[0].PrivateKey + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(-time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + crl, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatal(err) + } + + if err := validate(crl, issuerCert); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("check signature failed", func(t *testing.T) { + crl := &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + } + + if err := validate(crl, &x509.Certificate{}); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("unsupported CRL critical extensions", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(1, false, true) + issuerCert := chain[0].Cert + issuerKey := chain[0].PrivateKey + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + crl, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatal(err) + } + + // add unsupported critical extension + crl.Extensions = []pkix.Extension{ + { + Id: []int{1, 2, 3}, + Critical: true, + }, + } + + if err := validate(crl, issuerCert); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("issuing distribution point extension exists", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(1, false, true) + issuerCert := chain[0].Cert + issuerKey := chain[0].PrivateKey + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + ExtraExtensions: []pkix.Extension{ + { + Id: oidIssuingDistributionPoint, + Critical: true, + }, + }, + }, issuerCert, issuerKey) + if err != nil { + t.Fatal(err) + } + + crl, err := x509.ParseRevocationList(crlBytes) + if err != nil { + t.Fatal(err) + } + + if err := validate(crl, issuerCert); err != nil { + t.Fatal(err) + } + }) +} + +func TestCheckRevocation(t *testing.T) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1), + } + signingTime := time.Now() + + t.Run("certificate is nil", func(t *testing.T) { + _, err := checkRevocation(nil, &x509.RevocationList{}, signingTime, "") + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("CRL is nil", func(t *testing.T) { + _, err := checkRevocation(cert, nil, signingTime, "") + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("not revoked", func(t *testing.T) { + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(2), + }, + }, + } + r, err := checkRevocation(cert, baseCRL, signingTime, "") + if err != nil { + t.Fatal(err) + } + if r.Result != result.ResultOK { + t.Fatalf("unexpected result, got %s", r.Result) + } + }) + + t.Run("revoked", func(t *testing.T) { + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(1), + RevocationTime: time.Now().Add(-time.Hour), + }, + }, + } + r, err := checkRevocation(cert, baseCRL, signingTime, "") + if err != nil { + t.Fatal(err) + } + if r.Result != result.ResultRevoked { + t.Fatalf("expected revoked, got %s", r.Result) + } + }) + + t.Run("revoked but signing time is before invalidityDate", func(t *testing.T) { + invalidityDate := time.Now().Add(time.Hour) + invalidityDateBytes, err := marshalGeneralizedTimeToBytes(invalidityDate) + if err != nil { + t.Fatal(err) + } + + extensions := []pkix.Extension{ + { + Id: oidInvalidityDate, + Critical: false, + Value: invalidityDateBytes, + }, + } + + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(1), + RevocationTime: time.Now().Add(time.Hour), + Extensions: extensions, + }, + }, + } + r, err := checkRevocation(cert, baseCRL, signingTime, "") + if err != nil { + t.Fatal(err) + } + if r.Result != result.ResultOK { + t.Fatalf("unexpected result, got %s", r.Result) + } + }) + + t.Run("revoked; signing time is after invalidityDate", func(t *testing.T) { + invalidityDate := time.Now().Add(-time.Hour) + invalidityDateBytes, err := marshalGeneralizedTimeToBytes(invalidityDate) + if err != nil { + t.Fatal(err) + } + + extensions := []pkix.Extension{ + { + Id: oidInvalidityDate, + Critical: false, + Value: invalidityDateBytes, + }, + } + + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(1), + RevocationTime: time.Now().Add(-time.Hour), + Extensions: extensions, + }, + }, + } + r, err := checkRevocation(cert, baseCRL, signingTime, "") + if err != nil { + t.Fatal(err) + } + if r.Result != result.ResultRevoked { + t.Fatalf("expected revoked, got %s", r.Result) + } + }) + + t.Run("revoked and signing time is zero", func(t *testing.T) { + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(1), + RevocationTime: time.Time{}, + }, + }, + } + r, err := checkRevocation(cert, baseCRL, time.Time{}, "") + if err != nil { + t.Fatal(err) + } + if r.Result != result.ResultRevoked { + t.Fatalf("expected revoked, got %s", r.Result) + } + }) + + t.Run("revocation entry validation error", func(t *testing.T) { + baseCRL := &x509.RevocationList{ + RevokedCertificateEntries: []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(1), + Extensions: []pkix.Extension{ + { + Id: []int{1, 2, 3}, + Critical: true, + }, + }, + }, + }, + } + _, err := checkRevocation(cert, baseCRL, signingTime, "") + if err == nil { + t.Fatal("expected error") + } + }) +} + +func TestParseEntryExtension(t *testing.T) { + t.Run("unsupported critical extension", func(t *testing.T) { + entry := x509.RevocationListEntry{ + Extensions: []pkix.Extension{ + { + Id: []int{1, 2, 3}, + Critical: true, + }, + }, + } + if _, err := parseEntryExtensions(entry); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("valid extension", func(t *testing.T) { + entry := x509.RevocationListEntry{ + Extensions: []pkix.Extension{ + { + Id: []int{1, 2, 3}, + Critical: false, + }, + }, + } + if _, err := parseEntryExtensions(entry); err != nil { + t.Fatal(err) + } + }) + + t.Run("parse invalidityDate", func(t *testing.T) { + + // create a time and marshal it to be generalizedTime + invalidityDate := time.Now() + invalidityDateBytes, err := marshalGeneralizedTimeToBytes(invalidityDate) + if err != nil { + t.Fatal(err) + } + + entry := x509.RevocationListEntry{ + Extensions: []pkix.Extension{ + { + Id: oidInvalidityDate, + Critical: false, + Value: invalidityDateBytes, + }, + }, + } + extensions, err := parseEntryExtensions(entry) + if err != nil { + t.Fatal(err) + } + + if extensions.invalidityDate.IsZero() { + t.Fatal("expected invalidityDate") + } + }) + + t.Run("parse invalidityDate with error", func(t *testing.T) { + // invalid invalidityDate extension + entry := x509.RevocationListEntry{ + Extensions: []pkix.Extension{ + { + Id: oidInvalidityDate, + Critical: false, + Value: []byte{0x00, 0x01, 0x02, 0x03}, + }, + }, + } + _, err := parseEntryExtensions(entry) + if err == nil { + t.Fatal("expected error") + } + + // invalidityDate extension with extra bytes + invalidityDate := time.Now() + invalidityDateBytes, err := marshalGeneralizedTimeToBytes(invalidityDate) + if err != nil { + t.Fatal(err) + } + invalidityDateBytes = append(invalidityDateBytes, 0x00) + + entry = x509.RevocationListEntry{ + Extensions: []pkix.Extension{ + { + Id: oidInvalidityDate, + Critical: false, + Value: invalidityDateBytes, + }, + }, + } + _, err = parseEntryExtensions(entry) + if err == nil { + t.Fatal("expected error") + } + }) +} + +// marshalGeneralizedTimeToBytes converts a time.Time to ASN.1 GeneralizedTime bytes. +func marshalGeneralizedTimeToBytes(t time.Time) ([]byte, error) { + // ASN.1 GeneralizedTime requires the time to be in UTC + t = t.UTC() + // Use asn1.Marshal to directly get the ASN.1 GeneralizedTime bytes + return asn1.Marshal(t) +} + +func TestDownload(t *testing.T) { + t.Run("parse url error", func(t *testing.T) { + _, err := download(context.Background(), ":", http.DefaultClient) + if err == nil { + t.Fatal("expected error") + } + }) + t.Run("https download", func(t *testing.T) { + _, err := download(context.Background(), "https://example.com", http.DefaultClient) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("http.NewRequestWithContext error", func(t *testing.T) { + var ctx context.Context = nil + _, err := download(ctx, "http://example.com", &http.Client{}) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("client.Do error", func(t *testing.T) { + _, err := download(context.Background(), "http://example.com", &http.Client{ + Transport: errorRoundTripperMock{}, + }) + + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("status code is not 2xx", func(t *testing.T) { + _, err := download(context.Background(), "http://example.com", &http.Client{ + Transport: serverErrorRoundTripperMock{}, + }) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("readAll error", func(t *testing.T) { + _, err := download(context.Background(), "http://example.com", &http.Client{ + Transport: readFailedRoundTripperMock{}, + }) + if err == nil { + t.Fatal("expected error") + } + }) + + t.Run("exceed the size limit", func(t *testing.T) { + _, err := download(context.Background(), "http://example.com", &http.Client{ + Transport: expectedRoundTripperMock{Body: make([]byte, maxCRLSize+1)}, + }) + if err == nil { + t.Fatal("expected error") + } + }) +} + +func TestSupported(t *testing.T) { + t.Run("supported", func(t *testing.T) { + cert := &x509.Certificate{ + CRLDistributionPoints: []string{"http://example.com"}, + } + if !Supported(cert) { + t.Fatal("expected supported") + } + }) + + t.Run("unsupported", func(t *testing.T) { + cert := &x509.Certificate{} + if Supported(cert) { + t.Fatal("expected unsupported") + } + }) +} + +type errorRoundTripperMock struct{} + +func (rt errorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("error") +} + +type serverErrorRoundTripperMock struct{} + +func (rt serverErrorRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Request: req, + StatusCode: http.StatusInternalServerError, + }, nil +} + +type readFailedRoundTripperMock struct{} + +func (rt readFailedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: errorReaderMock{}, + Request: &http.Request{ + Method: http.MethodGet, + URL: req.URL, + }, + }, nil +} + +type expiredCRLRoundTripperMock struct{} + +func (rt expiredCRLRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + chain := testhelper.GetRevokableRSAChainWithRevocations(1, false, true) + issuerCert := chain[0].Cert + issuerKey := chain[0].PrivateKey + + crlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{ + NextUpdate: time.Now().Add(-time.Hour), + Number: big.NewInt(20240720), + }, issuerCert, issuerKey) + if err != nil { + return nil, err + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBuffer(crlBytes)), + }, nil +} + +type errorReaderMock struct{} + +func (r errorReaderMock) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("error") +} + +func (r errorReaderMock) Close() error { + return nil +} + +type expectedRoundTripperMock struct { + Body []byte +} + +func (rt expectedRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Request: req, + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBuffer(rt.Body)), + }, nil +} diff --git a/revocation/internal/crl/errors.go b/revocation/internal/crl/errors.go new file mode 100644 index 00000000..37866551 --- /dev/null +++ b/revocation/internal/crl/errors.go @@ -0,0 +1,22 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crl + +import "errors" + +var ( + // ErrDeltaCRLNotSupported is returned when the CRL contains a delta CRL but + // the delta CRL is not supported. + ErrDeltaCRLNotSupported = errors.New("delta CRL is not supported") +) diff --git a/revocation/ocsp/errors.go b/revocation/internal/ocsp/errors.go similarity index 100% rename from revocation/ocsp/errors.go rename to revocation/internal/ocsp/errors.go diff --git a/revocation/ocsp/errors_test.go b/revocation/internal/ocsp/errors_test.go similarity index 100% rename from revocation/ocsp/errors_test.go rename to revocation/internal/ocsp/errors_test.go diff --git a/revocation/internal/ocsp/ocsp.go b/revocation/internal/ocsp/ocsp.go new file mode 100644 index 00000000..25410ed3 --- /dev/null +++ b/revocation/internal/ocsp/ocsp.go @@ -0,0 +1,242 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ocsp provides methods for checking the OCSP revocation status of a +// certificate chain, as well as errors related to these checks +package ocsp + +import ( + "bytes" + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/notaryproject/notation-core-go/revocation/result" + "golang.org/x/crypto/ocsp" +) + +// CertCheckStatusOptions specifies values that are needed to check OCSP revocation +type CertCheckStatusOptions struct { + // HTTPClient is the HTTP client used to perform the OCSP request + HTTPClient *http.Client + + // SigningTime is used to compare with the invalidity date during revocation + SigningTime time.Time +} + +const ( + pkixNoCheckOID string = "1.3.6.1.5.5.7.48.1.5" + invalidityDateOID string = "2.5.29.24" + // Max size determined from https://www.ibm.com/docs/en/sva/9.0.6?topic=stanza-ocsp-max-response-size. + // Typical size is ~4 KB + ocspMaxResponseSize int64 = 20480 //bytes +) + +// CertCheckStatus checks the revocation status of a certificate using OCSP +func CertCheckStatus(cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult { + if !Supported(cert) { + // OCSP not enabled for this certificate. + return &result.CertRevocationResult{ + Result: result.ResultNonRevokable, + ServerResults: []*result.ServerResult{toServerResult("", NoServerError{})}, + RevocationMethod: result.RevocationMethodOCSP, + } + } + ocspURLs := cert.OCSPServer + + serverResults := make([]*result.ServerResult, len(ocspURLs)) + for serverIndex, server := range ocspURLs { + serverResult := checkStatusFromServer(cert, issuer, server, opts) + if serverResult.Result == result.ResultOK || + serverResult.Result == result.ResultRevoked || + (serverResult.Result == result.ResultUnknown && errors.Is(serverResult.Error, UnknownStatusError{})) { + // A valid response has been received from an OCSP server + // Result should be based on only this response, not any errors from + // other servers + return serverResultsToCertRevocationResult([]*result.ServerResult{serverResult}) + } + serverResults[serverIndex] = serverResult + } + return serverResultsToCertRevocationResult(serverResults) +} + +// Supported returns true if the certificate supports OCSP. +func Supported(cert *x509.Certificate) bool { + return cert != nil && len(cert.OCSPServer) > 0 +} + +func checkStatusFromServer(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) *result.ServerResult { + // Check valid server + if serverURL, err := url.Parse(server); err != nil || !strings.EqualFold(serverURL.Scheme, "http") { + // This function is only able to check servers that are accessible via HTTP + return toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", serverURL.Scheme)}) + } + + // Create OCSP Request + resp, err := executeOCSPCheck(cert, issuer, server, opts) + if err != nil { + // If there is a server error, attempt all servers before determining what to return + // to the user + return toServerResult(server, err) + } + + // Validate OCSP response isn't expired + if time.Now().After(resp.NextUpdate) { + return toServerResult(server, GenericError{Err: errors.New("expired OCSP response")}) + } + + // Handle pkix-ocsp-no-check and id-ce-invalidityDate extensions if present + // in response + extensionMap := extensionsToMap(resp.Extensions) + if _, foundNoCheck := extensionMap[pkixNoCheckOID]; !foundNoCheck { + // This will be ignored until CRL is implemented + // If it isn't found, CRL should be used to verify the OCSP response + _ = foundNoCheck // needed to bypass linter warnings (Remove after adding CRL) + // TODO: add CRL support + // https://github.com/notaryproject/notation-core-go/issues/125 + } + if invalidityDateBytes, foundInvalidityDate := extensionMap[invalidityDateOID]; foundInvalidityDate && !opts.SigningTime.IsZero() && resp.Status == ocsp.Revoked { + var invalidityDate time.Time + rest, err := asn1.UnmarshalWithParams(invalidityDateBytes, &invalidityDate, "generalized") + if len(rest) == 0 && err == nil && opts.SigningTime.Before(invalidityDate) { + return toServerResult(server, nil) + } + } + + // No errors, valid server response + switch resp.Status { + case ocsp.Good: + return toServerResult(server, nil) + case ocsp.Revoked: + return toServerResult(server, RevokedError{}) + default: + // ocsp.Unknown + return toServerResult(server, UnknownStatusError{}) + } +} + +func extensionsToMap(extensions []pkix.Extension) map[string][]byte { + extensionMap := make(map[string][]byte) + for _, extension := range extensions { + extensionMap[extension.Id.String()] = extension.Value + } + return extensionMap +} + +func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) (*ocsp.Response, error) { + // TODO: Look into other alternatives for specifying the Hash + // https://github.com/notaryproject/notation-core-go/issues/139 + // The following do not support SHA256 hashes: + // - Microsoft + // - Entrust + // - Let's Encrypt + // - Digicert (sometimes) + // As this represents a large percentage of public CAs, we are using the + // hashing algorithm SHA1, which has been confirmed to be supported by all + // that were tested. + ocspRequest, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA1}) + if err != nil { + return nil, GenericError{Err: err} + } + + var resp *http.Response + postRequired := base64.StdEncoding.EncodedLen(len(ocspRequest)) >= 255 + if !postRequired { + encodedReq := url.QueryEscape(base64.StdEncoding.EncodeToString(ocspRequest)) + if len(encodedReq) < 255 { + var reqURL string + reqURL, err = url.JoinPath(server, encodedReq) + if err != nil { + return nil, GenericError{Err: err} + } + resp, err = opts.HTTPClient.Get(reqURL) + } else { + resp, err = postRequest(ocspRequest, server, opts.HTTPClient) + } + } else { + resp, err = postRequest(ocspRequest, server, opts.HTTPClient) + } + + if err != nil { + var urlErr *url.Error + if errors.As(err, &urlErr) && urlErr.Timeout() { + return nil, TimeoutError{} + } + return nil, GenericError{Err: err} + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("failed to retrieve OCSP: response had status code %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, ocspMaxResponseSize)) + if err != nil { + return nil, GenericError{Err: err} + } + + switch { + case bytes.Equal(body, ocsp.UnauthorizedErrorResponse): + return nil, GenericError{Err: errors.New("OCSP unauthorized")} + case bytes.Equal(body, ocsp.MalformedRequestErrorResponse): + return nil, GenericError{Err: errors.New("OCSP malformed")} + case bytes.Equal(body, ocsp.InternalErrorErrorResponse): + return nil, GenericError{Err: errors.New("OCSP internal error")} + case bytes.Equal(body, ocsp.TryLaterErrorResponse): + return nil, GenericError{Err: errors.New("OCSP try later")} + case bytes.Equal(body, ocsp.SigRequredErrorResponse): + return nil, GenericError{Err: errors.New("OCSP signature required")} + } + + return ocsp.ParseResponseForCert(body, cert, issuer) +} + +func postRequest(req []byte, server string, httpClient *http.Client) (*http.Response, error) { + reader := bytes.NewReader(req) + return httpClient.Post(server, "application/ocsp-request", reader) +} + +func toServerResult(server string, err error) *result.ServerResult { + var serverResult *result.ServerResult + switch t := err.(type) { + case nil: + serverResult = result.NewServerResult(result.ResultOK, server, nil) + case NoServerError: + serverResult = result.NewServerResult(result.ResultNonRevokable, server, nil) + case RevokedError: + serverResult = result.NewServerResult(result.ResultRevoked, server, t) + default: + // Includes GenericError, UnknownStatusError, result.InvalidChainError, + // and TimeoutError + serverResult = result.NewServerResult(result.ResultUnknown, server, t) + } + serverResult.RevocationMethod = result.RevocationMethodOCSP + return serverResult +} + +func serverResultsToCertRevocationResult(serverResults []*result.ServerResult) *result.CertRevocationResult { + return &result.CertRevocationResult{ + Result: serverResults[len(serverResults)-1].Result, + ServerResults: serverResults, + RevocationMethod: result.RevocationMethodOCSP, + } +} diff --git a/revocation/internal/ocsp/ocsp_test.go b/revocation/internal/ocsp/ocsp_test.go new file mode 100644 index 00000000..f1f0f118 --- /dev/null +++ b/revocation/internal/ocsp/ocsp_test.go @@ -0,0 +1,208 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ocsp + +import ( + "crypto/x509" + "fmt" + "net/http" + "strings" + "testing" + "time" + + "github.com/notaryproject/notation-core-go/revocation/result" + "github.com/notaryproject/notation-core-go/testhelper" + "golang.org/x/crypto/ocsp" +) + +func validateEquivalentCertResults(certResults, expectedCertResults []*result.CertRevocationResult, t *testing.T) { + if len(certResults) != len(expectedCertResults) { + t.Errorf("Length of certResults (%d) did not match expected length (%d)", len(certResults), len(expectedCertResults)) + return + } + for i, certResult := range certResults { + if certResult.Result != expectedCertResults[i].Result { + t.Errorf("Expected certResults[%d].Result to be %s, but got %s", i, expectedCertResults[i].Result, certResult.Result) + } + if len(certResult.ServerResults) != len(expectedCertResults[i].ServerResults) { + t.Errorf("Length of certResults[%d].ServerResults (%d) did not match expected length (%d)", i, len(certResult.ServerResults), len(expectedCertResults[i].ServerResults)) + return + } + for j, serverResult := range certResult.ServerResults { + if serverResult.Result != expectedCertResults[i].ServerResults[j].Result { + t.Errorf("Expected certResults[%d].ServerResults[%d].Result to be %s, but got %s", i, j, expectedCertResults[i].ServerResults[j].Result, serverResult.Result) + } + if serverResult.Server != expectedCertResults[i].ServerResults[j].Server { + t.Errorf("Expected certResults[%d].ServerResults[%d].Server to be %s, but got %s", i, j, expectedCertResults[i].ServerResults[j].Server, serverResult.Server) + } + if serverResult.Error == nil { + if expectedCertResults[i].ServerResults[j].Error == nil { + continue + } + t.Errorf("certResults[%d].ServerResults[%d].Error was nil, but expected %v", i, j, expectedCertResults[i].ServerResults[j].Error) + } else if expectedCertResults[i].ServerResults[j].Error == nil { + t.Errorf("Unexpected error for certResults[%d].ServerResults[%d].Error: %v", i, j, serverResult.Error) + } else if serverResult.Error.Error() != expectedCertResults[i].ServerResults[j].Error.Error() { + t.Errorf("Expected certResults[%d].ServerResults[%d].Error to be %v, but got %v", i, j, expectedCertResults[i].ServerResults[j].Error, serverResult.Error) + } + } + } +} + +func getOKCertResult(server string) *result.CertRevocationResult { + return &result.CertRevocationResult{ + Result: result.ResultOK, + ServerResults: []*result.ServerResult{ + result.NewServerResult(result.ResultOK, server, nil), + }, + } +} + +func TestCheckStatus(t *testing.T) { + revokableCertTuple := testhelper.GetRevokableRSALeafCertificate() + revokableIssuerTuple := testhelper.GetRSARootCertificate() + ocspServer := revokableCertTuple.Cert.OCSPServer[0] + revokableChain := []*x509.Certificate{revokableCertTuple.Cert, revokableIssuerTuple.Cert} + testChain := []testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple} + + t.Run("check non-revoked cert", func(t *testing.T) { + client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good}, nil, true) + opts := CertCheckStatusOptions{ + SigningTime: time.Now(), + HTTPClient: client, + } + + certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} + validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) + }) + t.Run("check cert with Unknown OCSP response", func(t *testing.T) { + client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Unknown}, nil, true) + opts := CertCheckStatusOptions{ + SigningTime: time.Now(), + HTTPClient: client, + } + + certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + expectedCertResults := []*result.CertRevocationResult{{ + Result: result.ResultUnknown, + ServerResults: []*result.ServerResult{ + result.NewServerResult(result.ResultUnknown, ocspServer, UnknownStatusError{}), + }, + }} + validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) + }) + t.Run("check OCSP revoked cert", func(t *testing.T) { + client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Revoked}, nil, true) + opts := CertCheckStatusOptions{ + SigningTime: time.Now(), + HTTPClient: client, + } + + certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + expectedCertResults := []*result.CertRevocationResult{{ + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{ + result.NewServerResult(result.ResultRevoked, ocspServer, RevokedError{}), + }, + }} + validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) + }) + t.Run("check OCSP future revoked cert", func(t *testing.T) { + revokedTime := time.Now().Add(time.Hour) + client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Revoked}, &revokedTime, true) + opts := CertCheckStatusOptions{ + SigningTime: time.Now(), + HTTPClient: client, + } + + certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} + validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) + }) + + t.Run("certificate doesn't support OCSP", func(t *testing.T) { + ocspResult := CertCheckStatus(&x509.Certificate{}, revokableIssuerTuple.Cert, CertCheckStatusOptions{}) + expectedResult := &result.CertRevocationResult{ + Result: result.ResultNonRevokable, + ServerResults: []*result.ServerResult{toServerResult("", NoServerError{})}, + } + + validateEquivalentCertResults([]*result.CertRevocationResult{ocspResult}, []*result.CertRevocationResult{expectedResult}, t) + }) +} + +func TestCheckStatusFromServer(t *testing.T) { + revokableCertTuple := testhelper.GetRevokableRSALeafCertificate() + revokableIssuerTuple := testhelper.GetRSARootCertificate() + + t.Run("server url is not http", func(t *testing.T) { + server := "https://example.com" + serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{}) + expectedResult := toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", "https")}) + if serverResult.Result != expectedResult.Result { + t.Errorf("Expected Result to be %s, but got %s", expectedResult.Result, serverResult.Result) + } + if serverResult.Server != expectedResult.Server { + t.Errorf("Expected Server to be %s, but got %s", expectedResult.Server, serverResult.Server) + } + if serverResult.Error == nil { + t.Errorf("Expected Error to be %v, but got nil", expectedResult.Error) + } else if serverResult.Error.Error() != expectedResult.Error.Error() { + t.Errorf("Expected Error to be %v, but got %v", expectedResult.Error, serverResult.Error) + } + }) + + t.Run("request error", func(t *testing.T) { + server := "http://example.com" + serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Transport: &failedTransport{}, + }, + }) + errorMessage := "failed to execute request" + if !strings.Contains(serverResult.Error.Error(), errorMessage) { + t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error) + } + }) + + t.Run("ocsp expired", func(t *testing.T) { + client := testhelper.MockClient([]testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) + server := "http://example.com/expired_ocsp" + serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + HTTPClient: client, + }) + errorMessage := "expired OCSP response" + if !strings.Contains(serverResult.Error.Error(), errorMessage) { + t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error) + } + }) +} + +func TestPostRequest(t *testing.T) { + t.Run("failed to execute request", func(t *testing.T) { + _, err := postRequest(nil, "http://example.com", &http.Client{ + Transport: &failedTransport{}, + }) + if err == nil { + t.Errorf("Expected error, but got nil") + } + }) +} + +type failedTransport struct{} + +func (f *failedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("failed to execute request") +} diff --git a/revocation/internal/x509util/validate.go b/revocation/internal/x509util/validate.go new file mode 100644 index 00000000..134ef22b --- /dev/null +++ b/revocation/internal/x509util/validate.go @@ -0,0 +1,47 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package x509util provides the method to validate the certificate chain for a +// specific purpose, including code signing and timestamping. +package x509util + +import ( + "crypto/x509" + "fmt" + + "github.com/notaryproject/notation-core-go/revocation/purpose" + "github.com/notaryproject/notation-core-go/revocation/result" + coreX509 "github.com/notaryproject/notation-core-go/x509" +) + +// ValidateChain checks the certificate chain for a specific purpose, including +// code signing and timestamping. +func ValidateChain(certChain []*x509.Certificate, certChainPurpose purpose.Purpose) error { + switch certChainPurpose { + case purpose.CodeSigning: + // Since ValidateCodeSigningCertChain is using authentic signing time, + // signing time may be zero. + // Thus, it is better to pass nil here than fail for a cert's NotBefore + // being after zero time + if err := coreX509.ValidateCodeSigningCertChain(certChain, nil); err != nil { + return result.InvalidChainError{Err: err} + } + case purpose.Timestamping: + if err := coreX509.ValidateTimestampingCertChain(certChain); err != nil { + return result.InvalidChainError{Err: err} + } + default: + return result.InvalidChainError{Err: fmt.Errorf("unsupported certificate chain purpose %v", certChainPurpose)} + } + return nil +} diff --git a/revocation/internal/x509util/validate_test.go b/revocation/internal/x509util/validate_test.go new file mode 100644 index 00000000..22022e36 --- /dev/null +++ b/revocation/internal/x509util/validate_test.go @@ -0,0 +1,60 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package x509util + +import ( + "crypto/x509" + "testing" + + "github.com/notaryproject/notation-core-go/revocation/purpose" + "github.com/notaryproject/notation-core-go/testhelper" +) + +func TestValidate(t *testing.T) { + t.Run("unsupported_certificate_chain_purpose", func(t *testing.T) { + certChain := []*x509.Certificate{} + certChainPurpose := purpose.Purpose(-1) + err := ValidateChain(certChain, certChainPurpose) + if err == nil { + t.Errorf("Validate() failed, expected error, got nil") + } + }) + + t.Run("invalid code signing certificate chain", func(t *testing.T) { + certChain := []*x509.Certificate{} + certChainPurpose := purpose.CodeSigning + err := ValidateChain(certChain, certChainPurpose) + if err == nil { + t.Errorf("Validate() failed, expected error, got nil") + } + }) + + t.Run("invalid timestamping certificate chain", func(t *testing.T) { + certChain := []*x509.Certificate{} + certChainPurpose := purpose.Timestamping + err := ValidateChain(certChain, certChainPurpose) + if err == nil { + t.Errorf("Validate() failed, expected error, got nil") + } + }) + + t.Run("valid code signing certificate chain", func(t *testing.T) { + certChain := testhelper.GetRevokableRSAChain(2) + certChainPurpose := purpose.CodeSigning + err := ValidateChain([]*x509.Certificate{certChain[0].Cert, certChain[1].Cert}, certChainPurpose) + if err != nil { + t.Errorf("Validate() failed, expected nil, got %v", err) + } + }) +} diff --git a/revocation/ocsp/error.go b/revocation/ocsp/error.go new file mode 100644 index 00000000..005c700f --- /dev/null +++ b/revocation/ocsp/error.go @@ -0,0 +1,37 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ocsp + +import "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" + +type ( + // RevokedError is returned when the certificate's status for OCSP is + // ocsp.Revoked + RevokedError = ocsp.RevokedError + + // UnknownStatusError is returned when the certificate's status for OCSP is + // ocsp.Unknown + UnknownStatusError = ocsp.UnknownStatusError + + // GenericError is returned when there is an error during the OCSP revocation + // check, not necessarily a revocation + GenericError = ocsp.GenericError + + // NoServerError is returned when the OCSPServer is not specified. + NoServerError = ocsp.NoServerError + + // TimeoutError is returned when the connection attempt to an OCSP URL exceeds + // the specified threshold + TimeoutError = ocsp.TimeoutError +) diff --git a/revocation/ocsp/ocsp.go b/revocation/ocsp/ocsp.go index cb9e97cc..c2274f75 100644 --- a/revocation/ocsp/ocsp.go +++ b/revocation/ocsp/ocsp.go @@ -16,25 +16,16 @@ package ocsp import ( - "bytes" - "crypto" "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "encoding/base64" "errors" - "fmt" - "io" "net/http" - "net/url" - "strings" "sync" "time" + "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" + "github.com/notaryproject/notation-core-go/revocation/internal/x509util" "github.com/notaryproject/notation-core-go/revocation/purpose" "github.com/notaryproject/notation-core-go/revocation/result" - coreX509 "github.com/notaryproject/notation-core-go/x509" - "golang.org/x/crypto/ocsp" ) // Options specifies values that are needed to check OCSP revocation @@ -45,19 +36,10 @@ type Options struct { // values are CodeSigning and Timestamping. // When not provided, the default value is CodeSigning. CertChainPurpose purpose.Purpose - - SigningTime time.Time - HTTPClient *http.Client + SigningTime time.Time + HTTPClient *http.Client } -const ( - pkixNoCheckOID string = "1.3.6.1.5.5.7.48.1.5" - invalidityDateOID string = "2.5.29.24" - // Max size determined from https://www.ibm.com/docs/en/sva/9.0.6?topic=stanza-ocsp-max-response-size. - // Typical size is ~4 KB - ocspMaxResponseSize int64 = 20480 //bytes -) - // CheckStatus checks OCSP based on the passed options and returns an array of // result.CertRevocationResult objects that contains the results and error. The // length of this array will always be equal to the length of the certificate @@ -67,24 +49,15 @@ func CheckStatus(opts Options) ([]*result.CertRevocationResult, error) { return nil, result.InvalidChainError{Err: errors.New("chain does not contain any certificates")} } - switch opts.CertChainPurpose { - case purpose.CodeSigning: - // Since ValidateCodeSigningCertChain is using authentic signing time, - // signing time may be zero. - // Thus, it is better to pass nil here than fail for a cert's NotBefore - // being after zero time - if err := coreX509.ValidateCodeSigningCertChain(opts.CertChain, nil); err != nil { - return nil, result.InvalidChainError{Err: err} - } - case purpose.Timestamping: - if err := coreX509.ValidateTimestampingCertChain(opts.CertChain); err != nil { - return nil, result.InvalidChainError{Err: err} - } - default: - return nil, result.InvalidChainError{Err: fmt.Errorf("unsupported certificate chain purpose %v", opts.CertChainPurpose)} + if err := x509util.ValidateChain(opts.CertChain, opts.CertChainPurpose); err != nil { + return nil, err } certResults := make([]*result.CertRevocationResult, len(opts.CertChain)) + certCheckStatusOptions := ocsp.CertCheckStatusOptions{ + SigningTime: opts.SigningTime, + HTTPClient: opts.HTTPClient, + } // Check status for each cert in cert chain var wg sync.WaitGroup @@ -93,7 +66,7 @@ func CheckStatus(opts Options) ([]*result.CertRevocationResult, error) { // Assume cert chain is accurate and next cert in chain is the issuer go func(i int, cert *x509.Certificate) { defer wg.Done() - certResults[i] = certCheckStatus(cert, opts.CertChain[i+1], opts) + certResults[i] = ocsp.CertCheckStatus(cert, opts.CertChain[i+1], certCheckStatusOptions) }(i, cert) } // Last is root cert, which will never be revoked by OCSP @@ -108,182 +81,3 @@ func CheckStatus(opts Options) ([]*result.CertRevocationResult, error) { wg.Wait() return certResults, nil } - -func certCheckStatus(cert, issuer *x509.Certificate, opts Options) *result.CertRevocationResult { - ocspURLs := cert.OCSPServer - if len(ocspURLs) == 0 { - // OCSP not enabled for this certificate. - return &result.CertRevocationResult{ - Result: result.ResultNonRevokable, - ServerResults: []*result.ServerResult{toServerResult("", NoServerError{})}, - } - } - - serverResults := make([]*result.ServerResult, len(ocspURLs)) - for serverIndex, server := range ocspURLs { - serverResult := checkStatusFromServer(cert, issuer, server, opts) - if serverResult.Result == result.ResultOK || - serverResult.Result == result.ResultRevoked || - (serverResult.Result == result.ResultUnknown && errors.Is(serverResult.Error, UnknownStatusError{})) { - // A valid response has been received from an OCSP server - // Result should be based on only this response, not any errors from - // other servers - return serverResultsToCertRevocationResult([]*result.ServerResult{serverResult}) - } - serverResults[serverIndex] = serverResult - } - return serverResultsToCertRevocationResult(serverResults) -} - -func checkStatusFromServer(cert, issuer *x509.Certificate, server string, opts Options) *result.ServerResult { - // Check valid server - if serverURL, err := url.Parse(server); err != nil || !strings.EqualFold(serverURL.Scheme, "http") { - // This function is only able to check servers that are accessible via HTTP - return toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", serverURL.Scheme)}) - } - - // Create OCSP Request - resp, err := executeOCSPCheck(cert, issuer, server, opts) - if err != nil { - // If there is a server error, attempt all servers before determining what to return - // to the user - return toServerResult(server, err) - } - - // Validate OCSP response isn't expired - if time.Now().After(resp.NextUpdate) { - return toServerResult(server, GenericError{Err: errors.New("expired OCSP response")}) - } - - // Handle pkix-ocsp-no-check and id-ce-invalidityDate extensions if present - // in response - extensionMap := extensionsToMap(resp.Extensions) - if _, foundNoCheck := extensionMap[pkixNoCheckOID]; !foundNoCheck { - // This will be ignored until CRL is implemented - // If it isn't found, CRL should be used to verify the OCSP response - _ = foundNoCheck // needed to bypass linter warnings (Remove after adding CRL) - // TODO: add CRL support - // https://github.com/notaryproject/notation-core-go/issues/125 - } - if invalidityDateBytes, foundInvalidityDate := extensionMap[invalidityDateOID]; foundInvalidityDate && !opts.SigningTime.IsZero() && resp.Status == ocsp.Revoked { - var invalidityDate time.Time - rest, err := asn1.UnmarshalWithParams(invalidityDateBytes, &invalidityDate, "generalized") - if len(rest) == 0 && err == nil && opts.SigningTime.Before(invalidityDate) { - return toServerResult(server, nil) - } - } - - // No errors, valid server response - switch resp.Status { - case ocsp.Good: - return toServerResult(server, nil) - case ocsp.Revoked: - return toServerResult(server, RevokedError{}) - default: - // ocsp.Unknown - return toServerResult(server, UnknownStatusError{}) - } -} - -func extensionsToMap(extensions []pkix.Extension) map[string][]byte { - extensionMap := make(map[string][]byte) - for _, extension := range extensions { - extensionMap[extension.Id.String()] = extension.Value - } - return extensionMap -} - -func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts Options) (*ocsp.Response, error) { - // TODO: Look into other alternatives for specifying the Hash - // https://github.com/notaryproject/notation-core-go/issues/139 - // The following do not support SHA256 hashes: - // - Microsoft - // - Entrust - // - Let's Encrypt - // - Digicert (sometimes) - // As this represents a large percentage of public CAs, we are using the - // hashing algorithm SHA1, which has been confirmed to be supported by all - // that were tested. - ocspRequest, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA1}) - if err != nil { - return nil, GenericError{Err: err} - } - - var resp *http.Response - postRequired := base64.StdEncoding.EncodedLen(len(ocspRequest)) >= 255 - if !postRequired { - encodedReq := url.QueryEscape(base64.StdEncoding.EncodeToString(ocspRequest)) - if len(encodedReq) < 255 { - var reqURL string - reqURL, err = url.JoinPath(server, encodedReq) - if err != nil { - return nil, GenericError{Err: err} - } - resp, err = opts.HTTPClient.Get(reqURL) - } else { - resp, err = postRequest(ocspRequest, server, opts.HTTPClient) - } - } else { - resp, err = postRequest(ocspRequest, server, opts.HTTPClient) - } - - if err != nil { - var urlErr *url.Error - if errors.As(err, &urlErr) && urlErr.Timeout() { - return nil, TimeoutError{} - } - return nil, GenericError{Err: err} - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("failed to retrieve OCSP: response had status code %d", resp.StatusCode) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, ocspMaxResponseSize)) - if err != nil { - return nil, GenericError{Err: err} - } - - switch { - case bytes.Equal(body, ocsp.UnauthorizedErrorResponse): - return nil, GenericError{Err: errors.New("OCSP unauthorized")} - case bytes.Equal(body, ocsp.MalformedRequestErrorResponse): - return nil, GenericError{Err: errors.New("OCSP malformed")} - case bytes.Equal(body, ocsp.InternalErrorErrorResponse): - return nil, GenericError{Err: errors.New("OCSP internal error")} - case bytes.Equal(body, ocsp.TryLaterErrorResponse): - return nil, GenericError{Err: errors.New("OCSP try later")} - case bytes.Equal(body, ocsp.SigRequredErrorResponse): - return nil, GenericError{Err: errors.New("OCSP signature required")} - } - - return ocsp.ParseResponseForCert(body, cert, issuer) -} - -func postRequest(req []byte, server string, httpClient *http.Client) (*http.Response, error) { - reader := bytes.NewReader(req) - return httpClient.Post(server, "application/ocsp-request", reader) -} - -func toServerResult(server string, err error) *result.ServerResult { - switch t := err.(type) { - case nil: - return result.NewServerResult(result.ResultOK, server, nil) - case NoServerError: - return result.NewServerResult(result.ResultNonRevokable, server, nil) - case RevokedError: - return result.NewServerResult(result.ResultRevoked, server, t) - default: - // Includes GenericError, UnknownStatusError, result.InvalidChainError, - // and TimeoutError - return result.NewServerResult(result.ResultUnknown, server, t) - } -} - -func serverResultsToCertRevocationResult(serverResults []*result.ServerResult) *result.CertRevocationResult { - return &result.CertRevocationResult{ - Result: serverResults[len(serverResults)-1].Result, - ServerResults: serverResults, - } -} diff --git a/revocation/ocsp/ocsp_test.go b/revocation/ocsp/ocsp_test.go index 8a1479da..f677e6de 100644 --- a/revocation/ocsp/ocsp_test.go +++ b/revocation/ocsp/ocsp_test.go @@ -79,81 +79,14 @@ func getRootCertResult() *result.CertRevocationResult { } } -func TestCheckStatus(t *testing.T) { - revokableCertTuple := testhelper.GetRevokableRSALeafCertificate() - revokableIssuerTuple := testhelper.GetRSARootCertificate() - ocspServer := revokableCertTuple.Cert.OCSPServer[0] - revokableChain := []*x509.Certificate{revokableCertTuple.Cert, revokableIssuerTuple.Cert} - testChain := []testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple} - - t.Run("check non-revoked cert", func(t *testing.T) { - client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good}, nil, true) - opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, - } - - certResult := certCheckStatus(revokableChain[0], revokableChain[1], opts) - expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} - validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) - }) - t.Run("check cert with Unknown OCSP response", func(t *testing.T) { - client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Unknown}, nil, true) - opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, - } - - certResult := certCheckStatus(revokableChain[0], revokableChain[1], opts) - expectedCertResults := []*result.CertRevocationResult{{ - Result: result.ResultUnknown, - ServerResults: []*result.ServerResult{ - result.NewServerResult(result.ResultUnknown, ocspServer, UnknownStatusError{}), - }, - }} - validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) - }) - t.Run("check OCSP revoked cert", func(t *testing.T) { - client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Revoked}, nil, true) - opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, - } - - certResult := certCheckStatus(revokableChain[0], revokableChain[1], opts) - expectedCertResults := []*result.CertRevocationResult{{ - Result: result.ResultRevoked, - ServerResults: []*result.ServerResult{ - result.NewServerResult(result.ResultRevoked, ocspServer, RevokedError{}), - }, - }} - validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) - }) - t.Run("check OCSP future revoked cert", func(t *testing.T) { - revokedTime := time.Now().Add(time.Hour) - client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Revoked}, &revokedTime, true) - opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, - } - - certResult := certCheckStatus(revokableChain[0], revokableChain[1], opts) - expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} - validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) - }) -} - func TestCheckStatusForSelfSignedCert(t *testing.T) { selfSignedTuple := testhelper.GetRSASelfSignedSigningCertTuple("Notation revocation test self-signed cert") client := testhelper.MockClient([]testhelper.RSACertTuple{selfSignedTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: []*x509.Certificate{selfSignedTuple.Cert}, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: []*x509.Certificate{selfSignedTuple.Cert}, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -168,9 +101,10 @@ func TestCheckStatusForRootCert(t *testing.T) { rootTuple := testhelper.GetRSARootCertificate() client := testhelper.MockClient([]testhelper.RSACertTuple{rootTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: []*x509.Certificate{rootTuple.Cert}, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: []*x509.Certificate{rootTuple.Cert}, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -187,9 +121,10 @@ func TestCheckStatusForNonSelfSignedSingleCert(t *testing.T) { certTuple := testhelper.GetRSALeafCertificate() client := testhelper.MockClient([]testhelper.RSACertTuple{certTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: []*x509.Certificate{certTuple.Cert}, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: []*x509.Certificate{certTuple.Cert}, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -213,9 +148,10 @@ func TestCheckStatusForChain(t *testing.T) { t.Run("empty chain", func(t *testing.T) { opts := Options{ - CertChain: []*x509.Certificate{}, - SigningTime: time.Now(), - HTTPClient: http.DefaultClient, + CertChain: []*x509.Certificate{}, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: http.DefaultClient, } certResults, err := CheckStatus(opts) expectedErr := result.InvalidChainError{Err: errors.New("chain does not contain any certificates")} @@ -253,9 +189,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Unknown, ocsp.Good}, nil, true) // 3rd cert will be unknown, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -281,9 +218,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Revoked, ocsp.Good}, nil, true) // 3rd cert will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -309,9 +247,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Unknown, ocsp.Good, ocsp.Revoked, ocsp.Good}, nil, true) // 3rd cert will be unknown, 5th will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -343,9 +282,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Revoked, ocsp.Good}, &revokedTime, true) // 3rd cert will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -367,9 +307,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Unknown, ocsp.Good, ocsp.Revoked, ocsp.Good}, &revokedTime, true) // 3rd cert will be unknown, 5th will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -395,9 +336,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Revoked, ocsp.Good}, nil, true) // 3rd cert will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: time.Now().Add(time.Hour), - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now().Add(time.Hour), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -424,9 +366,10 @@ func TestCheckStatusForChain(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good, ocsp.Good, ocsp.Revoked, ocsp.Good}, &revokedTime, true) // 3rd cert will be revoked, the rest will be good opts := Options{ - CertChain: revokableChain, - SigningTime: zeroTime, - HTTPClient: client, + CertChain: revokableChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: zeroTime, + HTTPClient: client, } if !zeroTime.IsZero() { @@ -484,9 +427,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("no OCSPServer specified", func(t *testing.T) { opts := Options{ - CertChain: noOCSPChain, - SigningTime: time.Now(), - HTTPClient: http.DefaultClient, + CertChain: noOCSPChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: http.DefaultClient, } certResults, err := CheckStatus(opts) if err != nil { @@ -506,9 +450,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("chain missing root", func(t *testing.T) { opts := Options{ - CertChain: noRootChain, - SigningTime: time.Now(), - HTTPClient: http.DefaultClient, + CertChain: noRootChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: http.DefaultClient, } certResults, err := CheckStatus(opts) if err == nil || err.Error() != chainRootErr.Error() { @@ -521,9 +466,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("backwards chain", func(t *testing.T) { opts := Options{ - CertChain: backwardsChain, - SigningTime: time.Now(), - HTTPClient: http.DefaultClient, + CertChain: backwardsChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: http.DefaultClient, } certResults, err := CheckStatus(opts) if err == nil || err.Error() != backwardsChainErr.Error() { @@ -581,9 +527,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("timeout", func(t *testing.T) { timeoutClient := &http.Client{Timeout: 1 * time.Nanosecond} opts := Options{ - CertChain: okChain, - SigningTime: time.Now(), - HTTPClient: timeoutClient, + CertChain: okChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: timeoutClient, } certResults, err := CheckStatus(opts) if err != nil { @@ -610,9 +557,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("expired ocsp response", func(t *testing.T) { client := testhelper.MockClient(revokableTuples, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: expiredChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: expiredChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) if err != nil { @@ -634,9 +582,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("pkixNoCheck missing", func(t *testing.T) { client := testhelper.MockClient(revokableTuples, []ocsp.ResponseStatus{ocsp.Good}, nil, false) opts := Options{ - CertChain: okChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: okChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) @@ -654,9 +603,10 @@ func TestCheckStatusErrors(t *testing.T) { t.Run("non-HTTP URI error", func(t *testing.T) { client := testhelper.MockClient(revokableTuples, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: noHTTPChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: noHTTPChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) if err != nil { @@ -701,9 +651,10 @@ func TestCheckOCSPInvalidChain(t *testing.T) { t.Run("chain missing intermediate", func(t *testing.T) { client := testhelper.MockClient(revokableTuples, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: missingIntermediateChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: missingIntermediateChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) if err == nil || err.Error() != missingIntermediateErr.Error() { @@ -717,9 +668,10 @@ func TestCheckOCSPInvalidChain(t *testing.T) { t.Run("chain out of order", func(t *testing.T) { client := testhelper.MockClient(misorderedIntermediateTuples, []ocsp.ResponseStatus{ocsp.Good}, nil, true) opts := Options{ - CertChain: misorderedIntermediateChain, - SigningTime: time.Now(), - HTTPClient: client, + CertChain: misorderedIntermediateChain, + CertChainPurpose: purpose.CodeSigning, + SigningTime: time.Now(), + HTTPClient: client, } certResults, err := CheckStatus(opts) if err == nil || err.Error() != misorderedChainErr.Error() { diff --git a/revocation/result/results.go b/revocation/result/results.go index c7ecba51..09718cc3 100644 --- a/revocation/result/results.go +++ b/revocation/result/results.go @@ -16,23 +16,27 @@ package result import "strconv" -// Result is a type of enumerated value to help characterize errors. It can be -// OK, Unknown, or Revoked +// Result is a type of enumerated value to help characterize revocation result. +// It can be OK, Unknown, NonRevokable, or Revoked type Result int const ( // ResultUnknown is a Result that indicates that some error other than a - // revocation was encountered during the revocation check + // revocation was encountered during the revocation check. ResultUnknown Result = iota - // ResultOK is a Result that indicates that the revocation check resulted in no - // important errors + + // ResultOK is a Result that indicates that the revocation check resulted in + // no important errors. ResultOK - // ResultNonRevokable is a Result that indicates that the certificate cannot be - // checked for revocation. This may be a result of no OCSP servers being - // specified, the cert is a root certificate, or other related situations. + + // ResultNonRevokable is a Result that indicates that the certificate cannot + // be checked for revocation. This may be due to the absence of OCSP servers + // or CRL distribution points, or because the certificate is a root + // certificate. ResultNonRevokable + // ResultRevoked is a Result that indicates that at least one certificate was - // revoked when performing a revocation check on the certificate chain + // revoked when performing a revocation check on the certificate chain. ResultRevoked ) @@ -52,8 +56,45 @@ func (r Result) String() string { } } -// ServerResult encapsulates the result for a single server for a single -// certificate in the chain +// RevocationMethod defines the method used to check the revocation status of a +// certificate. +type RevocationMethod int + +const ( + // RevocationMethodUnknown is used for root certificates or when the method + // used to check the revocation status of a certificate is unknown. + RevocationMethodUnknown RevocationMethod = iota + + // RevocationMethodOCSP represents OCSP as the method used to check the + // revocation status of a certificate. + RevocationMethodOCSP + + // RevocationMethodCRL represents CRL as the method used to check the + // revocation status of a certificate. + RevocationMethodCRL + + // RevocationMethodOCSPFallbackCRL represents OCSP check with unknown error + // fallback to CRL as the method used to check the revocation status of a + // certificate. + RevocationMethodOCSPFallbackCRL +) + +// String provides a conversion from a Method to a string +func (m RevocationMethod) String() string { + switch m { + case RevocationMethodOCSP: + return "OCSP" + case RevocationMethodCRL: + return "CRL" + case RevocationMethodOCSPFallbackCRL: + return "OCSPFallbackCRL" + default: + return "Unknown" + } +} + +// ServerResult encapsulates the OCSP result for a single server or the CRL +// result for a single CRL URI for a certificate in the chain type ServerResult struct { // Result of revocation for this server (Unknown if there is an error which // prevents the retrieval of a valid status) @@ -67,6 +108,11 @@ type ServerResult struct { // Error is set if there is an error associated with the revocation check // to this server Error error + + // RevocationMethod is the method used to check the revocation status of the + // certificate, including RevocationMethodUnknown, RevocationMethodOCSP, + // RevocationMethodCRL + RevocationMethod RevocationMethod } // NewServerResult creates a ServerResult object from its individual parts: a @@ -83,21 +129,31 @@ func NewServerResult(result Result, server string, err error) *ServerResult { // chain as well as the results from individual servers associated with this // certificate type CertRevocationResult struct { - // Result of revocation for a specific cert in the chain - // - // If there are multiple ServerResults, this is because no responses were - // able to be retrieved, leaving each ServerResult with a Result of Unknown. - // Thus, in the case of more than one ServerResult, this will be ResultUnknown + // Result of revocation for a specific certificate in the chain. Result Result - // An array of results for each server associated with the certificate. - // The length will be either 1 or the number of OCSPServers for the cert. + // ServerResults is an array of results for each server associated with the + // certificate. // - // If the length is 1, then a valid status was able to be retrieved. Only + // When RevocationMethod is MethodOCSP, the length will be + // either 1 or the number of OCSPServers for the certificate. + // If the length is 1, then a valid status was retrieved. Only // this server result is contained. Any errors for other servers are // discarded in favor of this valid response. - // // Otherwise, every server specified had some error that prevented the - // status from being retrieved. These are all contained here for evaluation + // status from being retrieved. These are all contained here for evaluation. + // + // When RevocationMethod is MethodCRL, the length will be the number of + // CRL distribution points' URIs checked. If the result is Revoked, or + // there is an error, the length will be 1. + // + // When RevocationMethod is MethodOCSPFallbackCRL, the length + // will be the sum of the previous two cases. The CRL result will be + // appended after the OCSP results. ServerResults []*ServerResult + + // RevocationMethod is the method used to check the revocation status of the + // certificate, including RevocationMethodUnknown, RevocationMethodOCSP, + // RevocationMethodCRL and RevocationMethodOCSPFallbackCRL + RevocationMethod RevocationMethod } diff --git a/revocation/result/results_test.go b/revocation/result/results_test.go index 1c5a503a..75b42ae3 100644 --- a/revocation/result/results_test.go +++ b/revocation/result/results_test.go @@ -46,6 +46,27 @@ func TestResultString(t *testing.T) { }) } +func TestMethodString(t *testing.T) { + tests := []struct { + method RevocationMethod + expected string + }{ + {RevocationMethodOCSP, "OCSP"}, + {RevocationMethodCRL, "CRL"}, + {RevocationMethodOCSPFallbackCRL, "OCSPFallbackCRL"}, + {RevocationMethod(999), "Unknown"}, // Test for default case + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := tt.method.String() + if result != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, result) + } + }) + } +} + func TestNewServerResult(t *testing.T) { expectedR := &ServerResult{ Result: ResultNonRevokable, diff --git a/revocation/revocation.go b/revocation/revocation.go index 25ac11da..a20c63c1 100644 --- a/revocation/revocation.go +++ b/revocation/revocation.go @@ -21,9 +21,12 @@ import ( "errors" "fmt" "net/http" + "sync" "time" - "github.com/notaryproject/notation-core-go/revocation/ocsp" + "github.com/notaryproject/notation-core-go/revocation/internal/crl" + "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" + "github.com/notaryproject/notation-core-go/revocation/internal/x509util" "github.com/notaryproject/notation-core-go/revocation/purpose" "github.com/notaryproject/notation-core-go/revocation/result" ) @@ -34,8 +37,9 @@ import ( // To perform revocation check, use [Validator]. type Revocation interface { // Validate checks the revocation status for a certificate chain using OCSP - // and returns an array of CertRevocationResults that contain the results - // and any errors that are encountered during the process + // and CRL if OCSP is not available. It returns an array of + // CertRevocationResults that contain the results and any errors that are + // encountered during the process Validate(certChain []*x509.Certificate, signingTime time.Time) ([]*result.CertRevocationResult, error) } @@ -64,7 +68,8 @@ type Validator interface { // revocation is an internal struct used for revocation checking type revocation struct { - httpClient *http.Client + ocspHTTPClient *http.Client + crlHTTPClient *http.Client certChainPurpose purpose.Purpose } @@ -77,7 +82,8 @@ func New(httpClient *http.Client) (Revocation, error) { return nil, errors.New("invalid input: a non-nil httpClient must be specified") } return &revocation{ - httpClient: httpClient, + ocspHTTPClient: httpClient, + crlHTTPClient: httpClient, certChainPurpose: purpose.CodeSigning, }, nil } @@ -89,6 +95,11 @@ type Options struct { // OPTIONAL. OCSPHTTPClient *http.Client + // CRLHTTPClient is the HTTP client for CRL request. If not provided, + // a default *http.Client with timeout of 5 seconds will be used. + // OPTIONAL. + CRLHTTPClient *http.Client + // CertChainPurpose is the purpose of the certificate chain. Supported // values are CodeSigning and Timestamping. Default value is CodeSigning. // OPTIONAL. @@ -101,6 +112,10 @@ func NewWithOptions(opts Options) (Validator, error) { opts.OCSPHTTPClient = &http.Client{Timeout: 2 * time.Second} } + if opts.CRLHTTPClient == nil { + opts.CRLHTTPClient = &http.Client{Timeout: 5 * time.Second} + } + switch opts.CertChainPurpose { case purpose.CodeSigning, purpose.Timestamping: default: @@ -108,17 +123,22 @@ func NewWithOptions(opts Options) (Validator, error) { } return &revocation{ - httpClient: opts.OCSPHTTPClient, + ocspHTTPClient: opts.OCSPHTTPClient, + crlHTTPClient: opts.CRLHTTPClient, certChainPurpose: opts.CertChainPurpose, }, nil } // Validate checks the revocation status for a certificate chain using OCSP and -// returns an array of CertRevocationResults that contain the results and any -// errors that are encountered during the process +// CRL if OCSP is not available. It returns an array of CertRevocationResults +// that contain the results and any errors that are encountered during the +// process. // -// TODO: add CRL support -// https://github.com/notaryproject/notation-core-go/issues/125 +// This function tries OCSP and falls back to CRL when: +// - OCSP is not supported by the certificate +// - OCSP returns an unknown status +// +// NOTE: The certificate chain is expected to be in the order of leaf to root. func (r *revocation) Validate(certChain []*x509.Certificate, signingTime time.Time) ([]*result.CertRevocationResult, error) { return r.ValidateContext(context.Background(), ValidateContextOptions{ CertChain: certChain, @@ -126,24 +146,114 @@ func (r *revocation) Validate(certChain []*x509.Certificate, signingTime time.Ti }) } -// ValidateContext checks the revocation status for a certificate chain using -// OCSP and returns an array of CertRevocationResults that contain the results -// and any errors that are encountered during the process +// ValidateContext checks the revocation status for a certificate chain using OCSP and +// CRL if OCSP is not available. It returns an array of CertRevocationResults +// that contain the results and any errors that are encountered during the +// process. +// +// This function tries OCSP and falls back to CRL when: +// - OCSP is not supported by the certificate +// - OCSP returns an unknown status // -// TODO: add CRL support -// https://github.com/notaryproject/notation-core-go/issues/125 +// NOTE: The certificate chain is expected to be in the order of leaf to root. func (r *revocation) ValidateContext(ctx context.Context, validateContextOpts ValidateContextOptions) ([]*result.CertRevocationResult, error) { + // validate certificate chain if len(validateContextOpts.CertChain) == 0 { return nil, result.InvalidChainError{Err: errors.New("chain does not contain any certificates")} } + certChain := validateContextOpts.CertChain + if err := x509util.ValidateChain(certChain, r.certChainPurpose); err != nil { + return nil, err + } - return ocsp.CheckStatus(ocsp.Options{ - CertChain: validateContextOpts.CertChain, - CertChainPurpose: r.certChainPurpose, - SigningTime: validateContextOpts.AuthenticSigningTime, - HTTPClient: r.httpClient, - }) + ocspOpts := ocsp.CertCheckStatusOptions{ + HTTPClient: r.ocspHTTPClient, + SigningTime: validateContextOpts.AuthenticSigningTime, + } + crlOpts := crl.CertCheckStatusOptions{ + HTTPClient: r.crlHTTPClient, + SigningTime: validateContextOpts.AuthenticSigningTime, + } + + // panicChain is used to store the panic in goroutine and handle it + panicChan := make(chan any, len(certChain)) + defer close(panicChan) + + certResults := make([]*result.CertRevocationResult, len(certChain)) + var wg sync.WaitGroup + for i, cert := range certChain[:len(certChain)-1] { + switch { + case ocsp.Supported(cert): + // do OCSP check for the certificate + wg.Add(1) + + go func(i int, cert *x509.Certificate) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + // catch panic and send it to panicChan to avoid + // losing the panic + panicChan <- r + } + }() + + ocspResult := ocsp.CertCheckStatus(cert, certChain[i+1], ocspOpts) + if ocspResult != nil && ocspResult.Result == result.ResultUnknown && crl.Supported(cert) { + // try CRL check if OCSP serverResult is unknown + serverResult := crl.CertCheckStatus(ctx, cert, certChain[i+1], crlOpts) + // append CRL result to OCSP result + serverResult.ServerResults = append(ocspResult.ServerResults, serverResult.ServerResults...) + serverResult.RevocationMethod = result.RevocationMethodOCSPFallbackCRL + certResults[i] = serverResult + } else { + certResults[i] = ocspResult + } + }(i, cert) + case crl.Supported(cert): + // do CRL check for the certificate + wg.Add(1) + + go func(i int, cert *x509.Certificate) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + // catch panic and send it to panicChan to avoid + // losing the panic + panicChan <- r + } + }() + + certResults[i] = crl.CertCheckStatus(ctx, cert, certChain[i+1], crlOpts) + }(i, cert) + default: + certResults[i] = &result.CertRevocationResult{ + Result: result.ResultNonRevokable, + ServerResults: []*result.ServerResult{{ + Result: result.ResultNonRevokable, + RevocationMethod: result.RevocationMethodUnknown, + }}, + RevocationMethod: result.RevocationMethodUnknown, + } + } + } + + // Last is root cert, which will never be revoked by OCSP or CRL + certResults[len(certChain)-1] = &result.CertRevocationResult{ + Result: result.ResultNonRevokable, + ServerResults: []*result.ServerResult{{ + Result: result.ResultNonRevokable, + RevocationMethod: result.RevocationMethodUnknown, + }}, + RevocationMethod: result.RevocationMethodUnknown, + } + wg.Wait() + + // handle panic + select { + case p := <-panicChan: + panic(p) + default: + } - // TODO: add CRL support - // https://github.com/notaryproject/notation-core-go/issues/125 + return certResults, nil } diff --git a/revocation/revocation_test.go b/revocation/revocation_test.go index f663880c..2ac8b4c9 100644 --- a/revocation/revocation_test.go +++ b/revocation/revocation_test.go @@ -14,15 +14,21 @@ package revocation import ( + "bytes" "context" + "crypto/rand" "crypto/x509" "errors" "fmt" + "io" + "math/big" "net/http" + "strconv" + "strings" "testing" "time" - revocationocsp "github.com/notaryproject/notation-core-go/revocation/ocsp" + revocationocsp "github.com/notaryproject/notation-core-go/revocation/internal/ocsp" "github.com/notaryproject/notation-core-go/revocation/purpose" "github.com/notaryproject/notation-core-go/revocation/result" "github.com/notaryproject/notation-core-go/testhelper" @@ -60,6 +66,9 @@ func validateEquivalentCertResults(certResults, expectedCertResults []*result.Ce t.Errorf("Expected certResults[%d].ServerResults[%d].Error to be %v, but got %v", i, j, expectedCertResults[i].ServerResults[j].Error, serverResult.Error) } } + if certResult.RevocationMethod != expectedCertResults[i].RevocationMethod { + t.Errorf("Expected certResults[%d].RevocationMethod to be %d, but got %d", i, expectedCertResults[i].RevocationMethod, certResult.RevocationMethod) + } } } @@ -69,6 +78,7 @@ func getOKCertResult(server string) *result.CertRevocationResult { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultOK, server, nil), }, + RevocationMethod: result.RevocationMethodOCSP, } } @@ -96,8 +106,8 @@ func TestNew(t *testing.T) { revR, ok := r.(*revocation) if !ok { t.Error("Expected New to create an object matching the internal revocation struct") - } else if revR.httpClient != client { - t.Errorf("Expected New to set client to %v, but it was set to %v", client, revR.httpClient) + } else if revR.ocspHTTPClient != client { + t.Errorf("Expected New to set client to %v, but it was set to %v", client, revR.ocspHTTPClient) } } @@ -161,6 +171,7 @@ func TestCheckRevocationStatusForSingleCert(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[0].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getRootCertResult(), } @@ -183,6 +194,7 @@ func TestCheckRevocationStatusForSingleCert(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[0].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getRootCertResult(), } @@ -305,6 +317,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -332,6 +345,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -359,6 +373,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), { @@ -366,6 +381,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[4].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getRootCertResult(), } @@ -415,6 +431,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -442,6 +459,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -475,6 +493,7 @@ func TestCheckRevocationStatusForChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -493,6 +512,18 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { revokableChain[i].NotBefore = zeroTime } + t.Run("invalid revocation purpose", func(t *testing.T) { + revocationClient := &revocation{ + ocspHTTPClient: &http.Client{Timeout: 5 * time.Second}, + certChainPurpose: -1, + } + + _, err := revocationClient.Validate(revokableChain, time.Now()) + if err == nil { + t.Error("Expected Validate to fail with an error, but it succeeded") + } + }) + t.Run("empty chain", func(t *testing.T) { r, err := NewWithOptions(Options{ OCSPHTTPClient: &http.Client{Timeout: 5 * time.Second}, @@ -564,6 +595,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -596,6 +628,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -628,6 +661,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), { @@ -635,6 +669,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[4].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getRootCertResult(), } @@ -694,6 +729,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, revokableChain[2].OCSPServer[0], revocationocsp.UnknownStatusError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -726,6 +762,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -762,6 +799,7 @@ func TestCheckRevocationStatusForTimestampChain(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultRevoked, revokableChain[2].OCSPServer[0], revocationocsp.RevokedError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(revokableChain[3].OCSPServer[0]), getOKCertResult(revokableChain[4].OCSPServer[0]), @@ -856,12 +894,14 @@ func TestCheckRevocationErrors(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, okChain[0].OCSPServer[0], revocationocsp.TimeoutError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, { Result: result.ResultUnknown, ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, okChain[1].OCSPServer[0], revocationocsp.TimeoutError{}), }, + RevocationMethod: result.RevocationMethodOCSP, }, getRootCertResult(), } @@ -884,6 +924,7 @@ func TestCheckRevocationErrors(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, expiredChain[0].OCSPServer[0], expiredRespErr), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(expiredChain[1].OCSPServer[0]), getRootCertResult(), @@ -926,6 +967,7 @@ func TestCheckRevocationErrors(t *testing.T) { ServerResults: []*result.ServerResult{ result.NewServerResult(result.ResultUnknown, noHTTPChain[0].OCSPServer[0], noHTTPErr), }, + RevocationMethod: result.RevocationMethodOCSP, }, getOKCertResult(noHTTPChain[1].OCSPServer[0]), getRootCertResult(), @@ -989,9 +1031,306 @@ func TestCheckRevocationInvalidChain(t *testing.T) { }) } +func TestCRL(t *testing.T) { + t.Run("CRL check valid", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(3, false, true) + + revocationClient, err := NewWithOptions(Options{ + CRLHTTPClient: &http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: false, + }, + }, + OCSPHTTPClient: &http.Client{}, + CertChainPurpose: purpose.CodeSigning, + }) + if err != nil { + t.Errorf("Expected successful creation of revocation, but received error: %v", err) + } + + certResults, err := revocationClient.ValidateContext(context.Background(), ValidateContextOptions{ + CertChain: []*x509.Certificate{chain[0].Cert, chain[1].Cert, chain[2].Cert}, + AuthenticSigningTime: time.Now(), + }) + if err != nil { + t.Errorf("Expected CheckStatus to succeed, but got error: %v", err) + } + + expectedCertResults := []*result.CertRevocationResult{ + { + Result: result.ResultOK, + ServerResults: []*result.ServerResult{{ + Result: result.ResultOK, + Server: "http://example.com/chain_crl/0", + }}, + RevocationMethod: result.RevocationMethodCRL, + }, + { + Result: result.ResultOK, + ServerResults: []*result.ServerResult{{ + Result: result.ResultOK, + Server: "http://example.com/chain_crl/1", + }}, + RevocationMethod: result.RevocationMethodCRL, + }, + getRootCertResult(), + } + + validateEquivalentCertResults(certResults, expectedCertResults, t) + }) + + t.Run("CRL check with revoked status", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(3, false, true) + + revocationClient, err := NewWithOptions(Options{ + CRLHTTPClient: &http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: true, + }, + }, + OCSPHTTPClient: &http.Client{}, + CertChainPurpose: purpose.CodeSigning, + }) + if err != nil { + t.Errorf("Expected successful creation of revocation, but received error: %v", err) + } + + certResults, err := revocationClient.ValidateContext(context.Background(), ValidateContextOptions{ + CertChain: []*x509.Certificate{ + chain[0].Cert, // leaf + chain[1].Cert, // intermediate + chain[2].Cert, // root + }, + AuthenticSigningTime: time.Now(), + }) + if err != nil { + t.Errorf("Expected CheckStatus to succeed, but got error: %v", err) + } + + expectedCertResults := []*result.CertRevocationResult{ + { + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{ + { + Result: result.ResultRevoked, + Server: "http://example.com/chain_crl/0", + }, + }, + RevocationMethod: result.RevocationMethodCRL, + }, + { + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{ + { + Result: result.ResultRevoked, + Server: "http://example.com/chain_crl/1", + }, + }, + RevocationMethod: result.RevocationMethodCRL, + }, + getRootCertResult(), + } + + validateEquivalentCertResults(certResults, expectedCertResults, t) + }) + + t.Run("OCSP fallback to CRL", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(3, true, true) + + revocationClient, err := NewWithOptions(Options{ + CRLHTTPClient: &http.Client{ + Timeout: 5 * time.Second, + Transport: &crlRoundTripper{ + CertChain: chain, + Revoked: true, + FailOCSP: true, + }, + }, + OCSPHTTPClient: &http.Client{}, + CertChainPurpose: purpose.CodeSigning, + }) + if err != nil { + t.Errorf("Expected successful creation of revocation, but received error: %v", err) + } + + certResults, err := revocationClient.ValidateContext(context.Background(), ValidateContextOptions{ + CertChain: []*x509.Certificate{ + chain[0].Cert, // leaf + chain[1].Cert, // intermediate + chain[2].Cert, // root + }, + AuthenticSigningTime: time.Now(), + }) + if err != nil { + t.Errorf("Expected CheckStatus to succeed, but got error: %v", err) + } + + expectedCertResults := []*result.CertRevocationResult{ + { + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{ + { + Result: result.ResultUnknown, + Server: "http://example.com/chain_ocsp/0", + Error: errors.New("failed to retrieve OCSP: response had status code 500"), + RevocationMethod: result.RevocationMethodOCSP, + }, + { + Result: result.ResultRevoked, + Server: "http://example.com/chain_crl/0", + RevocationMethod: result.RevocationMethodCRL, + }, + }, + RevocationMethod: result.RevocationMethodOCSPFallbackCRL, + }, + { + Result: result.ResultRevoked, + ServerResults: []*result.ServerResult{ + { + Result: result.ResultUnknown, + Server: "http://example.com/chain_ocsp/1", + Error: errors.New("failed to retrieve OCSP: response had status code 500"), + RevocationMethod: result.RevocationMethodOCSPFallbackCRL, + }, + { + Result: result.ResultRevoked, + Server: "http://example.com/chain_crl/1", + RevocationMethod: result.RevocationMethodCRL, + }, + }, + RevocationMethod: result.RevocationMethodOCSPFallbackCRL, + }, + getRootCertResult(), + } + + validateEquivalentCertResults(certResults, expectedCertResults, t) + }) +} + +func TestPanicHandling(t *testing.T) { + t.Run("panic in OCSP", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(2, true, false) + client := &http.Client{ + Transport: panicTransport{}, + } + + r, err := NewWithOptions(Options{ + OCSPHTTPClient: client, + CRLHTTPClient: client, + CertChainPurpose: purpose.CodeSigning, + }) + if err != nil { + t.Errorf("Expected successful creation of revocation, but received error: %v", err) + } + + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic, but got nil") + } + }() + _, _ = r.ValidateContext(context.Background(), ValidateContextOptions{ + CertChain: []*x509.Certificate{chain[0].Cert, chain[1].Cert}, + AuthenticSigningTime: time.Now(), + }) + + }) + + t.Run("panic in CRL", func(t *testing.T) { + chain := testhelper.GetRevokableRSAChainWithRevocations(2, false, true) + client := &http.Client{ + Transport: panicTransport{}, + } + + r, err := NewWithOptions(Options{ + OCSPHTTPClient: client, + CRLHTTPClient: client, + CertChainPurpose: purpose.CodeSigning, + }) + if err != nil { + t.Errorf("Expected successful creation of revocation, but received error: %v", err) + } + + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic, but got nil") + } + }() + _, _ = r.ValidateContext(context.Background(), ValidateContextOptions{ + CertChain: []*x509.Certificate{chain[0].Cert, chain[1].Cert}, + AuthenticSigningTime: time.Now(), + }) + }) +} + +type crlRoundTripper struct { + CertChain []testhelper.RSACertTuple + Revoked bool + FailOCSP bool +} + +func (rt *crlRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // e.g. ocsp URL: http://example.com/chain_ocsp/0 + // e.g. crl URL: http://example.com/chain_crl/0 + parts := strings.Split(req.URL.Path, "/") + + isOCSP := parts[len(parts)-2] == "chain_ocsp" + // fail OCSP + if rt.FailOCSP && isOCSP { + return nil, errors.New("OCSP failed") + } + + // choose the cert suffix based on suffix of request url + // e.g. http://example.com/chain_crl/0 -> 0 + i, err := strconv.Atoi(parts[len(parts)-1]) + if err != nil { + return nil, err + } + if i >= len(rt.CertChain) { + return nil, errors.New("invalid index") + } + + cert := rt.CertChain[i].Cert + crl := &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(20240720), + } + + if rt.Revoked { + crl.RevokedCertificateEntries = []x509.RevocationListEntry{ + { + SerialNumber: cert.SerialNumber, + RevocationTime: time.Now().Add(-time.Hour), + }, + } + } + + issuerCert := rt.CertChain[i+1].Cert + issuerKey := rt.CertChain[i+1].PrivateKey + crlBytes, err := x509.CreateRevocationList(rand.Reader, crl, issuerCert, issuerKey) + if err != nil { + return nil, err + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(crlBytes)), + }, nil +} + +type panicTransport struct{} + +func (t panicTransport) RoundTrip(req *http.Request) (*http.Response, error) { + panic("panic") +} + func TestValidateContext(t *testing.T) { r, err := NewWithOptions(Options{ - OCSPHTTPClient: &http.Client{}, + OCSPHTTPClient: &http.Client{}, + CertChainPurpose: purpose.CodeSigning, }) if err != nil { t.Fatal(err) diff --git a/testhelper/certificatetest.go b/testhelper/certificatetest.go index 54a31c0e..a06e5035 100644 --- a/testhelper/certificatetest.go +++ b/testhelper/certificatetest.go @@ -75,16 +75,31 @@ func GetRevokableRSALeafCertificate() RSACertTuple { return revokableRSALeaf } +// GetRevokableRSAChainWithRevocations returns a certificate chain with OCSP +// and CRL enabled for revocation checks. +func GetRevokableRSAChainWithRevocations(size int, enabledOCSP, enabledCRL bool) []RSACertTuple { + setupCertificates() + chain := make([]RSACertTuple, size) + chain[size-1] = getRevokableRSARootChainCertTuple("Notation Test Revokable RSA Chain Cert Root", size-1, enabledCRL) + for i := size - 2; i > 0; i-- { + chain[i] = getRevokableRSAChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size-i), &chain[i+1], i, enabledOCSP, enabledCRL) + } + if size > 1 { + chain[0] = getRevokableRSALeafChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size), &chain[1], 0, true, false, enabledOCSP, enabledCRL) + } + return chain +} + // GetRevokableRSAChain returns a chain of certificates that specify a local OCSP server signed using RSA algorithm func GetRevokableRSAChain(size int) []RSACertTuple { setupCertificates() chain := make([]RSACertTuple, size) - chain[size-1] = getRevokableRSARootChainCertTuple("Notation Test Revokable RSA Chain Cert Root", size-1) + chain[size-1] = getRevokableRSARootChainCertTuple("Notation Test Revokable RSA Chain Cert Root", size-1, false) for i := size - 2; i > 0; i-- { - chain[i] = getRevokableRSAChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size-i), &chain[i+1], i) + chain[i] = getRevokableRSAChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size-i), &chain[i+1], i, true, false) } if size > 1 { - chain[0] = getRevokableRSALeafChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size), &chain[1], 0, true, false) + chain[0] = getRevokableRSALeafChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size), &chain[1], 0, true, false, true, false) } return chain } @@ -94,12 +109,12 @@ func GetRevokableRSAChain(size int) []RSACertTuple { func GetRevokableRSATimestampChain(size int) []RSACertTuple { setupCertificates() chain := make([]RSACertTuple, size) - chain[size-1] = getRevokableRSARootChainCertTuple("Notation Test Revokable RSA Chain Cert Root", size-1) + chain[size-1] = getRevokableRSARootChainCertTuple("Notation Test Revokable RSA Chain Cert Root", size-1, false) for i := size - 2; i > 0; i-- { - chain[i] = getRevokableRSAChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size-i), &chain[i+1], i) + chain[i] = getRevokableRSAChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size-i), &chain[i+1], i, true, false) } if size > 1 { - chain[0] = getRevokableRSALeafChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size), &chain[1], 0, false, true) + chain[0] = getRevokableRSALeafChainCertTuple(fmt.Sprintf("Notation Test Revokable RSA Chain Cert %d", size), &chain[1], 0, false, true, true, false) } return chain } @@ -171,31 +186,45 @@ func getRevokableRSACertTuple(cn string, issuer *RSACertTuple) RSACertTuple { return getRSACertTupleWithTemplate(template, issuer.PrivateKey, issuer) } -func getRevokableRSAChainCertTuple(cn string, previous *RSACertTuple, index int) RSACertTuple { +func getRevokableRSAChainCertTuple(cn string, previous *RSACertTuple, index int, enabledOCSP, enabledCRL bool) RSACertTuple { template := getCertTemplate(previous == nil, true, false, cn) template.BasicConstraintsValid = true template.IsCA = true template.KeyUsage = x509.KeyUsageCertSign - template.OCSPServer = []string{fmt.Sprintf("http://example.com/chain_ocsp/%d", index)} + if enabledOCSP { + template.OCSPServer = []string{fmt.Sprintf("http://example.com/chain_ocsp/%d", index)} + } + if enabledCRL { + template.KeyUsage |= x509.KeyUsageCRLSign + template.CRLDistributionPoints = []string{fmt.Sprintf("http://example.com/chain_crl/%d", index)} + } return getRSACertTupleWithTemplate(template, previous.PrivateKey, previous) } -func getRevokableRSARootChainCertTuple(cn string, pathLen int) RSACertTuple { +func getRevokableRSARootChainCertTuple(cn string, pathLen int, enabledCRL bool) RSACertTuple { pk, _ := rsa.GenerateKey(rand.Reader, 3072) template := getCertTemplate(true, true, false, cn) template.BasicConstraintsValid = true template.IsCA = true template.KeyUsage = x509.KeyUsageCertSign + if enabledCRL { + template.KeyUsage |= x509.KeyUsageCRLSign + } template.MaxPathLen = pathLen return getRSACertTupleWithTemplate(template, pk, nil) } -func getRevokableRSALeafChainCertTuple(cn string, issuer *RSACertTuple, index int, codesign, timestamp bool) RSACertTuple { +func getRevokableRSALeafChainCertTuple(cn string, issuer *RSACertTuple, index int, codesign, timestamp, enabledOCSP, enabledCRL bool) RSACertTuple { template := getCertTemplate(false, codesign, timestamp, cn) template.BasicConstraintsValid = true template.IsCA = false template.KeyUsage = x509.KeyUsageDigitalSignature - template.OCSPServer = []string{fmt.Sprintf("http://example.com/chain_ocsp/%d", index)} + if enabledOCSP { + template.OCSPServer = []string{fmt.Sprintf("http://example.com/chain_ocsp/%d", index)} + } + if enabledCRL { + template.CRLDistributionPoints = []string{fmt.Sprintf("http://example.com/chain_crl/%d", index)} + } return getRSACertTupleWithTemplate(template, issuer.PrivateKey, issuer) }