Skip to content

Commit

Permalink
coordinator: move userapi into authority (2)
Browse files Browse the repository at this point in the history
Makes the necessary adjustments to userapi and recoveryapi after moving
them from main to the authority package.
  • Loading branch information
burgerdev committed Jun 19, 2024
1 parent 20bbb81 commit cf571cd
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 488 deletions.
34 changes: 17 additions & 17 deletions coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/crypto"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/recoveryapi"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/kds"
"github.com/google/go-sev-guest/proto/sevsnp"
Expand Down Expand Up @@ -55,6 +57,9 @@ type Authority struct {
bundlesMux sync.RWMutex
logger *slog.Logger
metrics metrics

userapi.UnimplementedUserAPIServer
recoveryapi.UnimplementedRecoveryAPIServer
}

type metrics struct {
Expand Down Expand Up @@ -187,10 +192,10 @@ func (m *Authority) GetCertBundle(peerPublicKeyHashStr string) (Bundle, error) {
return bundle, nil
}

// GetManifestsAndLatestCA retrieves the manifest history and the currently active CA instance.
// getManifestsAndLatestCA retrieves the manifest history and the currently active CA instance.
//
// If no manifest is configured, it returns an error.
func (m *Authority) GetManifestsAndLatestCA() ([]*manifest.Manifest, *ca.CA, error) {
func (m *Authority) getManifestsAndLatestCA() ([]*manifest.Manifest, *ca.CA, error) {
if m.se.Load() == nil {
return nil, nil, ErrNoManifest
}
Expand Down Expand Up @@ -223,8 +228,8 @@ func (m *Authority) GetManifestsAndLatestCA() ([]*manifest.Manifest, *ca.CA, err
return manifests, state.ca, nil
}

// SetManifest updates the active manifest.
func (m *Authority) SetManifest(manifestBytes []byte, policies [][]byte) (*ca.CA, error) {
// setManifest updates the active manifest.
func (m *Authority) setManifest(manifestBytes []byte, policies [][]byte) (*ca.CA, error) {
if err := m.createSeedEngine(); err != nil {
return nil, fmt.Errorf("creating SeedEngine: %w", err)
}
Expand Down Expand Up @@ -314,8 +319,8 @@ func (m *Authority) SetManifest(manifestBytes []byte, policies [][]byte) (*ca.CA
return ca, nil
}

// LatestManifest retrieves the active manifest.
func (m *Authority) LatestManifest() (*manifest.Manifest, error) {
// latestManifest retrieves the active manifest.
func (m *Authority) latestManifest() (*manifest.Manifest, error) {
if m.se.Load() == nil {
return nil, ErrNoManifest
}
Expand All @@ -329,19 +334,14 @@ func (m *Authority) LatestManifest() (*manifest.Manifest, error) {
return c.manifest, nil
}

// Recoverable returns whether the Authority can be recovered from a persisted state.
func (m *Authority) Recoverable() (bool, error) {
return m.hist.HasLatest()
}

// Recover recovers the seed engine from a seed and salt.
func (m *Authority) Recover(seed, salt []byte) error {
// recover recovers the seed engine from a seed and salt.
func (m *Authority) recover(seed, salt []byte) error {
seedEngine, err := seedengine.New(seed, salt)
if err != nil {
return fmt.Errorf("creating seed engine: %w", err)
}
if !m.se.CompareAndSwap(nil, seedEngine) {
return errors.New("already recovered")
return ErrAlreadyRecovered
}
m.hist.ConfigureSigningKey(m.se.Load().TransactionSigningKey())
return nil
Expand All @@ -365,10 +365,10 @@ func (m *Authority) createSeedEngine() error {
if err != nil {
return fmt.Errorf("creating seed engine: %w", err)
}
// It's fine if the seedEngine has already been created by another thread.
m.se.CompareAndSwap(nil, seedEngine)

m.hist.ConfigureSigningKey(m.se.Load().TransactionSigningKey())
if m.se.CompareAndSwap(nil, seedEngine) {
m.hist.ConfigureSigningKey(seedEngine.TransactionSigningKey())
}
return nil
}

Expand Down
30 changes: 15 additions & 15 deletions coordinator/internal/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ func TestEmptyAuthority(t *testing.T) {
a, reg := newAuthority(t)

// A fresh authority does not have a signing key, so this should fail.
manifests, ca, err := a.GetManifestsAndLatestCA()
manifests, ca, err := a.getManifestsAndLatestCA()
assert.Error(t, err)
assert.Nil(t, ca)
assert.Empty(t, manifests)

manifest, err := a.LatestManifest()
manifest, err := a.latestManifest()
assert.Error(t, err)
assert.Nil(t, manifest)

Expand All @@ -52,16 +52,16 @@ func TestSetManifest(t *testing.T) {
a, reg := newAuthority(t)
expected, mnfstBytes, policies := newManifest(t)

mnfst, err := a.LatestManifest()
mnfst, err := a.latestManifest()
require.ErrorIs(err, ErrNoManifest)
require.Nil(mnfst)

ca, err := a.SetManifest(mnfstBytes, policies)
ca, err := a.setManifest(mnfstBytes, policies)
require.NoError(err)
require.NotNil(ca)
requireGauge(t, reg, 1)

actual, err := a.LatestManifest()
actual, err := a.latestManifest()
require.NoError(err)
require.NotNil(actual)

Expand All @@ -70,7 +70,7 @@ func TestSetManifest(t *testing.T) {
// Simulate manifest updates that this instance is not aware of by deleting its state.
a.state.Store(nil)

_, err = a.SetManifest(mnfstBytes, policies)
_, err = a.setManifest(mnfstBytes, policies)
require.NoError(err)
requireGauge(t, reg, 2)
}
Expand All @@ -80,7 +80,7 @@ func TestSetManifest_TooFewPolicies(t *testing.T) {
a, reg := newAuthority(t)
_, mnfstBytes, _ := newManifest(t)

ca, err := a.SetManifest(mnfstBytes, [][]byte{})
ca, err := a.setManifest(mnfstBytes, [][]byte{})
require.Error(err)
require.Nil(ca)
requireGauge(t, reg, 0)
Expand All @@ -91,23 +91,23 @@ func TestSetManifest_BadManifest(t *testing.T) {
a, reg := newAuthority(t)
_, _, policies := newManifest(t)

ca, err := a.SetManifest([]byte(`{ "policies": 1 }`), policies)
ca, err := a.setManifest([]byte(`{ "policies": 1 }`), policies)
require.Error(err)
require.Nil(ca)
requireGauge(t, reg, 0)
}

func TestGetManifests(t *testing.T) {
func TestGetManifestsAndLatestCA(t *testing.T) {
require := require.New(t)
a, reg := newAuthority(t)
originalManifest, mnfstBytes, policies := newManifest(t)

manifests, ca, err := a.GetManifestsAndLatestCA()
manifests, ca, err := a.getManifestsAndLatestCA()
require.ErrorIs(err, ErrNoManifest)
require.Empty(manifests)
require.Nil(ca)

oldCA, err := a.SetManifest(mnfstBytes, policies)
oldCA, err := a.setManifest(mnfstBytes, policies)
require.NoError(err)
require.NotNil(oldCA)
requireGauge(t, reg, 1)
Expand All @@ -117,7 +117,7 @@ func TestGetManifests(t *testing.T) {
alteredManifestBytes, err := json.Marshal(alteredManifest)
require.NoError(err)

expectedCA, err := a.SetManifest(alteredManifestBytes, policies)
expectedCA, err := a.setManifest(alteredManifestBytes, policies)
require.NoError(err)
require.NotNil(expectedCA)
requireGauge(t, reg, 2)
Expand All @@ -126,7 +126,7 @@ func TestGetManifests(t *testing.T) {

expectedManifests := []*manifest.Manifest{originalManifest, &alteredManifest}

manifests, currentCA, err := a.GetManifestsAndLatestCA()
manifests, currentCA, err := a.getManifestsAndLatestCA()
require.NoError(err)
require.Equal(expectedCA.GetMeshCACert(), currentCA.GetMeshCACert())
require.Equal(expectedCA.GetRootCACert(), currentCA.GetRootCACert())
Expand All @@ -135,7 +135,7 @@ func TestGetManifests(t *testing.T) {
// Simulate manifest updates that this instance is not aware of by deleting its state.
a.state.Store(nil)

manifests, _, err = a.GetManifestsAndLatestCA()
manifests, _, err = a.getManifestsAndLatestCA()
require.NoError(err)
require.Equal(expectedManifests, manifests)
requireGauge(t, reg, len(expectedManifests))
Expand All @@ -152,7 +152,7 @@ func TestSNPValidateOpts(t *testing.T) {
require.Error(err)
require.Nil(opts)

_, err = a.SetManifest(mnfstBytes, policies)
_, err = a.setManifest(mnfstBytes, policies)
require.NoError(err)

opts, err = a.SNPValidateOpts(report)
Expand Down
91 changes: 18 additions & 73 deletions coordinator/internal/authority/recoveryapi.go
Original file line number Diff line number Diff line change
@@ -1,88 +1,33 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package main
package authority

import (
"context"
"fmt"
"log/slog"
"net"
"time"
"errors"

"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/recoveryapi"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type recoveryAPIServer struct {
grpc *grpc.Server
logger *slog.Logger
recoverable recoverable
recoveryDoneC chan struct{}
// ErrAlreadyRecovered is returned if recovery was requested but a seed is already set.
var ErrAlreadyRecovered = errors.New("coordinator is already recovered")

recoveryapi.UnimplementedRecoveryAPIServer
}

func newRecoveryAPIServer(recoveryTarget recoverable, reg *prometheus.Registry, log *slog.Logger) *recoveryAPIServer {
issuer := snp.NewIssuer(logger.NewNamed(log, "snp-issuer"))
credentials := atlscredentials.New(issuer, nil)
// Recover recovers the Coordinator from a seed and salt.
func (a *Authority) Recover(_ context.Context, req *recoveryapi.RecoverRequest) (*recoveryapi.RecoverResponse, error) {
a.logger.Info("Recover called")

grpcUserAPIMetrics := grpcprometheus.NewServerMetrics(
grpcprometheus.WithServerCounterOptions(
grpcprometheus.WithSubsystem("contrast_recoveryapi"),
),
grpcprometheus.WithServerHandlingTimeHistogram(
grpcprometheus.WithHistogramSubsystem("contrast_recoveryapi"),
grpcprometheus.WithHistogramBuckets([]float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5}),
),
)

grpcServer := grpc.NewServer(
grpc.Creds(credentials),
grpc.KeepaliveParams(keepalive.ServerParameters{Time: 15 * time.Second}),
grpc.ChainStreamInterceptor(
grpcUserAPIMetrics.StreamServerInterceptor(),
),
grpc.ChainUnaryInterceptor(
grpcUserAPIMetrics.UnaryServerInterceptor(),
),
)
s := &recoveryAPIServer{
grpc: grpcServer,
logger: log.WithGroup("recoveryapi"),
recoverable: recoveryTarget,
recoveryDoneC: make(chan struct{}),
}
recoveryapi.RegisterRecoveryAPIServer(s.grpc, s)

grpcUserAPIMetrics.InitializeMetrics(grpcServer)
reg.MustRegister(grpcUserAPIMetrics)

return s
}
err := a.recover(req.Seed, req.Salt)
switch {
case errors.Is(err, ErrAlreadyRecovered):
return nil, status.Error(codes.FailedPrecondition, err.Error())
case err == nil:
return &recoveryapi.RecoverResponse{}, nil
default:
// Pretty sure this failed because the seed was bad.
return nil, status.Errorf(codes.InvalidArgument, err.Error())

func (s *recoveryAPIServer) Serve(endpoint string) error {
lis, err := net.Listen("tcp", endpoint)
if err != nil {
return fmt.Errorf("listening on %s: %w", endpoint, err)
}
return s.grpc.Serve(lis)
}

func (s *recoveryAPIServer) WaitRecoveryDone() {
<-s.recoveryDoneC
}

func (s *recoveryAPIServer) Recover(_ context.Context, req *recoveryapi.RecoverRequest) (*recoveryapi.RecoverResponse, error) {
return &recoveryapi.RecoverResponse{}, s.recoverable.Recover(req.Seed, req.Salt)
}

type recoverable interface {
Recover(seed, salt []byte) error
}
Loading

0 comments on commit cf571cd

Please sign in to comment.