From 1a4b1df187debea5d7b5074acebae36e1fa1a5f9 Mon Sep 17 00:00:00 2001 From: Adam Hughes <9903835+tri-adam@users.noreply.github.com> Date: Thu, 16 Feb 2023 21:47:06 +0000 Subject: [PATCH] feat: add OptSignWithContext/OptVerifyWithContext --- pkg/integrity/clearsign.go | 7 +++--- pkg/integrity/clearsign_test.go | 7 +++--- pkg/integrity/dsse.go | 8 +++--- pkg/integrity/dsse_test.go | 6 ++--- pkg/integrity/sign.go | 21 ++++++++++++---- pkg/integrity/sign_test.go | 3 ++- pkg/integrity/verify.go | 43 +++++++++++++++++++++------------ pkg/integrity/verify_test.go | 9 ++++--- 8 files changed, 66 insertions(+), 38 deletions(-) diff --git a/pkg/integrity/clearsign.go b/pkg/integrity/clearsign.go index f79d3485..c1173ddd 100644 --- a/pkg/integrity/clearsign.go +++ b/pkg/integrity/clearsign.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved. +// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved. // This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file // distributed with the sources of this project regarding your rights to use or distribute this // software. @@ -7,6 +7,7 @@ package integrity import ( "bytes" + "context" "crypto" "errors" "io" @@ -37,7 +38,7 @@ func newClearsignEncoder(e *openpgp.Entity, timeFunc func() time.Time) *clearsig // signMessage signs the message from r in clear-sign format, and writes the result to w. On // success, the hash function is returned. -func (en *clearsignEncoder) signMessage(w io.Writer, r io.Reader) (crypto.Hash, error) { +func (en *clearsignEncoder) signMessage(ctx context.Context, w io.Writer, r io.Reader) (crypto.Hash, error) { plaintext, err := clearsign.Encode(w, en.e.PrivateKey, en.config) if err != nil { return 0, err @@ -62,7 +63,7 @@ func newClearsignDecoder(kr openpgp.KeyRing) *clearsignDecoder { // verifyMessage reads a message from r, verifies its signature, and returns the message contents. // On success, the signing entity is set in vr. -func (de *clearsignDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { +func (de *clearsignDecoder) verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { //nolint:lll data, err := io.ReadAll(r) if err != nil { return nil, err diff --git a/pkg/integrity/clearsign_test.go b/pkg/integrity/clearsign_test.go index e3c5f29a..fe7188cb 100644 --- a/pkg/integrity/clearsign_test.go +++ b/pkg/integrity/clearsign_test.go @@ -8,6 +8,7 @@ package integrity import ( "bufio" "bytes" + "context" "crypto" "errors" "io" @@ -53,7 +54,7 @@ func Test_clearsignEncoder_signMessage(t *testing.T) { t.Run(tt.name, func(t *testing.T) { b := bytes.Buffer{} - ht, err := tt.en.signMessage(&b, strings.NewReader(testMessage)) + ht, err := tt.en.signMessage(context.Background(), &b, strings.NewReader(testMessage)) if got, want := err, tt.wantErr; (got != nil) != want { t.Fatalf("got error %v, wantErr %v", got, want) } @@ -173,7 +174,7 @@ func Test_clearsignDecoder_verifyMessage(t *testing.T) { Time: fixedTime, }, } - h, err := en.signMessage(&b, strings.NewReader(testMessage)) + h, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage)) if err != nil { t.Fatal(err) } @@ -189,7 +190,7 @@ func Test_clearsignDecoder_verifyMessage(t *testing.T) { // Decode and verify message. var vr VerifyResult - message, err := tt.de.verifyMessage(bytes.NewReader(b.Bytes()), h, &vr) + message, err := tt.de.verifyMessage(context.Background(), bytes.NewReader(b.Bytes()), h, &vr) if got, want := err, tt.wantErr; !errors.Is(got, want) { t.Fatalf("got error %v, want %v", got, want) diff --git a/pkg/integrity/dsse.go b/pkg/integrity/dsse.go index f0430044..777b2d7a 100644 --- a/pkg/integrity/dsse.go +++ b/pkg/integrity/dsse.go @@ -65,13 +65,13 @@ func newDSSEEncoder(ss []signature.Signer, opts ...signature.SignOption) (*dsseE // signMessage signs the message from r in DSSE format, and writes the result to w. On success, the // hash function is returned. -func (en *dsseEncoder) signMessage(w io.Writer, r io.Reader) (crypto.Hash, error) { +func (en *dsseEncoder) signMessage(ctx context.Context, w io.Writer, r io.Reader) (crypto.Hash, error) { body, err := io.ReadAll(r) if err != nil { return 0, err } - e, err := en.es.SignPayload(context.TODO(), en.payloadType, body) + e, err := en.es.SignPayload(ctx, en.payloadType, body) if err != nil { return 0, err } @@ -102,7 +102,7 @@ var ( // verifyMessage reads a message from r, verifies its signature(s), and returns the message // contents. On success, the accepted public keys are set in vr. -func (de *dsseDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { +func (de *dsseDecoder) verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { //nolint:lll vs := make([]dsse.Verifier, 0, len(de.vs)) for _, v := range de.vs { dv, err := newDSSEVerifier(v, options.WithCryptoSignerOpts(h)) @@ -123,7 +123,7 @@ func (de *dsseDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResul return nil, err } - vr.aks, err = v.Verify(context.TODO(), &e) + vr.aks, err = v.Verify(ctx, &e) if err != nil { //nolint:errorlint // Go 1.19 compatibility return nil, fmt.Errorf("%w: %v", errDSSEVerifyEnvelopeFailed, err) diff --git a/pkg/integrity/dsse_test.go b/pkg/integrity/dsse_test.go index 14462c84..46d43f8c 100644 --- a/pkg/integrity/dsse_test.go +++ b/pkg/integrity/dsse_test.go @@ -87,7 +87,7 @@ func Test_dsseEncoder_signMessage(t *testing.T) { t.Fatal(err) } - ht, err := en.signMessage(&b, strings.NewReader(testMessage)) + ht, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage)) if got, want := err, tt.wantErr; (got != nil) != want { t.Fatalf("got error %v, wantErr %v", got, want) } @@ -323,7 +323,7 @@ func Test_dsseDecoder_verifyMessage(t *testing.T) { } // Sign and encode message. - h, err := en.signMessage(&b, strings.NewReader(testMessage)) + h, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage)) if err != nil { t.Fatal(err) } @@ -345,7 +345,7 @@ func Test_dsseDecoder_verifyMessage(t *testing.T) { // Decode and verify message. var vr VerifyResult - message, err := tt.de.verifyMessage(bytes.NewReader(b.Bytes()), h, &vr) + message, err := tt.de.verifyMessage(context.Background(), bytes.NewReader(b.Bytes()), h, &vr) if got, want := err, tt.wantErr; !errors.Is(got, want) { t.Errorf("got error %v, want %v", got, want) diff --git a/pkg/integrity/sign.go b/pkg/integrity/sign.go index 384df489..d38d74f7 100644 --- a/pkg/integrity/sign.go +++ b/pkg/integrity/sign.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved. +// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved. // This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file // distributed with the sources of this project regarding your rights to use or distribute this // software. @@ -7,6 +7,7 @@ package integrity import ( "bytes" + "context" "crypto" "encoding/json" "errors" @@ -32,7 +33,7 @@ var ErrNoKeyMaterial = errors.New("key material not provided") type encoder interface { // signMessage signs the message from r, and writes the result to w. On success, the signature // hash function is returned. - signMessage(w io.Writer, r io.Reader) (ht crypto.Hash, err error) + signMessage(ctx context.Context, w io.Writer, r io.Reader) (ht crypto.Hash, err error) } type groupSigner struct { @@ -149,7 +150,7 @@ func (gs *groupSigner) addObject(od sif.Descriptor) error { } // sign creates a digital signature as specified by gs. -func (gs *groupSigner) sign() (sif.DescriptorInput, error) { +func (gs *groupSigner) sign(ctx context.Context) (sif.DescriptorInput, error) { // Get minimum object ID in group. Object IDs in the image metadata will be relative to this. minID, err := getGroupMinObjectID(gs.f, gs.id) if err != nil { @@ -170,7 +171,7 @@ func (gs *groupSigner) sign() (sif.DescriptorInput, error) { // Sign image metadata. b := bytes.Buffer{} - ht, err := gs.en.signMessage(&b, bytes.NewReader(enc)) + ht, err := gs.en.signMessage(ctx, &b, bytes.NewReader(enc)) if err != nil { return sif.DescriptorInput{}, fmt.Errorf("failed to sign message: %w", err) } @@ -190,6 +191,7 @@ type signOpts struct { objectIDs [][]uint32 timeFunc func() time.Time deterministic bool + ctx context.Context //nolint:containedctx } // SignerOpt are used to configure so. @@ -253,6 +255,14 @@ func OptSignDeterministic() SignerOpt { } } +// OptSignWithContext specifies that the given context should be used in RPC to external services. +func OptSignWithContext(ctx context.Context) SignerOpt { + return func(so *signOpts) error { + so.ctx = ctx + return nil + } +} + // withGroupedObjects splits the objects represented by ids into object groups, and calls fn once // per object group. func withGroupedObjects(f *sif.FileImage, ids []uint32, fn func(uint32, []uint32) error) error { @@ -309,6 +319,7 @@ func NewSigner(f *sif.FileImage, opts ...SignerOpt) (*Signer, error) { so := signOpts{ timeFunc: time.Now, + ctx: context.Background(), } // Apply options. @@ -391,7 +402,7 @@ func NewSigner(f *sif.FileImage, opts ...SignerOpt) (*Signer, error) { // Sign adds digital signatures as specified by s. func (s *Signer) Sign() error { for _, gs := range s.signers { - di, err := gs.sign() + di, err := gs.sign(s.opts.ctx) if err != nil { return fmt.Errorf("integrity: %w", err) } diff --git a/pkg/integrity/sign_test.go b/pkg/integrity/sign_test.go index 782389d3..b1a15f10 100644 --- a/pkg/integrity/sign_test.go +++ b/pkg/integrity/sign_test.go @@ -7,6 +7,7 @@ package integrity import ( "bytes" + "context" "crypto" "errors" "os" @@ -341,7 +342,7 @@ func TestGroupSigner_Sign(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - di, err := tt.gs.sign() + di, err := tt.gs.sign(context.Background()) if (err != nil) != tt.wantErr { t.Fatalf("got error %v, want %v", err, tt.wantErr) } diff --git a/pkg/integrity/verify.go b/pkg/integrity/verify.go index cc882669..e64e6c99 100644 --- a/pkg/integrity/verify.go +++ b/pkg/integrity/verify.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved. +// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved. // This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file // distributed with the sources of this project regarding your rights to use or distribute this // software. @@ -7,6 +7,7 @@ package integrity import ( "bytes" + "context" "crypto" "encoding/hex" "encoding/json" @@ -110,14 +111,14 @@ func (v *groupVerifier) signatures() ([]sif.Descriptor, error) { // If verification of the SIF global header fails, ErrHeaderIntegrity is returned. If verification // of a data object descriptor fails, a DescriptorIntegrityError is returned. If verification of a // data object fails, a ObjectIntegrityError is returned. -func (v *groupVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error { +func (v *groupVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { ht, fp, err := sig.SignatureMetadata() if err != nil { return err } // Verify signature and decode message. - b, err := de.verifyMessage(sig.GetReader(), ht, vr) + b, err := de.verifyMessage(ctx, sig.GetReader(), ht, vr) if err != nil { return &SignatureNotValidError{ID: sig.ID(), Err: err} } @@ -181,9 +182,9 @@ func (v *legacyGroupVerifier) signatures() ([]sif.Descriptor, error) { // If an invalid signature is encountered, a SignatureNotValidError is returned. // // If verification of a data object fails, a ObjectIntegrityError is returned. -func (v *legacyGroupVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error { +func (v *legacyGroupVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { //nolint:lll // Verify signature and decode message. - b, err := de.verifyMessage(sig.GetReader(), crypto.SHA256, vr) + b, err := de.verifyMessage(ctx, sig.GetReader(), crypto.SHA256, vr) if err != nil { return &SignatureNotValidError{ID: sig.ID(), Err: err} } @@ -246,9 +247,9 @@ func (v *legacyObjectVerifier) signatures() ([]sif.Descriptor, error) { // If an invalid signature is encountered, a SignatureNotValidError is returned. // // If verification of a data object fails, a ObjectIntegrityError is returned. -func (v *legacyObjectVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error { +func (v *legacyObjectVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { //nolint:lll // Verify signature and decode message. - b, err := de.verifyMessage(sig.GetReader(), crypto.SHA256, vr) + b, err := de.verifyMessage(ctx, sig.GetReader(), crypto.SHA256, vr) if err != nil { return &SignatureNotValidError{ID: sig.ID(), Err: err} } @@ -285,7 +286,7 @@ func (v *legacyObjectVerifier) verifySignature(sig sif.Descriptor, de decoder, v type decoder interface { // verifyMessage reads a message from r, verifies its signature, and returns the message // contents. - verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) + verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) } type verifyTask interface { @@ -301,7 +302,7 @@ type verifyTask interface { // If verification of the SIF global header fails, ErrHeaderIntegrity is returned. If // verification of a data object descriptor fails, a DescriptorIntegrityError is returned. If // verification of a data object fails, a ObjectIntegrityError is returned. - verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error + verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error } type verifyOpts struct { @@ -311,6 +312,7 @@ type verifyOpts struct { objects []uint32 isLegacy bool isLegacyAll bool + ctx context.Context //nolint:containedctx cb VerifyCallback } @@ -385,6 +387,15 @@ func OptVerifyLegacyAll() VerifierOpt { } } +// OptVerifyWithContext specifies that the given context should be used in RPC to external +// services. +func OptVerifyWithContext(ctx context.Context) VerifierOpt { + return func(vo *verifyOpts) error { + vo.ctx = ctx + return nil + } +} + // OptVerifyCallback registers cb as the verification callback, which is called after each // signature is verified. func OptVerifyCallback(cb VerifyCallback) VerifierOpt { @@ -449,10 +460,10 @@ func getLegacyTasks(f *sif.FileImage, groupIDs, objectIDs []uint32) ([]verifyTas // Verifier describes a SIF image verifier. type Verifier struct { f *sif.FileImage + opts verifyOpts tasks []verifyTask dsse decoder cs decoder - cb VerifyCallback } // NewVerifier returns a Verifier to examine and/or verify digital signatures(s) in f according to @@ -470,7 +481,9 @@ func NewVerifier(f *sif.FileImage, opts ...VerifierOpt) (*Verifier, error) { return nil, fmt.Errorf("integrity: %w", errNilFileImage) } - vo := verifyOpts{} + vo := verifyOpts{ + ctx: context.Background(), + } // Apply options. for _, o := range opts { @@ -510,8 +523,8 @@ func NewVerifier(f *sif.FileImage, opts ...VerifierOpt) (*Verifier, error) { v := Verifier{ f: f, + opts: vo, tasks: t, - cb: vo.cb, } if vo.vs != nil { @@ -645,12 +658,12 @@ func (v *Verifier) Verify() error { vr := VerifyResult{sig: sig} // Verify signature. - err := t.verifySignature(sig, de, &vr) + err := t.verifySignature(v.opts.ctx, sig, de, &vr) // Call verify callback, if applicable. - if v.cb != nil { + if v.opts.cb != nil { vr.err = err - if ignoreError := v.cb(vr); ignoreError { + if ignoreError := v.opts.cb(vr); ignoreError { err = nil } } diff --git a/pkg/integrity/verify_test.go b/pkg/integrity/verify_test.go index 28a041bd..64479448 100644 --- a/pkg/integrity/verify_test.go +++ b/pkg/integrity/verify_test.go @@ -6,6 +6,7 @@ package integrity import ( + "context" "crypto" "errors" "io" @@ -162,7 +163,7 @@ func TestGroupVerifier_verify(t *testing.T) { } var vr VerifyResult - err := v.verifySignature(tt.sig, tt.de, &vr) + err := v.verifySignature(context.Background(), tt.sig, tt.de, &vr) if got, want := err, tt.wantErr; !errors.Is(got, want) { t.Errorf("got error %v, want %v", got, want) @@ -292,7 +293,7 @@ func TestLegacyGroupVerifier_verify(t *testing.T) { } var vr VerifyResult - err = v.verifySignature(tt.sig, tt.de, &vr) + err = v.verifySignature(context.Background(), tt.sig, tt.de, &vr) if got, want := err, tt.wantErr; !errors.Is(got, want) { t.Errorf("got error %v, want %v", got, want) @@ -432,7 +433,7 @@ func TestLegacyObjectVerifier_verify(t *testing.T) { } var vr VerifyResult - err = v.verifySignature(tt.sig, tt.de, &vr) + err = v.verifySignature(context.Background(), tt.sig, tt.de, &vr) if got, want := err, tt.wantErr; !errors.Is(got, want) { t.Errorf("got error %v, want %v", got, want) @@ -717,7 +718,7 @@ func (v mockVerifier) signatures() ([]sif.Descriptor, error) { return v.sigs, v.sigsErr } -func (v mockVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error { +func (v mockVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { vr.verified = v.verified vr.e = v.e return v.verifyErr