Skip to content

Commit

Permalink
Merge pull request #259 from tri-adam/expose-ctx
Browse files Browse the repository at this point in the history
Expose Context in Sign/Verify
  • Loading branch information
tri-adam authored Feb 24, 2023
2 parents 7f906be + 1a4b1df commit 6394eaa
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 38 deletions.
7 changes: 4 additions & 3 deletions pkg/integrity/clearsign.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved.
// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file
// distributed with the sources of this project regarding your rights to use or distribute this
// software.
Expand All @@ -7,6 +7,7 @@ package integrity

import (
"bytes"
"context"
"crypto"
"errors"
"io"
Expand Down Expand Up @@ -37,7 +38,7 @@ func newClearsignEncoder(e *openpgp.Entity, timeFunc func() time.Time) *clearsig

// signMessage signs the message from r in clear-sign format, and writes the result to w. On
// success, the hash function is returned.
func (en *clearsignEncoder) signMessage(w io.Writer, r io.Reader) (crypto.Hash, error) {
func (en *clearsignEncoder) signMessage(ctx context.Context, w io.Writer, r io.Reader) (crypto.Hash, error) {
plaintext, err := clearsign.Encode(w, en.e.PrivateKey, en.config)
if err != nil {
return 0, err
Expand All @@ -62,7 +63,7 @@ func newClearsignDecoder(kr openpgp.KeyRing) *clearsignDecoder {

// verifyMessage reads a message from r, verifies its signature, and returns the message contents.
// On success, the signing entity is set in vr.
func (de *clearsignDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) {
func (de *clearsignDecoder) verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { //nolint:lll
data, err := io.ReadAll(r)
if err != nil {
return nil, err
Expand Down
7 changes: 4 additions & 3 deletions pkg/integrity/clearsign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package integrity
import (
"bufio"
"bytes"
"context"
"crypto"
"errors"
"io"
Expand Down Expand Up @@ -53,7 +54,7 @@ func Test_clearsignEncoder_signMessage(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
b := bytes.Buffer{}

ht, err := tt.en.signMessage(&b, strings.NewReader(testMessage))
ht, err := tt.en.signMessage(context.Background(), &b, strings.NewReader(testMessage))
if got, want := err, tt.wantErr; (got != nil) != want {
t.Fatalf("got error %v, wantErr %v", got, want)
}
Expand Down Expand Up @@ -173,7 +174,7 @@ func Test_clearsignDecoder_verifyMessage(t *testing.T) {
Time: fixedTime,
},
}
h, err := en.signMessage(&b, strings.NewReader(testMessage))
h, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage))
if err != nil {
t.Fatal(err)
}
Expand All @@ -189,7 +190,7 @@ func Test_clearsignDecoder_verifyMessage(t *testing.T) {

// Decode and verify message.
var vr VerifyResult
message, err := tt.de.verifyMessage(bytes.NewReader(b.Bytes()), h, &vr)
message, err := tt.de.verifyMessage(context.Background(), bytes.NewReader(b.Bytes()), h, &vr)

if got, want := err, tt.wantErr; !errors.Is(got, want) {
t.Fatalf("got error %v, want %v", got, want)
Expand Down
8 changes: 4 additions & 4 deletions pkg/integrity/dsse.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func newDSSEEncoder(ss []signature.Signer, opts ...signature.SignOption) (*dsseE

// signMessage signs the message from r in DSSE format, and writes the result to w. On success, the
// hash function is returned.
func (en *dsseEncoder) signMessage(w io.Writer, r io.Reader) (crypto.Hash, error) {
func (en *dsseEncoder) signMessage(ctx context.Context, w io.Writer, r io.Reader) (crypto.Hash, error) {
body, err := io.ReadAll(r)
if err != nil {
return 0, err
}

e, err := en.es.SignPayload(context.TODO(), en.payloadType, body)
e, err := en.es.SignPayload(ctx, en.payloadType, body)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -102,7 +102,7 @@ var (

// verifyMessage reads a message from r, verifies its signature(s), and returns the message
// contents. On success, the accepted public keys are set in vr.
func (de *dsseDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) {
func (de *dsseDecoder) verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error) { //nolint:lll
vs := make([]dsse.Verifier, 0, len(de.vs))
for _, v := range de.vs {
dv, err := newDSSEVerifier(v, options.WithCryptoSignerOpts(h))
Expand All @@ -123,7 +123,7 @@ func (de *dsseDecoder) verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResul
return nil, err
}

vr.aks, err = v.Verify(context.TODO(), &e)
vr.aks, err = v.Verify(ctx, &e)
if err != nil {
//nolint:errorlint // Go 1.19 compatibility
return nil, fmt.Errorf("%w: %v", errDSSEVerifyEnvelopeFailed, err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/integrity/dsse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func Test_dsseEncoder_signMessage(t *testing.T) {
t.Fatal(err)
}

ht, err := en.signMessage(&b, strings.NewReader(testMessage))
ht, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage))
if got, want := err, tt.wantErr; (got != nil) != want {
t.Fatalf("got error %v, wantErr %v", got, want)
}
Expand Down Expand Up @@ -323,7 +323,7 @@ func Test_dsseDecoder_verifyMessage(t *testing.T) {
}

// Sign and encode message.
h, err := en.signMessage(&b, strings.NewReader(testMessage))
h, err := en.signMessage(context.Background(), &b, strings.NewReader(testMessage))
if err != nil {
t.Fatal(err)
}
Expand All @@ -345,7 +345,7 @@ func Test_dsseDecoder_verifyMessage(t *testing.T) {

// Decode and verify message.
var vr VerifyResult
message, err := tt.de.verifyMessage(bytes.NewReader(b.Bytes()), h, &vr)
message, err := tt.de.verifyMessage(context.Background(), bytes.NewReader(b.Bytes()), h, &vr)

if got, want := err, tt.wantErr; !errors.Is(got, want) {
t.Errorf("got error %v, want %v", got, want)
Expand Down
21 changes: 16 additions & 5 deletions pkg/integrity/sign.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved.
// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file
// distributed with the sources of this project regarding your rights to use or distribute this
// software.
Expand All @@ -7,6 +7,7 @@ package integrity

import (
"bytes"
"context"
"crypto"
"encoding/json"
"errors"
Expand All @@ -32,7 +33,7 @@ var ErrNoKeyMaterial = errors.New("key material not provided")
type encoder interface {
// signMessage signs the message from r, and writes the result to w. On success, the signature
// hash function is returned.
signMessage(w io.Writer, r io.Reader) (ht crypto.Hash, err error)
signMessage(ctx context.Context, w io.Writer, r io.Reader) (ht crypto.Hash, err error)
}

type groupSigner struct {
Expand Down Expand Up @@ -149,7 +150,7 @@ func (gs *groupSigner) addObject(od sif.Descriptor) error {
}

// sign creates a digital signature as specified by gs.
func (gs *groupSigner) sign() (sif.DescriptorInput, error) {
func (gs *groupSigner) sign(ctx context.Context) (sif.DescriptorInput, error) {
// Get minimum object ID in group. Object IDs in the image metadata will be relative to this.
minID, err := getGroupMinObjectID(gs.f, gs.id)
if err != nil {
Expand All @@ -170,7 +171,7 @@ func (gs *groupSigner) sign() (sif.DescriptorInput, error) {

// Sign image metadata.
b := bytes.Buffer{}
ht, err := gs.en.signMessage(&b, bytes.NewReader(enc))
ht, err := gs.en.signMessage(ctx, &b, bytes.NewReader(enc))
if err != nil {
return sif.DescriptorInput{}, fmt.Errorf("failed to sign message: %w", err)
}
Expand All @@ -190,6 +191,7 @@ type signOpts struct {
objectIDs [][]uint32
timeFunc func() time.Time
deterministic bool
ctx context.Context //nolint:containedctx
}

// SignerOpt are used to configure so.
Expand Down Expand Up @@ -253,6 +255,14 @@ func OptSignDeterministic() SignerOpt {
}
}

// OptSignWithContext specifies that the given context should be used in RPC to external services.
func OptSignWithContext(ctx context.Context) SignerOpt {
return func(so *signOpts) error {
so.ctx = ctx
return nil
}
}

// withGroupedObjects splits the objects represented by ids into object groups, and calls fn once
// per object group.
func withGroupedObjects(f *sif.FileImage, ids []uint32, fn func(uint32, []uint32) error) error {
Expand Down Expand Up @@ -309,6 +319,7 @@ func NewSigner(f *sif.FileImage, opts ...SignerOpt) (*Signer, error) {

so := signOpts{
timeFunc: time.Now,
ctx: context.Background(),
}

// Apply options.
Expand Down Expand Up @@ -391,7 +402,7 @@ func NewSigner(f *sif.FileImage, opts ...SignerOpt) (*Signer, error) {
// Sign adds digital signatures as specified by s.
func (s *Signer) Sign() error {
for _, gs := range s.signers {
di, err := gs.sign()
di, err := gs.sign(s.opts.ctx)
if err != nil {
return fmt.Errorf("integrity: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/integrity/sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package integrity

import (
"bytes"
"context"
"crypto"
"errors"
"os"
Expand Down Expand Up @@ -341,7 +342,7 @@ func TestGroupSigner_Sign(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
di, err := tt.gs.sign()
di, err := tt.gs.sign(context.Background())
if (err != nil) != tt.wantErr {
t.Fatalf("got error %v, want %v", err, tt.wantErr)
}
Expand Down
43 changes: 28 additions & 15 deletions pkg/integrity/verify.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, Sylabs Inc. All rights reserved.
// Copyright (c) 2020-2023, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file
// distributed with the sources of this project regarding your rights to use or distribute this
// software.
Expand All @@ -7,6 +7,7 @@ package integrity

import (
"bytes"
"context"
"crypto"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -110,14 +111,14 @@ func (v *groupVerifier) signatures() ([]sif.Descriptor, error) {
// If verification of the SIF global header fails, ErrHeaderIntegrity is returned. If verification
// of a data object descriptor fails, a DescriptorIntegrityError is returned. If verification of a
// data object fails, a ObjectIntegrityError is returned.
func (v *groupVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error {
func (v *groupVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error {
ht, fp, err := sig.SignatureMetadata()
if err != nil {
return err
}

// Verify signature and decode message.
b, err := de.verifyMessage(sig.GetReader(), ht, vr)
b, err := de.verifyMessage(ctx, sig.GetReader(), ht, vr)
if err != nil {
return &SignatureNotValidError{ID: sig.ID(), Err: err}
}
Expand Down Expand Up @@ -181,9 +182,9 @@ func (v *legacyGroupVerifier) signatures() ([]sif.Descriptor, error) {
// If an invalid signature is encountered, a SignatureNotValidError is returned.
//
// If verification of a data object fails, a ObjectIntegrityError is returned.
func (v *legacyGroupVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error {
func (v *legacyGroupVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { //nolint:lll
// Verify signature and decode message.
b, err := de.verifyMessage(sig.GetReader(), crypto.SHA256, vr)
b, err := de.verifyMessage(ctx, sig.GetReader(), crypto.SHA256, vr)
if err != nil {
return &SignatureNotValidError{ID: sig.ID(), Err: err}
}
Expand Down Expand Up @@ -246,9 +247,9 @@ func (v *legacyObjectVerifier) signatures() ([]sif.Descriptor, error) {
// If an invalid signature is encountered, a SignatureNotValidError is returned.
//
// If verification of a data object fails, a ObjectIntegrityError is returned.
func (v *legacyObjectVerifier) verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error {
func (v *legacyObjectVerifier) verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error { //nolint:lll
// Verify signature and decode message.
b, err := de.verifyMessage(sig.GetReader(), crypto.SHA256, vr)
b, err := de.verifyMessage(ctx, sig.GetReader(), crypto.SHA256, vr)
if err != nil {
return &SignatureNotValidError{ID: sig.ID(), Err: err}
}
Expand Down Expand Up @@ -285,7 +286,7 @@ func (v *legacyObjectVerifier) verifySignature(sig sif.Descriptor, de decoder, v
type decoder interface {
// verifyMessage reads a message from r, verifies its signature, and returns the message
// contents.
verifyMessage(r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error)
verifyMessage(ctx context.Context, r io.Reader, h crypto.Hash, vr *VerifyResult) ([]byte, error)
}

type verifyTask interface {
Expand All @@ -301,7 +302,7 @@ type verifyTask interface {
// If verification of the SIF global header fails, ErrHeaderIntegrity is returned. If
// verification of a data object descriptor fails, a DescriptorIntegrityError is returned. If
// verification of a data object fails, a ObjectIntegrityError is returned.
verifySignature(sig sif.Descriptor, de decoder, vr *VerifyResult) error
verifySignature(ctx context.Context, sig sif.Descriptor, de decoder, vr *VerifyResult) error
}

type verifyOpts struct {
Expand All @@ -311,6 +312,7 @@ type verifyOpts struct {
objects []uint32
isLegacy bool
isLegacyAll bool
ctx context.Context //nolint:containedctx
cb VerifyCallback
}

Expand Down Expand Up @@ -385,6 +387,15 @@ func OptVerifyLegacyAll() VerifierOpt {
}
}

// OptVerifyWithContext specifies that the given context should be used in RPC to external
// services.
func OptVerifyWithContext(ctx context.Context) VerifierOpt {
return func(vo *verifyOpts) error {
vo.ctx = ctx
return nil
}
}

// OptVerifyCallback registers cb as the verification callback, which is called after each
// signature is verified.
func OptVerifyCallback(cb VerifyCallback) VerifierOpt {
Expand Down Expand Up @@ -449,10 +460,10 @@ func getLegacyTasks(f *sif.FileImage, groupIDs, objectIDs []uint32) ([]verifyTas
// Verifier describes a SIF image verifier.
type Verifier struct {
f *sif.FileImage
opts verifyOpts
tasks []verifyTask
dsse decoder
cs decoder
cb VerifyCallback
}

// NewVerifier returns a Verifier to examine and/or verify digital signatures(s) in f according to
Expand All @@ -470,7 +481,9 @@ func NewVerifier(f *sif.FileImage, opts ...VerifierOpt) (*Verifier, error) {
return nil, fmt.Errorf("integrity: %w", errNilFileImage)
}

vo := verifyOpts{}
vo := verifyOpts{
ctx: context.Background(),
}

// Apply options.
for _, o := range opts {
Expand Down Expand Up @@ -510,8 +523,8 @@ func NewVerifier(f *sif.FileImage, opts ...VerifierOpt) (*Verifier, error) {

v := Verifier{
f: f,
opts: vo,
tasks: t,
cb: vo.cb,
}

if vo.vs != nil {
Expand Down Expand Up @@ -645,12 +658,12 @@ func (v *Verifier) Verify() error {
vr := VerifyResult{sig: sig}

// Verify signature.
err := t.verifySignature(sig, de, &vr)
err := t.verifySignature(v.opts.ctx, sig, de, &vr)

// Call verify callback, if applicable.
if v.cb != nil {
if v.opts.cb != nil {
vr.err = err
if ignoreError := v.cb(vr); ignoreError {
if ignoreError := v.opts.cb(vr); ignoreError {
err = nil
}
}
Expand Down
Loading

0 comments on commit 6394eaa

Please sign in to comment.