From 27a9bafb57c2ae3090df35cd33ba2e8781862216 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Wed, 22 Aug 2018 09:29:36 -0600 Subject: [PATCH] sql datastore cleanup - consistent use of transactions. - opens sqlite3 databases in WAL mode to improve concurrency performance - enforces a lock over write operations for sqlite3. - improved error handling. - error wrapping for improved diagnostics. - fixed Configure() to be safe in the presence of concurrent callers and to not overwrite the current database unless the new one opened successfully. - converts tests to test suite - has tests test against database on disk in order to properly exercise behavior (in-memory databases have journaling differences) Signed-off-by: Andrew Harding --- pkg/server/plugin/datastore/sql/migration.go | 30 +- pkg/server/plugin/datastore/sql/postgres.go | 2 +- pkg/server/plugin/datastore/sql/sql.go | 1222 ++++++++++-------- pkg/server/plugin/datastore/sql/sql_test.go | 660 +++++----- pkg/server/plugin/datastore/sql/sqlite.go | 11 +- 5 files changed, 1054 insertions(+), 871 deletions(-) diff --git a/pkg/server/plugin/datastore/sql/migration.go b/pkg/server/plugin/datastore/sql/migration.go index 0c969711f9..54c8f4437b 100644 --- a/pkg/server/plugin/datastore/sql/migration.go +++ b/pkg/server/plugin/datastore/sql/migration.go @@ -15,7 +15,7 @@ const ( func migrateDB(db *gorm.DB) (err error) { isNew := !db.HasTable(&Bundle{}) if err := db.Error; err != nil { - return err + return sqlError.Wrap(err) } if isNew { @@ -23,17 +23,17 @@ func migrateDB(db *gorm.DB) (err error) { } if err := db.AutoMigrate(&Migration{}).Error; err != nil { - return err + return sqlError.Wrap(err) } migration := new(Migration) if err := db.Assign(Migration{}).FirstOrCreate(migration).Error; err != nil { - return err + return sqlError.Wrap(err) } version := migration.Version if version > codeVersion { - err = fmt.Errorf("backwards migration not supported! (current=%d, code=%d)", version, codeVersion) + err = sqlError.New("backwards migration not supported! (current=%d, code=%d)", version, codeVersion) logrus.Error(err) return err } @@ -46,7 +46,7 @@ func migrateDB(db *gorm.DB) (err error) { for version < codeVersion { tx := db.Begin() if err := tx.Error; err != nil { - return err + return sqlError.Wrap(err) } version, err = migrateVersion(tx, version) if err != nil { @@ -54,7 +54,7 @@ func migrateDB(db *gorm.DB) (err error) { return err } if err := tx.Commit().Error; err != nil { - return err + return sqlError.Wrap(err) } } @@ -66,22 +66,26 @@ func initDB(db *gorm.DB) (err error) { logrus.Infof("initializing database.") tx := db.Begin() if err := tx.Error; err != nil { - return err + return sqlError.Wrap(err) } if err := tx.AutoMigrate(&Bundle{}, &CACert{}, &AttestedNodeEntry{}, &NodeResolverMapEntry{}, &RegisteredEntry{}, &JoinToken{}, &Selector{}, &Migration{}).Error; err != nil { tx.Rollback() - return err + return sqlError.Wrap(err) } if err := tx.Assign(Migration{Version: codeVersion}).FirstOrCreate(&Migration{}).Error; err != nil { tx.Rollback() - return err + return sqlError.Wrap(err) + } + + if err := tx.Commit().Error; err != nil { + return sqlError.Wrap(err) } - return tx.Commit().Error + return nil } func migrateVersion(tx *gorm.DB, version int) (versionOut int, err error) { @@ -95,7 +99,7 @@ func migrateVersion(tx *gorm.DB, version int) (versionOut int, err error) { case 0: err = migrateToV1(tx) default: - err = fmt.Errorf("no migration support for version %d", version) + err = sqlError.New("no migration support for version %d", version) } if err != nil { return version, err @@ -103,7 +107,7 @@ func migrateVersion(tx *gorm.DB, version int) (versionOut int, err error) { nextVersion := version + 1 if err := tx.Model(&Migration{}).Updates(Migration{Version: nextVersion}).Error; err != nil { - return version, err + return version, sqlError.Wrap(err) } return nextVersion, nil @@ -126,7 +130,7 @@ func migrateToV1(tx *gorm.DB) error { // sqlite3). for _, table := range v0tables { if err := tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE deleted_at IS NOT NULL;", table)).Error; err != nil { - return err + return sqlError.Wrap(err) } } return nil diff --git a/pkg/server/plugin/datastore/sql/postgres.go b/pkg/server/plugin/datastore/sql/postgres.go index 5b79628854..3b71ebf609 100644 --- a/pkg/server/plugin/datastore/sql/postgres.go +++ b/pkg/server/plugin/datastore/sql/postgres.go @@ -10,7 +10,7 @@ type postgres struct{} func (p postgres) connect(connectionString string) (*gorm.DB, error) { db, err := gorm.Open("postgres", connectionString) if err != nil { - return nil, err + return nil, sqlError.Wrap(err) } return db, nil diff --git a/pkg/server/plugin/datastore/sql/sql.go b/pkg/server/plugin/datastore/sql/sql.go index 833ebb730f..a07556fc63 100644 --- a/pkg/server/plugin/datastore/sql/sql.go +++ b/pkg/server/plugin/datastore/sql/sql.go @@ -4,8 +4,6 @@ import ( "context" "crypto/x509" "errors" - "fmt" - "net/url" "sync" "time" @@ -13,11 +11,13 @@ import ( "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/satori/go.uuid" + "github.com/spiffe/spire/pkg/common/idutil" "github.com/spiffe/spire/pkg/common/selector" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/proto/common" spi "github.com/spiffe/spire/proto/common/plugin" "github.com/spiffe/spire/proto/server/datastore" + "github.com/zeebo/errs" ) var ( @@ -28,101 +28,535 @@ var ( Author: "", Company: "", } + + sqlError = errs.Class("datastore-sql") ) type configuration struct { DatabaseType string `hcl:"database_type" json:"database_type"` ConnectionString string `hcl:"connection_string" json:"connection_string"` + + // Undocumented flags + LogSQL bool `hcl:"log_sql" json:"log_sql"` +} + +type sqlDB struct { + databaseType string + connectionString string + *gorm.DB + + // this lock is only required for synchronized writes with "sqlite3". see + // the withTx() implementation for details. + opMu sync.Mutex } type sqlPlugin struct { - db *gorm.DB + mu sync.Mutex + db *sqlDB +} - DatabaseType string - ConnectionString string +func newPlugin() *sqlPlugin { + return &sqlPlugin{} +} - mutex *sync.Mutex +// New creates a new sql plugin struct. Configure must be called +// in order to start the db. +func New() datastore.Plugin { + return newPlugin() } // CreateBundle stores the given bundle -func (ds *sqlPlugin) CreateBundle(ctx context.Context, req *datastore.Bundle) (*datastore.Bundle, error) { - model, err := ds.bundleToModel(req) +func (ds *sqlPlugin) CreateBundle(ctx context.Context, req *datastore.Bundle) (resp *datastore.Bundle, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = createBundle(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// UpdateBundle updates an existing bundle with the given CAs. Overwrites any +// existing certificates. +func (ds *sqlPlugin) UpdateBundle(ctx context.Context, req *datastore.Bundle) (resp *datastore.Bundle, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = updateBundle(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// AppendBundle adds the specified CA certificates to an existing bundle. If no bundle exists for the +// specified trust domain, create one. Returns the entirety. +func (ds *sqlPlugin) AppendBundle(ctx context.Context, req *datastore.Bundle) (resp *datastore.Bundle, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = appendBundle(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// DeleteBundle deletes the bundle with the matching TrustDomain. Any CACert data passed is ignored. +func (ds *sqlPlugin) DeleteBundle(ctx context.Context, req *datastore.Bundle) (resp *datastore.Bundle, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = deleteBundle(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// FetchBundle returns the bundle matching the specified Trust Domain. +func (ds *sqlPlugin) FetchBundle(ctx context.Context, req *datastore.Bundle) (resp *datastore.Bundle, err error) { + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchBundle(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// ListBundles can be used to fetch all existing bundles. +func (ds *sqlPlugin) ListBundles(ctx context.Context, req *common.Empty) (resp *datastore.Bundles, err error) { + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = listBundles(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) CreateAttestedNodeEntry(ctx context.Context, + req *datastore.CreateAttestedNodeEntryRequest) (resp *datastore.CreateAttestedNodeEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = createAttestedNodeEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) FetchAttestedNodeEntry(ctx context.Context, + req *datastore.FetchAttestedNodeEntryRequest) (resp *datastore.FetchAttestedNodeEntryResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchAttestedNodeEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) FetchStaleNodeEntries(ctx context.Context, + req *datastore.FetchStaleNodeEntriesRequest) (resp *datastore.FetchStaleNodeEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchStaleNodeEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) UpdateAttestedNodeEntry(ctx context.Context, + req *datastore.UpdateAttestedNodeEntryRequest) (resp *datastore.UpdateAttestedNodeEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = updateAttestedNodeEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) DeleteAttestedNodeEntry(ctx context.Context, + req *datastore.DeleteAttestedNodeEntryRequest) (resp *datastore.DeleteAttestedNodeEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = deleteAttestedNodeEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) CreateNodeResolverMapEntry(ctx context.Context, + req *datastore.CreateNodeResolverMapEntryRequest) (resp *datastore.CreateNodeResolverMapEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = createNodeResolverMapEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) FetchNodeResolverMapEntry(ctx context.Context, + req *datastore.FetchNodeResolverMapEntryRequest) (resp *datastore.FetchNodeResolverMapEntryResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchNodeResolverMapEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) DeleteNodeResolverMapEntry(ctx context.Context, + req *datastore.DeleteNodeResolverMapEntryRequest) (resp *datastore.DeleteNodeResolverMapEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = deleteNodeResolverMapEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (sqlPlugin) RectifyNodeResolverMapEntries(ctx context.Context, + req *datastore.RectifyNodeResolverMapEntriesRequest) (*datastore.RectifyNodeResolverMapEntriesResponse, error) { + return &datastore.RectifyNodeResolverMapEntriesResponse{}, errors.New("Not Implemented") +} + +func (ds *sqlPlugin) CreateRegistrationEntry(ctx context.Context, + req *datastore.CreateRegistrationEntryRequest) (resp *datastore.CreateRegistrationEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = createRegistrationEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) FetchRegistrationEntry(ctx context.Context, + req *datastore.FetchRegistrationEntryRequest) (resp *datastore.FetchRegistrationEntryResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchRegistrationEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) FetchRegistrationEntries(ctx context.Context, + req *common.Empty) (resp *datastore.FetchRegistrationEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchRegistrationEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds sqlPlugin) UpdateRegistrationEntry(ctx context.Context, + req *datastore.UpdateRegistrationEntryRequest) (resp *datastore.UpdateRegistrationEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = updateRegistrationEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) DeleteRegistrationEntry(ctx context.Context, + req *datastore.DeleteRegistrationEntryRequest) (resp *datastore.DeleteRegistrationEntryResponse, err error) { + + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = deleteRegistrationEntry(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) ListParentIDEntries(ctx context.Context, + req *datastore.ListParentIDEntriesRequest) (resp *datastore.ListParentIDEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = listParentIDEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) ListSelectorEntries(ctx context.Context, + req *datastore.ListSelectorEntriesRequest) (resp *datastore.ListSelectorEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = listSelectorEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) ListMatchingEntries(ctx context.Context, + req *datastore.ListSelectorEntriesRequest) (resp *datastore.ListSelectorEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = listMatchingEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) ListSpiffeEntries(ctx context.Context, + req *datastore.ListSpiffeEntriesRequest) (resp *datastore.ListSpiffeEntriesResponse, err error) { + + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = listSpiffeEntries(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// RegisterToken takes a Token message and stores it +func (ds *sqlPlugin) RegisterToken(ctx context.Context, req *datastore.JoinToken) (resp *common.Empty, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = registerToken(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// FetchToken takes a Token message and returns one, populating the fields +// we have knowledge of +func (ds *sqlPlugin) FetchToken(ctx context.Context, req *datastore.JoinToken) (resp *datastore.JoinToken, err error) { + if err := ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = fetchToken(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) DeleteToken(ctx context.Context, req *datastore.JoinToken) (resp *common.Empty, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = deleteToken(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +// PruneTokens takes a Token message, and deletes all tokens which have expired +// before the date in the message +func (ds *sqlPlugin) PruneTokens(ctx context.Context, req *datastore.JoinToken) (resp *common.Empty, err error) { + if err := ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { + resp, err = pruneTokens(tx, req) + return err + }); err != nil { + return nil, err + } + return resp, nil +} + +func (ds *sqlPlugin) Configure(ctx context.Context, req *spi.ConfigureRequest) (*spi.ConfigureResponse, error) { + // Parse HCL config payload into config struct + config := &configuration{} + if err := hcl.Decode(config, req.Configuration); err != nil { + return nil, err + } + + if config.DatabaseType == "" { + return nil, errors.New("database_type must be set") + } + + if config.ConnectionString == "" { + return nil, errors.New("connection_string must be set") + } + + ds.mu.Lock() + defer ds.mu.Unlock() + + if ds.db == nil || + config.ConnectionString != ds.db.connectionString || + config.DatabaseType != ds.db.databaseType { + + db, err := openDB(config.DatabaseType, config.ConnectionString) + if err != nil { + return nil, err + } + + if ds.db != nil { + ds.db.Close() + } + + ds.db = &sqlDB{ + DB: db, + databaseType: config.DatabaseType, + connectionString: config.ConnectionString, + } + } + + ds.db.LogMode(config.LogSQL) + + return &spi.ConfigureResponse{}, nil +} + +func (sqlPlugin) GetPluginInfo(context.Context, *spi.GetPluginInfoRequest) (*spi.GetPluginInfoResponse, error) { + return &pluginInfo, nil +} + +func (ds *sqlPlugin) withWriteTx(ctx context.Context, op func(tx *gorm.DB) error) error { + return ds.withTx(ctx, op, false) +} + +func (ds *sqlPlugin) withReadTx(ctx context.Context, op func(tx *gorm.DB) error) error { + return ds.withTx(ctx, op, true) +} + +func (ds *sqlPlugin) withTx(ctx context.Context, op func(tx *gorm.DB) error, readOnly bool) error { + ds.mu.Lock() + db := ds.db + ds.mu.Unlock() + + if db.databaseType == "sqlite3" && !readOnly { + // sqlite3 can only have one writer at a time. since we're in WAL mode, + // there can be concurrent reads and writes, so no lock is necessary + // over the read operations. + db.opMu.Lock() + defer db.opMu.Unlock() + } + + // TODO: as soon as GORM supports it, attach the context + tx := db.Begin() + if err := tx.Error; err != nil { + return sqlError.Wrap(err) + } + + if err := op(tx); err != nil { + tx.Rollback() + return err + } + + if readOnly { + // rolling back makes sure that functions that are invoked with + // withReadTx, and then do writes, will not pass unit tests, since the + // writes won't be committed. + return sqlError.Wrap(tx.Rollback().Error) + } + return sqlError.Wrap(tx.Commit().Error) +} + +func openDB(databaseType, connectionString string) (*gorm.DB, error) { + var db *gorm.DB + var err error + + switch databaseType { + case "sqlite3": + db, err = sqlite{}.connect(connectionString) + case "postgres": + db, err = postgres{}.connect(connectionString) + default: + return nil, sqlError.New("unsupported database_type: %v", databaseType) + } if err != nil { return nil, err } - result := ds.db.Create(model) - if result.Error != nil { - return nil, result.Error + if err := migrateDB(db); err != nil { + db.Close() + return nil, err } - return req, nil + return db, nil } -// UpdateBundle updates an existing bundle with the given CAs. Overwrites any -// existing certificates. -func (ds *sqlPlugin) UpdateBundle(ctx context.Context, req *datastore.Bundle) (*datastore.Bundle, error) { - newModel, err := ds.bundleToModel(req) +func createBundle(tx *gorm.DB, req *datastore.Bundle) (*datastore.Bundle, error) { + model, err := bundleToModel(req) if err != nil { return nil, err } - tx := ds.db.Begin() + if err := tx.Create(model).Error; err != nil { + return nil, sqlError.Wrap(err) + } + + return req, nil +} + +func updateBundle(tx *gorm.DB, req *datastore.Bundle) (*datastore.Bundle, error) { + newModel, err := bundleToModel(req) + if err != nil { + return nil, err + } // Fetch the model to get its ID model := &Bundle{} - result := tx.Find(model, "trust_domain = ?", newModel.TrustDomain) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Find(model, "trust_domain = ?", newModel.TrustDomain).Error; err != nil { + return nil, sqlError.Wrap(err) } // Delete existing CA certs - the provided list takes precedence - result = tx.Where("bundle_id = ?", model.ID).Delete(CACert{}) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Where("bundle_id = ?", model.ID).Delete(CACert{}).Error; err != nil { + return nil, sqlError.Wrap(err) } // Set the new values model.CACerts = newModel.CACerts - result = tx.Save(model) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Save(model).Error; err != nil { + return nil, sqlError.Wrap(err) } - return req, tx.Commit().Error + return req, nil } -// AppendBundle adds the specified CA certificates to an existing bundle. If no bundle exists for the -// specified trust domain, create one. Returns the entirety. -func (ds *sqlPlugin) AppendBundle(ctx context.Context, req *datastore.Bundle) (*datastore.Bundle, error) { - newModel, err := ds.bundleToModel(req) +func appendBundle(tx *gorm.DB, req *datastore.Bundle) (*datastore.Bundle, error) { + newModel, err := bundleToModel(req) if err != nil { return nil, err } - tx := ds.db.Begin() - // First, fetch the existing model model := &Bundle{} result := tx.Find(model, "trust_domain = ?", newModel.TrustDomain) - if result.RecordNotFound() { - tx.Rollback() - return ds.CreateBundle(ctx, req) + return createBundle(tx, req) } else if result.Error != nil { - tx.Rollback() - return nil, result.Error + return nil, sqlError.Wrap(result.Error) } // Get the existing certificates so we can include them in the response var caCerts []CACert - result = tx.Model(model).Related(&caCerts) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Model(model).Related(&caCerts).Error; err != nil { + return nil, sqlError.Wrap(err) } model.CACerts = caCerts @@ -132,108 +566,83 @@ func (ds *sqlPlugin) AppendBundle(ctx context.Context, req *datastore.Bundle) (* } } - result = tx.Save(model) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Save(model).Error; err != nil { + return nil, sqlError.Wrap(err) } - resp, err := ds.modelToBundle(model) + resp, err := modelToBundle(model) if err != nil { - tx.Rollback() return nil, err } - return resp, tx.Commit().Error + return resp, nil } -// DeleteBundle deletes the bundle with the matching TrustDomain. Any CACert data passed is ignored. -func (ds *sqlPlugin) DeleteBundle(ctx context.Context, req *datastore.Bundle) (*datastore.Bundle, error) { +func deleteBundle(tx *gorm.DB, req *datastore.Bundle) (*datastore.Bundle, error) { // We don't care if cert data was sent - remove it now to prevent // further processing. req.CaCerts = []byte{} - model, err := ds.bundleToModel(req) + model, err := bundleToModel(req) if err != nil { return nil, err } - tx := ds.db.Begin() - - result := tx.Find(model, "trust_domain = ?", model.TrustDomain) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Find(model, "trust_domain = ?", model.TrustDomain).Error; err != nil { + return nil, sqlError.Wrap(err) } // Fetch related CA certs for response before we delete them var caCerts []CACert - result = tx.Model(model).Related(&caCerts) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Model(model).Related(&caCerts).Error; err != nil { + return nil, sqlError.Wrap(err) } model.CACerts = caCerts - result = tx.Where("bundle_id = ?", model.ID).Delete(CACert{}) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Where("bundle_id = ?", model.ID).Delete(CACert{}).Error; err != nil { + return nil, sqlError.Wrap(err) } - result = tx.Delete(model) - if result.Error != nil { - tx.Rollback() - return nil, result.Error + if err := tx.Delete(model).Error; err != nil { + return nil, sqlError.Wrap(err) } - resp, err := ds.modelToBundle(model) + resp, err := modelToBundle(model) if err != nil { - tx.Rollback() return nil, err } - return resp, tx.Commit().Error + return resp, nil } // FetchBundle returns the bundle matching the specified Trust Domain. -func (ds *sqlPlugin) FetchBundle(ctx context.Context, req *datastore.Bundle) (*datastore.Bundle, error) { - model, err := ds.bundleToModel(req) +func fetchBundle(tx *gorm.DB, req *datastore.Bundle) (*datastore.Bundle, error) { + model, err := bundleToModel(req) if err != nil { return nil, err } - result := ds.db.Find(model, "trust_domain = ?", model.TrustDomain) - if result.Error != nil { - return nil, result.Error + if err := tx.Find(model, "trust_domain = ?", model.TrustDomain).Error; err != nil { + return nil, sqlError.Wrap(err) } - var caCerts []CACert - result = ds.db.Model(model).Related(&caCerts) - if result.Error != nil { - return nil, result.Error + if err := tx.Model(model).Related(&model.CACerts).Error; err != nil { + return nil, sqlError.Wrap(err) } - model.CACerts = caCerts - return ds.modelToBundle(model) + return modelToBundle(model) } // ListBundles can be used to fetch all existing bundles. -func (ds *sqlPlugin) ListBundles(ctx context.Context, req *common.Empty) (*datastore.Bundles, error) { - // Get a consistent view - tx := ds.db.Begin() - defer tx.Rollback() - +func listBundles(tx *gorm.DB, req *common.Empty) (*datastore.Bundles, error) { var bundles []Bundle - result := tx.Find(&bundles) - if result.Error != nil { - return nil, result.Error + if err := tx.Find(&bundles).Error; err != nil { + return nil, sqlError.Wrap(err) } var caCerts []CACert - result = tx.Find(&caCerts) - if result.Error != nil { - return nil, result.Error + if err := tx.Find(&caCerts).Error; err != nil { + return nil, sqlError.Wrap(err) } // Index CA Certs by Bundle ID so we can reconstruct them more easily @@ -257,7 +666,7 @@ func (ds *sqlPlugin) ListBundles(ctx context.Context, req *common.Empty) (*datas model.CACerts = []CACert{} } - bundle, err := ds.modelToBundle(&model) + bundle, err := modelToBundle(&model) if err != nil { return nil, err } @@ -268,20 +677,15 @@ func (ds *sqlPlugin) ListBundles(ctx context.Context, req *common.Empty) (*datas return resp, nil } -func (ds *sqlPlugin) CreateAttestedNodeEntry(ctx context.Context, - req *datastore.CreateAttestedNodeEntryRequest) (*datastore.CreateAttestedNodeEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func createAttestedNodeEntry(tx *gorm.DB, req *datastore.CreateAttestedNodeEntryRequest) (*datastore.CreateAttestedNodeEntryResponse, error) { entry := req.AttestedNodeEntry if entry == nil { - return nil, errors.New("invalid request: missing attested node") + return nil, sqlError.New("invalid request: missing attested node") } expiresAt, err := time.Parse(datastore.TimeFormat, entry.CertExpirationDate) if err != nil { - return nil, errors.New("invalid request: missing expiration") + return nil, sqlError.New("invalid request: missing expiration") } model := AttestedNodeEntry{ @@ -291,8 +695,8 @@ func (ds *sqlPlugin) CreateAttestedNodeEntry(ctx context.Context, ExpiresAt: expiresAt, } - if err := ds.db.Create(&model).Error; err != nil { - return nil, err + if err := tx.Create(&model).Error; err != nil { + return nil, sqlError.Wrap(err) } return &datastore.CreateAttestedNodeEntryResponse{ @@ -305,19 +709,14 @@ func (ds *sqlPlugin) CreateAttestedNodeEntry(ctx context.Context, }, nil } -func (ds *sqlPlugin) FetchAttestedNodeEntry(ctx context.Context, - req *datastore.FetchAttestedNodeEntryRequest) (*datastore.FetchAttestedNodeEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func fetchAttestedNodeEntry(tx *gorm.DB, req *datastore.FetchAttestedNodeEntryRequest) (*datastore.FetchAttestedNodeEntryResponse, error) { var model AttestedNodeEntry - err := ds.db.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error + err := tx.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error switch { case err == gorm.ErrRecordNotFound: return &datastore.FetchAttestedNodeEntryResponse{}, nil case err != nil: - return nil, err + return nil, sqlError.Wrap(err) } return &datastore.FetchAttestedNodeEntryResponse{ AttestedNodeEntry: &datastore.AttestedNodeEntry{ @@ -329,15 +728,10 @@ func (ds *sqlPlugin) FetchAttestedNodeEntry(ctx context.Context, }, nil } -func (ds *sqlPlugin) FetchStaleNodeEntries(ctx context.Context, - req *datastore.FetchStaleNodeEntriesRequest) (*datastore.FetchStaleNodeEntriesResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func fetchStaleNodeEntries(tx *gorm.DB, req *datastore.FetchStaleNodeEntriesRequest) (*datastore.FetchStaleNodeEntriesResponse, error) { var models []AttestedNodeEntry - if err := ds.db.Find(&models, "expires_at < ?", time.Now()).Error; err != nil { - return nil, err + if err := tx.Find(&models, "expires_at < ?", time.Now()).Error; err != nil { + return nil, sqlError.Wrap(err) } resp := &datastore.FetchStaleNodeEntriesResponse{ @@ -355,24 +749,15 @@ func (ds *sqlPlugin) FetchStaleNodeEntries(ctx context.Context, return resp, nil } -func (ds *sqlPlugin) UpdateAttestedNodeEntry(ctx context.Context, - req *datastore.UpdateAttestedNodeEntryRequest) (*datastore.UpdateAttestedNodeEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - - var model AttestedNodeEntry - +func updateAttestedNodeEntry(tx *gorm.DB, req *datastore.UpdateAttestedNodeEntryRequest) (*datastore.UpdateAttestedNodeEntryResponse, error) { expiresAt, err := time.Parse(datastore.TimeFormat, req.CertExpirationDate) if err != nil { - return nil, err + return nil, sqlError.Wrap(err) } - db := ds.db.Begin() - - if err := db.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { - db.Rollback() - return nil, err + var model AttestedNodeEntry + if err := tx.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { + return nil, sqlError.Wrap(err) } updates := AttestedNodeEntry{ @@ -380,9 +765,8 @@ func (ds *sqlPlugin) UpdateAttestedNodeEntry(ctx context.Context, ExpiresAt: expiresAt, } - if err := db.Model(&model).Updates(updates).Error; err != nil { - db.Rollback() - return nil, err + if err := tx.Model(&model).Updates(updates).Error; err != nil { + return nil, sqlError.Wrap(err) } return &datastore.UpdateAttestedNodeEntryResponse{ @@ -392,27 +776,17 @@ func (ds *sqlPlugin) UpdateAttestedNodeEntry(ctx context.Context, CertSerialNumber: model.SerialNumber, CertExpirationDate: model.ExpiresAt.Format(datastore.TimeFormat), }, - }, db.Commit().Error + }, nil } -func (ds *sqlPlugin) DeleteAttestedNodeEntry(ctx context.Context, - req *datastore.DeleteAttestedNodeEntryRequest) (*datastore.DeleteAttestedNodeEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - - db := ds.db.Begin() - +func deleteAttestedNodeEntry(tx *gorm.DB, req *datastore.DeleteAttestedNodeEntryRequest) (*datastore.DeleteAttestedNodeEntryResponse, error) { var model AttestedNodeEntry - - if err := db.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { - db.Rollback() - return nil, err + if err := tx.Find(&model, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { + return nil, sqlError.Wrap(err) } - if err := db.Delete(&model).Error; err != nil { - db.Rollback() - return nil, err + if err := tx.Delete(&model).Error; err != nil { + return nil, sqlError.Wrap(err) } return &datastore.DeleteAttestedNodeEntryResponse{ @@ -422,23 +796,18 @@ func (ds *sqlPlugin) DeleteAttestedNodeEntry(ctx context.Context, CertSerialNumber: model.SerialNumber, CertExpirationDate: model.ExpiresAt.Format(datastore.TimeFormat), }, - }, db.Commit().Error + }, nil } -func (ds *sqlPlugin) CreateNodeResolverMapEntry(ctx context.Context, - req *datastore.CreateNodeResolverMapEntryRequest) (*datastore.CreateNodeResolverMapEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func createNodeResolverMapEntry(tx *gorm.DB, req *datastore.CreateNodeResolverMapEntryRequest) (*datastore.CreateNodeResolverMapEntryResponse, error) { entry := req.NodeResolverMapEntry if entry == nil { - return nil, errors.New("Invalid Request: no map entry") + return nil, sqlError.New("invalid request: no map entry") } selector := entry.Selector if selector == nil { - return nil, errors.New("Invalid Request: no selector") + return nil, sqlError.New("invalid request: no selector") } model := NodeResolverMapEntry{ @@ -447,8 +816,8 @@ func (ds *sqlPlugin) CreateNodeResolverMapEntry(ctx context.Context, Value: selector.Value, } - if err := ds.db.Create(&model).Error; err != nil { - return nil, err + if err := tx.Create(&model).Error; err != nil { + return nil, sqlError.Wrap(err) } return &datastore.CreateNodeResolverMapEntryResponse{ @@ -462,16 +831,10 @@ func (ds *sqlPlugin) CreateNodeResolverMapEntry(ctx context.Context, }, nil } -func (ds *sqlPlugin) FetchNodeResolverMapEntry(ctx context.Context, - req *datastore.FetchNodeResolverMapEntryRequest) (*datastore.FetchNodeResolverMapEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func fetchNodeResolverMapEntry(tx *gorm.DB, req *datastore.FetchNodeResolverMapEntryRequest) (*datastore.FetchNodeResolverMapEntryResponse, error) { var models []NodeResolverMapEntry - - if err := ds.db.Find(&models, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { - return nil, err + if err := tx.Find(&models, "spiffe_id = ?", req.BaseSpiffeId).Error; err != nil { + return nil, sqlError.Wrap(err) } resp := &datastore.FetchNodeResolverMapEntryResponse{ @@ -490,21 +853,13 @@ func (ds *sqlPlugin) FetchNodeResolverMapEntry(ctx context.Context, return resp, nil } -func (ds *sqlPlugin) DeleteNodeResolverMapEntry(ctx context.Context, - req *datastore.DeleteNodeResolverMapEntryRequest) (*datastore.DeleteNodeResolverMapEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() - +func deleteNodeResolverMapEntry(tx *gorm.DB, req *datastore.DeleteNodeResolverMapEntryRequest) (*datastore.DeleteNodeResolverMapEntryResponse, error) { entry := req.NodeResolverMapEntry if entry == nil { - return nil, errors.New("Invalid Request: no map entry") + return nil, sqlError.New("invalid request: no map entry") } - tx := ds.db.Begin() - // if no selector is given, delete all entries with the given spiffe id - scope := tx.Where("spiffe_id = ?", entry.BaseSpiffeId) if selector := entry.Selector; selector != nil { @@ -515,13 +870,11 @@ func (ds *sqlPlugin) DeleteNodeResolverMapEntry(ctx context.Context, var models []NodeResolverMapEntry if err := scope.Find(&models).Error; err != nil { - tx.Rollback() - return nil, err + return nil, sqlError.Wrap(err) } if err := scope.Delete(&NodeResolverMapEntry{}).Error; err != nil { - tx.Rollback() - return nil, err + return nil, sqlError.Wrap(err) } resp := &datastore.DeleteNodeResolverMapEntryResponse{ @@ -538,84 +891,70 @@ func (ds *sqlPlugin) DeleteNodeResolverMapEntry(ctx context.Context, }) } - return resp, tx.Commit().Error -} - -func (sqlPlugin) RectifyNodeResolverMapEntries(ctx context.Context, - req *datastore.RectifyNodeResolverMapEntriesRequest) (*datastore.RectifyNodeResolverMapEntriesResponse, error) { - return &datastore.RectifyNodeResolverMapEntriesResponse{}, errors.New("Not Implemented") + return resp, nil } -func (ds *sqlPlugin) CreateRegistrationEntry(ctx context.Context, - request *datastore.CreateRegistrationEntryRequest) (*datastore.CreateRegistrationEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() +func createRegistrationEntry(tx *gorm.DB, + req *datastore.CreateRegistrationEntryRequest) (*datastore.CreateRegistrationEntryResponse, error) { // TODO: Validations should be done in the ProtoBuf level [https://github.com/spiffe/spire/issues/44] - if request.RegisteredEntry == nil { - return nil, errors.New("Invalid request: missing registered entry") + if req.RegisteredEntry == nil { + return nil, sqlError.New("invalid request: missing registered entry") } - err := ds.validateRegistrationEntry(request.RegisteredEntry) - if err != nil { - return nil, fmt.Errorf("Invalid registration entry: %v", err) + if err := validateRegistrationEntry(req.RegisteredEntry); err != nil { + return nil, err } - entryID, err := uuid.NewV4() + entryID, err := newRegistrationEntryID() if err != nil { - return nil, fmt.Errorf("could not generate entry id: %v", err) + return nil, err } newRegisteredEntry := RegisteredEntry{ - EntryID: entryID.String(), - SpiffeID: request.RegisteredEntry.SpiffeId, - ParentID: request.RegisteredEntry.ParentId, - TTL: request.RegisteredEntry.Ttl, + EntryID: entryID, + SpiffeID: req.RegisteredEntry.SpiffeId, + ParentID: req.RegisteredEntry.ParentId, + TTL: req.RegisteredEntry.Ttl, // TODO: Add support to Federated Bundles [https://github.com/spiffe/spire/issues/42] } - tx := ds.db.Begin() if err := tx.Create(&newRegisteredEntry).Error; err != nil { - tx.Rollback() - return nil, err + return nil, sqlError.Wrap(err) } - for _, registeredSelector := range request.RegisteredEntry.Selectors { + for _, registeredSelector := range req.RegisteredEntry.Selectors { newSelector := Selector{ RegisteredEntryID: newRegisteredEntry.ID, Type: registeredSelector.Type, Value: registeredSelector.Value} if err := tx.Create(&newSelector).Error; err != nil { - tx.Rollback() - return nil, err + return nil, sqlError.Wrap(err) } } return &datastore.CreateRegistrationEntryResponse{ RegisteredEntryId: newRegisteredEntry.EntryID, - }, tx.Commit().Error + }, nil } -func (ds *sqlPlugin) FetchRegistrationEntry(ctx context.Context, - request *datastore.FetchRegistrationEntryRequest) (*datastore.FetchRegistrationEntryResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() +func fetchRegistrationEntry(tx *gorm.DB, + req *datastore.FetchRegistrationEntryRequest) (*datastore.FetchRegistrationEntryResponse, error) { var fetchedRegisteredEntry RegisteredEntry - err := ds.db.Find(&fetchedRegisteredEntry, "entry_id = ?", request.RegisteredEntryId).Error - + err := tx.Find(&fetchedRegisteredEntry, "entry_id = ?", req.RegisteredEntryId).Error switch { case err == gorm.ErrRecordNotFound: return &datastore.FetchRegistrationEntryResponse{}, nil case err != nil: - return nil, err + return nil, sqlError.Wrap(err) } var fetchedSelectors []*Selector - ds.db.Model(&fetchedRegisteredEntry).Related(&fetchedSelectors) + if err := tx.Model(&fetchedRegisteredEntry).Related(&fetchedSelectors).Error; err != nil { + return nil, sqlError.Wrap(err) + } selectors := make([]*common.Selector, 0, len(fetchedSelectors)) @@ -636,17 +975,17 @@ func (ds *sqlPlugin) FetchRegistrationEntry(ctx context.Context, }, nil } -func (ds *sqlPlugin) FetchRegistrationEntries(ctx context.Context, - request *common.Empty) (*datastore.FetchRegistrationEntriesResponse, error) { +func fetchRegistrationEntries(tx *gorm.DB, + req *common.Empty) (*datastore.FetchRegistrationEntriesResponse, error) { var entries []RegisteredEntry - if err := ds.db.Find(&entries).Error; err != nil { - return nil, err + if err := tx.Find(&entries).Error; err != nil { + return nil, sqlError.Wrap(err) } var sel []Selector - if err := ds.db.Find(&sel).Error; err != nil { - return nil, err + if err := tx.Find(&sel).Error; err != nil { + return nil, sqlError.Wrap(err) } // Organize the selectors for easier access @@ -662,50 +1001,43 @@ func (ds *sqlPlugin) FetchRegistrationEntries(ctx context.Context, } } - resEntries, err := ds.convertEntries(entries) + resEntries, err := modelsToEntries(tx, entries) if err != nil { - return nil, err + return nil, sqlError.Wrap(err) } - res := &datastore.FetchRegistrationEntriesResponse{ + return &datastore.FetchRegistrationEntriesResponse{ RegisteredEntries: &common.RegistrationEntries{ Entries: resEntries, }, - } - - return res, nil + }, nil } -func (ds sqlPlugin) UpdateRegistrationEntry(ctx context.Context, - request *datastore.UpdateRegistrationEntryRequest) (*datastore.UpdateRegistrationEntryResponse, error) { +func updateRegistrationEntry(tx *gorm.DB, + req *datastore.UpdateRegistrationEntryRequest) (*datastore.UpdateRegistrationEntryResponse, error) { - if request.RegisteredEntry == nil { - return nil, errors.New("No registration entry provided") + if req.RegisteredEntry == nil { + return nil, sqlError.New("no registration entry provided") } - err := ds.validateRegistrationEntry(request.RegisteredEntry) - if err != nil { - return nil, fmt.Errorf("Invalid registration entry: %v", err) + if err := validateRegistrationEntry(req.RegisteredEntry); err != nil { + return nil, err } - tx := ds.db.Begin() - // Get the existing entry // TODO: Refactor message type to take EntryID directly from the entry - see #449 entry := RegisteredEntry{} - if err = tx.Find(&entry, "entry_id = ?", request.RegisteredEntryId).Error; err != nil { - tx.Rollback() - return nil, err + if err := tx.Find(&entry, "entry_id = ?", req.RegisteredEntryId).Error; err != nil { + return nil, sqlError.Wrap(err) } // Delete existing selectors - we will write new ones - if err = tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - tx.Rollback() - return nil, err + if err := tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { + return nil, sqlError.Wrap(err) } selectors := []Selector{} - for _, s := range request.RegisteredEntry.Selectors { + for _, s := range req.RegisteredEntry.Selectors { selector := Selector{ Type: s.Type, Value: s.Value, @@ -714,96 +1046,81 @@ func (ds sqlPlugin) UpdateRegistrationEntry(ctx context.Context, selectors = append(selectors, selector) } - entry.SpiffeID = request.RegisteredEntry.SpiffeId - entry.ParentID = request.RegisteredEntry.ParentId - entry.TTL = request.RegisteredEntry.Ttl + entry.SpiffeID = req.RegisteredEntry.SpiffeId + entry.ParentID = req.RegisteredEntry.ParentId + entry.TTL = req.RegisteredEntry.Ttl entry.Selectors = selectors - if err = tx.Save(&entry).Error; err != nil { - tx.Rollback() - return nil, err - } - - if err = tx.Commit().Error; err != nil { - tx.Rollback() - return nil, err + if err := tx.Save(&entry).Error; err != nil { + return nil, sqlError.Wrap(err) } - request.RegisteredEntry.EntryId = entry.EntryID - return &datastore.UpdateRegistrationEntryResponse{RegisteredEntry: request.RegisteredEntry}, nil + req.RegisteredEntry.EntryId = entry.EntryID + return &datastore.UpdateRegistrationEntryResponse{ + RegisteredEntry: req.RegisteredEntry, + }, nil } -func (ds *sqlPlugin) DeleteRegistrationEntry(ctx context.Context, - request *datastore.DeleteRegistrationEntryRequest) (*datastore.DeleteRegistrationEntryResponse, error) { +func deleteRegistrationEntry(tx *gorm.DB, + req *datastore.DeleteRegistrationEntryRequest) (*datastore.DeleteRegistrationEntryResponse, error) { entry := RegisteredEntry{} - if err := ds.db.Find(&entry, "entry_id = ?", request.RegisteredEntryId).Error; err != nil { - return &datastore.DeleteRegistrationEntryResponse{}, err + if err := tx.Find(&entry, "entry_id = ?", req.RegisteredEntryId).Error; err != nil { + return nil, sqlError.Wrap(err) } - if err := ds.db.Delete(&entry).Error; err != nil { - return &datastore.DeleteRegistrationEntryResponse{}, err + if err := tx.Delete(&entry).Error; err != nil { + return nil, sqlError.Wrap(err) } - respEntry, err := ds.convertEntries([]RegisteredEntry{entry}) + respEntry, err := modelToEntry(tx, entry) if err != nil { - return &datastore.DeleteRegistrationEntryResponse{}, err + return nil, err } - resp := &datastore.DeleteRegistrationEntryResponse{ - RegisteredEntry: respEntry[0], - } - return resp, nil + return &datastore.DeleteRegistrationEntryResponse{ + RegisteredEntry: respEntry, + }, nil } -func (ds *sqlPlugin) ListParentIDEntries(ctx context.Context, - request *datastore.ListParentIDEntriesRequest) (response *datastore.ListParentIDEntriesResponse, err error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() +func listParentIDEntries(tx *gorm.DB, + req *datastore.ListParentIDEntriesRequest) (*datastore.ListParentIDEntriesResponse, error) { var fetchedRegisteredEntries []RegisteredEntry - err = ds.db.Find(&fetchedRegisteredEntries, "parent_id = ?", request.ParentId).Error - + err := tx.Find(&fetchedRegisteredEntries, "parent_id = ?", req.ParentId).Error switch { case err == gorm.ErrRecordNotFound: return &datastore.ListParentIDEntriesResponse{}, nil case err != nil: - return nil, err + return nil, sqlError.Wrap(err) } - regEntryList, err := ds.convertEntries(fetchedRegisteredEntries) + regEntryList, err := modelsToEntries(tx, fetchedRegisteredEntries) if err != nil { return nil, err } return &datastore.ListParentIDEntriesResponse{RegisteredEntryList: regEntryList}, nil } -func (ds *sqlPlugin) ListSelectorEntries(ctx context.Context, - request *datastore.ListSelectorEntriesRequest) (*datastore.ListSelectorEntriesResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() +func listSelectorEntries(tx *gorm.DB, + req *datastore.ListSelectorEntriesRequest) (*datastore.ListSelectorEntriesResponse, error) { - entries, err := ds.listMatchingEntries(request.Selectors) + entries, err := listEntriesWithExactSelectorMatch(tx, req.Selectors) if err != nil { - return &datastore.ListSelectorEntriesResponse{}, err + return nil, err } util.SortRegistrationEntries(entries) return &datastore.ListSelectorEntriesResponse{RegisteredEntryList: entries}, nil } -func (ds *sqlPlugin) ListMatchingEntries(ctx context.Context, - request *datastore.ListSelectorEntriesRequest) (*datastore.ListSelectorEntriesResponse, error) { - - ds.mutex.Lock() - defer ds.mutex.Unlock() +func listMatchingEntries(tx *gorm.DB, + req *datastore.ListSelectorEntriesRequest) (*datastore.ListSelectorEntriesResponse, error) { resp := &datastore.ListSelectorEntriesResponse{} - for combination := range selector.NewSetFromRaw(request.Selectors).Power() { - entries, err := ds.listMatchingEntries(combination.Raw()) + for combination := range selector.NewSetFromRaw(req.Selectors).Power() { + entries, err := listEntriesWithExactSelectorMatch(tx, combination.Raw()) if err != nil { - return &datastore.ListSelectorEntriesResponse{}, err + return nil, err } resp.RegisteredEntryList = append(resp.RegisteredEntryList, entries...) } @@ -812,31 +1129,27 @@ func (ds *sqlPlugin) ListMatchingEntries(ctx context.Context, return resp, nil } -func (ds *sqlPlugin) ListSpiffeEntries(ctx context.Context, - request *datastore.ListSpiffeEntriesRequest) (*datastore.ListSpiffeEntriesResponse, error) { +func listSpiffeEntries(tx *gorm.DB, + req *datastore.ListSpiffeEntriesRequest) (*datastore.ListSpiffeEntriesResponse, error) { var entries []RegisteredEntry - err := ds.db.Find(&entries, "spiffe_id = ?", request.SpiffeId).Error - if err != nil { - return &datastore.ListSpiffeEntriesResponse{}, err + if err := tx.Find(&entries, "spiffe_id = ?", req.SpiffeId).Error; err != nil { + return nil, sqlError.Wrap(err) } - respEntries, err := ds.convertEntries(entries) + respEntries, err := modelsToEntries(tx, entries) if err != nil { - return &datastore.ListSpiffeEntriesResponse{}, err + return nil, err } - resp := &datastore.ListSpiffeEntriesResponse{ + return &datastore.ListSpiffeEntriesResponse{ RegisteredEntryList: respEntries, - } - return resp, nil + }, nil } -// RegisterToken takes a Token message and stores it -func (ds *sqlPlugin) RegisterToken(ctx context.Context, req *datastore.JoinToken) (*common.Empty, error) { - resp := new(common.Empty) +func registerToken(tx *gorm.DB, req *datastore.JoinToken) (*common.Empty, error) { if req.Token == "" || req.Expiry == 0 { - return resp, errors.New("token and expiry are required") + return nil, errors.New("token and expiry are required") } t := JoinToken{ @@ -844,97 +1157,88 @@ func (ds *sqlPlugin) RegisterToken(ctx context.Context, req *datastore.JoinToken Expiry: req.Expiry, } - return resp, ds.db.Create(&t).Error + if err := tx.Create(&t).Error; err != nil { + return nil, sqlError.Wrap(err) + } + + return &common.Empty{}, nil } -// FetchToken takes a Token message and returns one, populating the fields -// we have knowledge of -func (ds *sqlPlugin) FetchToken(ctx context.Context, req *datastore.JoinToken) (*datastore.JoinToken, error) { +func fetchToken(tx *gorm.DB, req *datastore.JoinToken) (*datastore.JoinToken, error) { var t JoinToken - - err := ds.db.Find(&t, "token = ?", req.Token).Error + err := tx.Find(&t, "token = ?", req.Token).Error if err == gorm.ErrRecordNotFound { return &datastore.JoinToken{}, nil + } else if err != nil { + return nil, sqlError.Wrap(err) } - resp := &datastore.JoinToken{ + return &datastore.JoinToken{ Token: t.Token, Expiry: t.Expiry, - } - return resp, err + }, nil } -func (ds *sqlPlugin) DeleteToken(ctx context.Context, req *datastore.JoinToken) (*common.Empty, error) { +func deleteToken(tx *gorm.DB, req *datastore.JoinToken) (*common.Empty, error) { var t JoinToken + if err := tx.Find(&t, "token = ?", req.Token).Error; err != nil { + return nil, sqlError.Wrap(err) + } - err := ds.db.Find(&t, "token = ?", req.Token).Error - if err != nil { - return &common.Empty{}, err + if err := tx.Delete(&t).Error; err != nil { + return nil, sqlError.Wrap(err) } - return &common.Empty{}, ds.db.Delete(&t).Error + return &common.Empty{}, nil } -// PruneTokens takes a Token message, and deletes all tokens which have expired -// before the date in the message -func (ds *sqlPlugin) PruneTokens(ctx context.Context, req *datastore.JoinToken) (*common.Empty, error) { - var staleTokens []JoinToken - resp := new(common.Empty) - - err := ds.db.Where("expiry <= ?", req.Expiry).Find(&staleTokens).Error - if err != nil { - return resp, err - } - - for _, t := range staleTokens { - err := ds.db.Delete(&t).Error - if err != nil { - return resp, err - } +func pruneTokens(tx *gorm.DB, req *datastore.JoinToken) (*common.Empty, error) { + if err := tx.Where("expiry <= ?", req.Expiry).Delete(&JoinToken{}).Error; err != nil { + return nil, sqlError.Wrap(err) } - return resp, nil + return &common.Empty{}, nil } -func (ds *sqlPlugin) Configure(ctx context.Context, req *spi.ConfigureRequest) (*spi.ConfigureResponse, error) { - resp := &spi.ConfigureResponse{} - - // Parse HCL config payload into config struct - config := &configuration{} - hclTree, err := hcl.Parse(req.Configuration) +// modelToBundle converts the given bundle model to a Protobuf bundle message. It will also +// include any embedded CACert models. +func modelToBundle(model *Bundle) (*datastore.Bundle, error) { + id, err := idutil.ParseSpiffeID(model.TrustDomain, idutil.AllowAnyTrustDomain()) if err != nil { - resp.ErrorList = []string{err.Error()} - return resp, err + return nil, sqlError.Wrap(err) } - err = hcl.DecodeObject(&config, hclTree) - if err != nil { - resp.ErrorList = []string{err.Error()} - return resp, err + + caCerts := []byte{} + for _, c := range model.CACerts { + caCerts = append(caCerts, c.Cert...) } - if config.DatabaseType == "" { - return resp, errors.New("database_type must be set") + pb := &datastore.Bundle{ + TrustDomain: id.String(), + CaCerts: caCerts, } - if config.ConnectionString == "" { - return resp, errors.New("connection_string must be set") + return pb, nil +} + +func validateRegistrationEntry(entry *common.RegistrationEntry) error { + if entry.Selectors == nil || len(entry.Selectors) == 0 { + return sqlError.New("invalid registration entry: missing selector list") } - if config.ConnectionString != ds.ConnectionString { - ds.DatabaseType = config.DatabaseType - ds.ConnectionString = config.ConnectionString - return resp, ds.restart() + if len(entry.SpiffeId) == 0 { + return sqlError.New("invalid registration entry: missing SPIFFE ID") } - return resp, nil -} + if entry.Ttl < 0 { + return sqlError.New("invalid registration entry: TTL is not set") + } -func (sqlPlugin) GetPluginInfo(context.Context, *spi.GetPluginInfoRequest) (*spi.GetPluginInfoResponse, error) { - return &pluginInfo, nil + return nil } -// listMatchingEntries finds registered entries containing exactly the specified selectors. -func (ds *sqlPlugin) listMatchingEntries(selectors []*common.Selector) ([]*common.RegistrationEntry, error) { +// listEntriesWithExactSelectorMatch finds registered entries containing exactly the specified selectors. +func listEntriesWithExactSelectorMatch(tx *gorm.DB, selectors []*common.Selector) ([]*common.RegistrationEntry, error) { if len(selectors) < 1 { return nil, nil } @@ -943,9 +1247,8 @@ func (ds *sqlPlugin) listMatchingEntries(selectors []*common.Selector) ([]*commo refCount := make(map[uint]int) for _, s := range selectors { var results []Selector - err := ds.db.Find(&results, "type = ? AND value = ?", s.Type, s.Value).Error - if err != nil { - return nil, err + if err := tx.Find(&results, "type = ? AND value = ?", s.Type, s.Value).Error; err != nil { + return nil, sqlError.Wrap(err) } for _, r := range results { @@ -969,9 +1272,8 @@ func (ds *sqlPlugin) listMatchingEntries(selectors []*common.Selector) ([]*commo var resp []RegisteredEntry for _, id := range entryIDs { var result RegisteredEntry - err := ds.db.Find(&result, "id = ?", id).Error - if err != nil { - return nil, err + if err := tx.Find(&result, "id = ?", id).Error; err != nil { + return nil, sqlError.Wrap(err) } resp = append(resp, result) @@ -979,7 +1281,7 @@ func (ds *sqlPlugin) listMatchingEntries(selectors []*common.Selector) ([]*commo // Weed out entries that have more selectors than requested, since only // EXACT matches should be returned. - convertedEntries, err := ds.convertEntriesNoSort(resp) + convertedEntries, err := modelsToUnsortedEntries(tx, resp) if err != nil { return nil, err } @@ -995,15 +1297,15 @@ func (ds *sqlPlugin) listMatchingEntries(selectors []*common.Selector) ([]*commo // bundleToModel converts the given Protobuf bundle message to a database model. It // performs validation, and fully parses certificates to form CACert embedded models. -func (ds *sqlPlugin) bundleToModel(pb *datastore.Bundle) (*Bundle, error) { - id, err := ds.validateTrustDomain(pb.TrustDomain) +func bundleToModel(pb *datastore.Bundle) (*Bundle, error) { + id, err := idutil.ParseSpiffeID(pb.TrustDomain, idutil.AllowAnyTrustDomain()) if err != nil { - return nil, err + return nil, sqlError.Wrap(err) } certs, err := x509.ParseCertificates(pb.CaCerts) if err != nil { - return nil, errors.New("could not parse CA certificates") + return nil, sqlError.New("could not parse CA certificates") } // Translate CACerts, if any @@ -1025,67 +1327,8 @@ func (ds *sqlPlugin) bundleToModel(pb *datastore.Bundle) (*Bundle, error) { return bundle, nil } -// modelToBundle converts the given bundle model to a Protobuf bundle message. It will also -// include any embedded CACert models. -func (ds *sqlPlugin) modelToBundle(model *Bundle) (*datastore.Bundle, error) { - id, err := ds.validateTrustDomain(model.TrustDomain) - if err != nil { - return nil, err - } - - caCerts := []byte{} - for _, c := range model.CACerts { - caCerts = append(caCerts, c.Cert...) - } - - pb := &datastore.Bundle{ - TrustDomain: id.String(), - CaCerts: caCerts, - } - - return pb, nil -} - -func (ds *sqlPlugin) validateRegistrationEntry(entry *common.RegistrationEntry) error { - if entry.Selectors == nil || len(entry.Selectors) == 0 { - return errors.New("missing selector list") - } - - if len(entry.SpiffeId) == 0 { - return errors.New("missing SPIFFE ID") - } - - if entry.Ttl < 0 { - return errors.New("TTL is not set") - } - - return nil -} - -// validateTrustDomain converts the given string to a URL, and ensures that it is a correctly -// formatted SPIFFE trust domain. String is taken as the argument here since neither Protobuf nor -// GORM natively support the url.URL type. -// -// A valid trust domain has the SPIFFE scheme, a non-zero host component, and no path -func (ds *sqlPlugin) validateTrustDomain(in string) (*url.URL, error) { - if in == "" { - return nil, errors.New("trust domain is required") - } - - id, err := url.Parse(in) - if err != nil { - return nil, fmt.Errorf("could not parse trust domain %v: %v", in, err) - } - - if id.Scheme != "spiffe" || id.Host == "" || (id.Path != "" && id.Path != "/") { - return nil, fmt.Errorf("%v is not a valid SPIFFE trust domain", id.String()) - } - - return id, nil -} - -func (ds *sqlPlugin) convertEntries(fetchedRegisteredEntries []RegisteredEntry) (responseEntries []*common.RegistrationEntry, err error) { - entries, err := ds.convertEntriesNoSort(fetchedRegisteredEntries) +func modelsToEntries(tx *gorm.DB, fetchedRegisteredEntries []RegisteredEntry) (responseEntries []*common.RegistrationEntry, err error) { + entries, err := modelsToUnsortedEntries(tx, fetchedRegisteredEntries) if err != nil { return nil, err } @@ -1093,71 +1336,42 @@ func (ds *sqlPlugin) convertEntries(fetchedRegisteredEntries []RegisteredEntry) return entries, nil } -func (ds *sqlPlugin) convertEntriesNoSort(fetchedRegisteredEntries []RegisteredEntry) (responseEntries []*common.RegistrationEntry, err error) { +func modelsToUnsortedEntries(tx *gorm.DB, fetchedRegisteredEntries []RegisteredEntry) (responseEntries []*common.RegistrationEntry, err error) { for _, regEntry := range fetchedRegisteredEntries { - var selectors []*common.Selector - var fetchedSelectors []*Selector - if err = ds.db.Model(®Entry).Related(&fetchedSelectors).Error; err != nil { + responseEntry, err := modelToEntry(tx, regEntry) + if err != nil { return nil, err } - - for _, selector := range fetchedSelectors { - selectors = append(selectors, &common.Selector{ - Type: selector.Type, - Value: selector.Value}) - } - responseEntries = append(responseEntries, &common.RegistrationEntry{ - EntryId: regEntry.EntryID, - Selectors: selectors, - SpiffeId: regEntry.SpiffeID, - ParentId: regEntry.ParentID, - Ttl: regEntry.TTL, - }) + responseEntries = append(responseEntries, responseEntry) } return responseEntries, nil } -// restart will close and re-open the gorm database. -func (ds *sqlPlugin) restart() error { - ds.mutex.Lock() - defer ds.mutex.Unlock() - - var db *gorm.DB - var err error - - switch ds.DatabaseType { - case "sqlite3": - db, err = sqlite{}.connect(ds.ConnectionString) - case "postgres": - db, err = postgres{}.connect(ds.ConnectionString) - default: - return fmt.Errorf("unsupported database_type: %v", ds.DatabaseType) - } - if err != nil { - return err - } - - if err := migrateDB(db); err != nil { - db.Close() - return err +func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry, error) { + var selectors []*common.Selector + var fetchedSelectors []*Selector + if err := tx.Model(&model).Related(&fetchedSelectors).Error; err != nil { + return nil, sqlError.Wrap(err) } - if ds.db != nil { - ds.db.Close() + for _, selector := range fetchedSelectors { + selectors = append(selectors, &common.Selector{ + Type: selector.Type, + Value: selector.Value}) } - - ds.db = db - return nil + return &common.RegistrationEntry{ + EntryId: model.EntryID, + Selectors: selectors, + SpiffeId: model.SpiffeID, + ParentId: model.ParentID, + Ttl: model.TTL, + }, nil } -func newPlugin() *sqlPlugin { - return &sqlPlugin{ - mutex: new(sync.Mutex), +func newRegistrationEntryID() (string, error) { + id, err := uuid.NewV4() + if err != nil { + return "", sqlError.New("unable to generate registration entry id: %v", err) } -} - -// New creates a new sql plugin struct. Configure must be called -// in order to start the db. -func New() datastore.Plugin { - return newPlugin() + return id.String(), nil } diff --git a/pkg/server/plugin/datastore/sql/sql_test.go b/pkg/server/plugin/datastore/sql/sql_test.go index 182cae7bc1..7f0c56c927 100644 --- a/pkg/server/plugin/datastore/sql/sql_test.go +++ b/pkg/server/plugin/datastore/sql/sql_test.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "os" "path/filepath" - "sync" "sync/atomic" "testing" "time" @@ -17,63 +16,94 @@ import ( spi "github.com/spiffe/spire/proto/common/plugin" "github.com/spiffe/spire/proto/server/datastore" testutil "github.com/spiffe/spire/test/util" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) var ( ctx = context.Background() - - // nextInMemoryId is atomically incremented and appended to the database - // name for in-memory databases. A unique name is required to prevent - // the in-memory database from being shared. - // - // See https://www.sqlite.org/inmemorydb.html for details. - nextInMemoryId uint64 ) -type regEntries []*common.RegistrationEntry +func TestPlugin(t *testing.T) { + suite.Run(t, new(PluginSuite)) +} -func TestInvalidPluginConfiguration(t *testing.T) { - invalidPlugin := &sqlPlugin{ - mutex: new(sync.Mutex), - DatabaseType: "wrong", - ConnectionString: "string", - } +type PluginSuite struct { + suite.Suite + dir string - if invalidPlugin.restart() == nil { - t.Errorf("Excepted error on invalid database_type: %v", invalidPlugin.DatabaseType) - } + nextId int + ds datastore.Plugin +} + +func (s *PluginSuite) SetupSuite() { + var err error + s.dir, err = ioutil.TempDir("", "spire-datastore-sql-tests") + s.Require().NoError(err) +} + +func (s *PluginSuite) SetupTest() { + s.ds = s.newPlugin() +} + +func (s *PluginSuite) TearDownSuite() { + os.RemoveAll(s.dir) +} + +func (s *PluginSuite) newPlugin() datastore.Plugin { + p := New() + + s.nextId++ + dbPath := filepath.Join(s.dir, fmt.Sprintf("db%d.sqlite3", s.nextId)) + + _, err := p.Configure(context.Background(), &spi.ConfigureRequest{ + Configuration: fmt.Sprintf(` + database_type = "sqlite3" + log_sql = true + connection_string = "file://%s" + `, dbPath), + }) + s.Require().NoError(err) + + return p } -func TestBundle_CRUD(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestInvalidPluginConfiguration() { + _, err := s.ds.Configure(context.Background(), &spi.ConfigureRequest{ + Configuration: ` + database_type = "wrong" + connection_string = "bad" + `, + }) + s.Require().EqualError(err, "datastore-sql: unsupported database_type: wrong") +} +func (s *PluginSuite) TestBundleCRUD() { cert, _, err := testutil.LoadSVIDFixture() - require.NoError(t, err) + s.Require().NoError(err) bundle := &datastore.Bundle{ - TrustDomain: "spiffe://foo/", + TrustDomain: "spiffe://foo", CaCerts: cert.Raw, } // create - _, err = ds.CreateBundle(ctx, bundle) - require.NoError(t, err) + _, err = s.ds.CreateBundle(ctx, bundle) + s.Require().NoError(err) // fetch - fresp, err := ds.FetchBundle(ctx, &datastore.Bundle{TrustDomain: "spiffe://foo/"}) - require.NoError(t, err) - assert.Equal(t, bundle, fresp) + fresp, err := s.ds.FetchBundle(ctx, &datastore.Bundle{TrustDomain: "spiffe://foo"}) + s.Require().NoError(err) + s.Equal(bundle, fresp) // list - lresp, err := ds.ListBundles(ctx, &common.Empty{}) - require.NoError(t, err) - assert.Equal(t, 1, len(lresp.Bundles)) - assert.Equal(t, bundle, lresp.Bundles[0]) + lresp, err := s.ds.ListBundles(ctx, &common.Empty{}) + s.Require().NoError(err) + s.Equal(1, len(lresp.Bundles)) + s.Equal(bundle, lresp.Bundles[0]) cert, _, err = testutil.LoadCAFixture() - require.NoError(t, err) + s.Require().NoError(err) bundle2 := &datastore.Bundle{ TrustDomain: bundle.TrustDomain, @@ -81,50 +111,48 @@ func TestBundle_CRUD(t *testing.T) { } // append - aresp, err := ds.AppendBundle(ctx, bundle2) - require.NoError(t, err) + aresp, err := s.ds.AppendBundle(ctx, bundle2) + s.Require().NoError(err) certs := append(bundle.CaCerts, cert.Raw...) - assert.Equal(t, certs, aresp.CaCerts) + s.Equal(certs, aresp.CaCerts) // append identical - aresp, err = ds.AppendBundle(ctx, bundle2) - require.NoError(t, err) - assert.Equal(t, certs, aresp.CaCerts) + aresp, err = s.ds.AppendBundle(ctx, bundle2) + s.Require().NoError(err) + s.Equal(certs, aresp.CaCerts) // append on a new bundle bundle3 := &datastore.Bundle{ - TrustDomain: "spiffe://bar/", + TrustDomain: "spiffe://bar", CaCerts: cert.Raw, } - anresp, err := ds.AppendBundle(ctx, bundle3) - require.NoError(t, err) - assert.Equal(t, bundle3, anresp) + anresp, err := s.ds.AppendBundle(ctx, bundle3) + s.Require().NoError(err) + s.Equal(bundle3, anresp) // update - uresp, err := ds.UpdateBundle(ctx, bundle2) - require.NoError(t, err) - assert.Equal(t, bundle2, uresp) + uresp, err := s.ds.UpdateBundle(ctx, bundle2) + s.Require().NoError(err) + s.Equal(bundle2, uresp) - lresp, err = ds.ListBundles(ctx, &common.Empty{}) - require.NoError(t, err) - assert.Equal(t, 2, len(lresp.Bundles)) - assert.Equal(t, []*datastore.Bundle{bundle2, bundle3}, lresp.Bundles) + lresp, err = s.ds.ListBundles(ctx, &common.Empty{}) + s.Require().NoError(err) + s.Equal(2, len(lresp.Bundles)) + s.Equal([]*datastore.Bundle{bundle2, bundle3}, lresp.Bundles) // delete - dresp, err := ds.DeleteBundle(ctx, &datastore.Bundle{ + dresp, err := s.ds.DeleteBundle(ctx, &datastore.Bundle{ TrustDomain: bundle.TrustDomain, }) - require.NoError(t, err) - assert.Equal(t, bundle2, dresp) + s.Require().NoError(err) + s.Equal(bundle2, dresp) - lresp, err = ds.ListBundles(ctx, &common.Empty{}) - require.NoError(t, err) - assert.Equal(t, 1, len(lresp.Bundles)) + lresp, err = s.ds.ListBundles(ctx, &common.Empty{}) + s.Require().NoError(err) + s.Equal(1, len(lresp.Bundles)) } -func Test_CreateAttestedNodeEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestCreateAttestedNodeEntry() { entry := &datastore.AttestedNodeEntry{ BaseSpiffeId: "foo", AttestationDataType: "aws-tag", @@ -132,29 +160,26 @@ func Test_CreateAttestedNodeEntry(t *testing.T) { CertExpirationDate: time.Now().Add(time.Hour).Format(datastore.TimeFormat), } - cresp, err := ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) - require.NoError(t, err) - assert.Equal(t, entry, cresp.AttestedNodeEntry) + cresp, err := s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) + s.Require().NoError(err) + s.Equal(entry, cresp.AttestedNodeEntry) - fresp, err := ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) - assert.Equal(t, entry, fresp.AttestedNodeEntry) + fresp, err := s.ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) + s.Equal(entry, fresp.AttestedNodeEntry) - sresp, err := ds.FetchStaleNodeEntries(ctx, &datastore.FetchStaleNodeEntriesRequest{}) - require.NoError(t, err) - assert.Empty(t, sresp.AttestedNodeEntryList) + sresp, err := s.ds.FetchStaleNodeEntries(ctx, &datastore.FetchStaleNodeEntriesRequest{}) + s.Require().NoError(err) + s.Empty(sresp.AttestedNodeEntryList) } -func Test_FetchAttestedNodeEntry_missing(t *testing.T) { - ds := createDefault(t) - fresp, err := ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: "missing"}) - require.NoError(t, err) - require.Nil(t, fresp.AttestedNodeEntry) +func (s *PluginSuite) TestFetchAttestedNodeEntryMissing() { + fresp, err := s.ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: "missing"}) + s.Require().NoError(err) + s.Require().Nil(fresp.AttestedNodeEntry) } -func Test_FetchStaleNodeEntries(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestFetchStaleNodeEntries() { efuture := &datastore.AttestedNodeEntry{ BaseSpiffeId: "foo", AttestationDataType: "aws-tag", @@ -169,20 +194,18 @@ func Test_FetchStaleNodeEntries(t *testing.T) { CertExpirationDate: time.Now().Add(-time.Hour).Format(datastore.TimeFormat), } - _, err := ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: efuture}) - require.NoError(t, err) + _, err := s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: efuture}) + s.Require().NoError(err) - _, err = ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: epast}) - require.NoError(t, err) + _, err = s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: epast}) + s.Require().NoError(err) - sresp, err := ds.FetchStaleNodeEntries(ctx, &datastore.FetchStaleNodeEntriesRequest{}) - require.NoError(t, err) - assert.Equal(t, []*datastore.AttestedNodeEntry{epast}, sresp.AttestedNodeEntryList) + sresp, err := s.ds.FetchStaleNodeEntries(ctx, &datastore.FetchStaleNodeEntriesRequest{}) + s.Require().NoError(err) + s.Equal([]*datastore.AttestedNodeEntry{epast}, sresp.AttestedNodeEntryList) } -func Test_UpdateAttestedNodeEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestUpdateAttestedNodeEntry() { entry := &datastore.AttestedNodeEntry{ BaseSpiffeId: "foo", AttestationDataType: "aws-tag", @@ -193,39 +216,37 @@ func Test_UpdateAttestedNodeEntry(t *testing.T) { userial := "deadbeef" uexpires := time.Now().Add(time.Hour * 2).Format(datastore.TimeFormat) - _, err := ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) - require.NoError(t, err) + _, err := s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) + s.Require().NoError(err) - uresp, err := ds.UpdateAttestedNodeEntry(ctx, &datastore.UpdateAttestedNodeEntryRequest{ + uresp, err := s.ds.UpdateAttestedNodeEntry(ctx, &datastore.UpdateAttestedNodeEntryRequest{ BaseSpiffeId: entry.BaseSpiffeId, CertSerialNumber: userial, CertExpirationDate: uexpires, }) - require.NoError(t, err) + s.Require().NoError(err) uentry := uresp.AttestedNodeEntry - require.NotNil(t, uentry) + s.Require().NotNil(uentry) - assert.Equal(t, entry.BaseSpiffeId, uentry.BaseSpiffeId) - assert.Equal(t, entry.AttestationDataType, uentry.AttestationDataType) - assert.Equal(t, userial, uentry.CertSerialNumber) - assert.Equal(t, uexpires, uentry.CertExpirationDate) + s.Equal(entry.BaseSpiffeId, uentry.BaseSpiffeId) + s.Equal(entry.AttestationDataType, uentry.AttestationDataType) + s.Equal(userial, uentry.CertSerialNumber) + s.Equal(uexpires, uentry.CertExpirationDate) - fresp, err := ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) + fresp, err := s.ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) fentry := fresp.AttestedNodeEntry - require.NotNil(t, fentry) + s.Require().NotNil(fentry) - assert.Equal(t, entry.BaseSpiffeId, fentry.BaseSpiffeId) - assert.Equal(t, entry.AttestationDataType, fentry.AttestationDataType) - assert.Equal(t, userial, fentry.CertSerialNumber) - assert.Equal(t, uexpires, fentry.CertExpirationDate) + s.Equal(entry.BaseSpiffeId, fentry.BaseSpiffeId) + s.Equal(entry.AttestationDataType, fentry.AttestationDataType) + s.Equal(userial, fentry.CertSerialNumber) + s.Equal(uexpires, fentry.CertExpirationDate) } -func Test_DeleteAttestedNodeEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestDeleteAttestedNodeEntry() { entry := &datastore.AttestedNodeEntry{ BaseSpiffeId: "foo", AttestationDataType: "aws-tag", @@ -233,21 +254,19 @@ func Test_DeleteAttestedNodeEntry(t *testing.T) { CertExpirationDate: time.Now().Add(time.Hour).Format(datastore.TimeFormat), } - _, err := ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) - require.NoError(t, err) + _, err := s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) + s.Require().NoError(err) - dresp, err := ds.DeleteAttestedNodeEntry(ctx, &datastore.DeleteAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) - assert.Equal(t, entry, dresp.AttestedNodeEntry) + dresp, err := s.ds.DeleteAttestedNodeEntry(ctx, &datastore.DeleteAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) + s.Equal(entry, dresp.AttestedNodeEntry) - fresp, err := ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) - assert.Nil(t, fresp.AttestedNodeEntry) + fresp, err := s.ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) + s.Nil(fresp.AttestedNodeEntry) } -func Test_CreateNodeResolverMapEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestCreateNodeResolverMapEntry() { entry := &datastore.NodeResolverMapEntry{ BaseSpiffeId: "main", Selector: &common.Selector{ @@ -256,26 +275,23 @@ func Test_CreateNodeResolverMapEntry(t *testing.T) { }, } - cresp, err := ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) - require.NoError(t, err) + cresp, err := s.ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) + s.Require().NoError(err) centry := cresp.NodeResolverMapEntry - assert.Equal(t, entry, centry) + s.Equal(entry, centry) } -func Test_CreateNodeResolverMapEntry_dupe(t *testing.T) { - ds := createDefault(t) - entries := createNodeResolverMapEntries(t, ds) +func (s *PluginSuite) TestCreateNodeResolverMapEntryDuplicate() { + entries := s.createNodeResolverMapEntries(s.ds) entry := entries[0] - cresp, err := ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) - assert.Error(t, err) - require.Nil(t, cresp) + cresp, err := s.ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) + s.Error(err) + s.Require().Nil(cresp) } -func Test_FetchNodeResolverMapEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestFetchNodeResolverMapEntry() { entry := &datastore.NodeResolverMapEntry{ BaseSpiffeId: "main", Selector: &common.Selector{ @@ -284,68 +300,61 @@ func Test_FetchNodeResolverMapEntry(t *testing.T) { }, } - cresp, err := ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) - require.NoError(t, err) + cresp, err := s.ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) + s.Require().NoError(err) centry := cresp.NodeResolverMapEntry - assert.Equal(t, entry, centry) + s.Equal(entry, centry) } -func Test_DeleteNodeResolverMapEntry_specific(t *testing.T) { +func (s *PluginSuite) TestDeleteNodeResolverMapEntry() { // remove entries for the specific (spiffe_id,type,value) - - ds := createDefault(t) - entries := createNodeResolverMapEntries(t, ds) + entries := s.createNodeResolverMapEntries(s.ds) entry_removed := entries[0] - dresp, err := ds.DeleteNodeResolverMapEntry(ctx, &datastore.DeleteNodeResolverMapEntryRequest{NodeResolverMapEntry: entry_removed}) - require.NoError(t, err) + dresp, err := s.ds.DeleteNodeResolverMapEntry(ctx, &datastore.DeleteNodeResolverMapEntryRequest{NodeResolverMapEntry: entry_removed}) + s.Require().NoError(err) - assert.Equal(t, entries[0:1], dresp.NodeResolverMapEntryList) + s.Equal(entries[0:1], dresp.NodeResolverMapEntryList) for idx, entry := range entries[1:] { - fresp, err := ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err, idx) - require.Len(t, fresp.NodeResolverMapEntryList, 1, "%v", idx) - assert.Equal(t, entry, fresp.NodeResolverMapEntryList[0], "%v", idx) + fresp, err := s.ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err, idx) + s.Require().Len(fresp.NodeResolverMapEntryList, 1, "%v", idx) + s.Equal(entry, fresp.NodeResolverMapEntryList[0], "%v", idx) } } -func Test_DeleteNodeResolverMapEntry_all(t *testing.T) { +func (s *PluginSuite) TestDeleteNodeResolverMapEntryAll() { // remove all entries for the spiffe_id - - ds := createDefault(t) - entries := createNodeResolverMapEntries(t, ds) + entries := s.createNodeResolverMapEntries(s.ds) entry_removed := &datastore.NodeResolverMapEntry{ BaseSpiffeId: entries[0].BaseSpiffeId, } - dresp, err := ds.DeleteNodeResolverMapEntry(ctx, &datastore.DeleteNodeResolverMapEntryRequest{NodeResolverMapEntry: entry_removed}) - require.NoError(t, err) + dresp, err := s.ds.DeleteNodeResolverMapEntry(ctx, &datastore.DeleteNodeResolverMapEntryRequest{NodeResolverMapEntry: entry_removed}) + s.Require().NoError(err) - assert.Equal(t, entries[0:2], dresp.NodeResolverMapEntryList) + s.Equal(entries[0:2], dresp.NodeResolverMapEntryList) { entry := entry_removed - fresp, err := ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) - assert.Empty(t, fresp.NodeResolverMapEntryList) + fresp, err := s.ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) + s.Empty(fresp.NodeResolverMapEntryList) } { entry := entries[2] - fresp, err := ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) - require.NoError(t, err) - assert.NotEmpty(t, fresp.NodeResolverMapEntryList) + fresp, err := s.ds.FetchNodeResolverMapEntry(ctx, &datastore.FetchNodeResolverMapEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + s.Require().NoError(err) + s.NotEmpty(fresp.NodeResolverMapEntryList) } } -func Test_RectifyNodeResolverMapEntries(t *testing.T) { -} - -func createNodeResolverMapEntries(t *testing.T, ds datastore.DataStore) []*datastore.NodeResolverMapEntry { +func (s *PluginSuite) createNodeResolverMapEntries(ds datastore.DataStore) []*datastore.NodeResolverMapEntry { entries := []*datastore.NodeResolverMapEntry{ { BaseSpiffeId: "main", @@ -372,46 +381,38 @@ func createNodeResolverMapEntries(t *testing.T, ds datastore.DataStore) []*datas for idx, entry := range entries { _, err := ds.CreateNodeResolverMapEntry(ctx, &datastore.CreateNodeResolverMapEntryRequest{NodeResolverMapEntry: entry}) - require.NoError(t, err, "%v", idx) + s.Require().NoError(err, "%v", idx) } return entries } -func Test_CreateRegistrationEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestCreateRegistrationEntry() { var validRegistrationEntries []*common.RegistrationEntry - err := getTestDataFromJsonFile(t, filepath.Join("testdata", "valid_registration_entries.json"), &validRegistrationEntries) - require.NoError(t, err) + s.getTestDataFromJsonFile(filepath.Join("testdata", "valid_registration_entries.json"), &validRegistrationEntries) for _, validRegistrationEntry := range validRegistrationEntries { - createRegistrationEntryResponse, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: validRegistrationEntry}) - require.NoError(t, err) - assert.NotNil(t, createRegistrationEntryResponse) - assert.NotEmpty(t, createRegistrationEntryResponse.RegisteredEntryId) + createRegistrationEntryResponse, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: validRegistrationEntry}) + s.Require().NoError(err) + s.NotNil(createRegistrationEntryResponse) + s.NotEmpty(createRegistrationEntryResponse.RegisteredEntryId) } } -func Test_CreateInvalidRegistrationEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestCreateInvalidRegistrationEntry() { var invalidRegistrationEntries []*common.RegistrationEntry - err := getTestDataFromJsonFile(t, filepath.Join("testdata", "invalid_registration_entries.json"), &invalidRegistrationEntries) - require.NoError(t, err) + s.getTestDataFromJsonFile(filepath.Join("testdata", "invalid_registration_entries.json"), &invalidRegistrationEntries) for _, invalidRegisteredEntry := range invalidRegistrationEntries { - createRegistrationEntryResponse, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: invalidRegisteredEntry}) - require.Error(t, err) - require.Nil(t, createRegistrationEntryResponse) + createRegistrationEntryResponse, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: invalidRegisteredEntry}) + s.Require().Error(err) + s.Require().Nil(createRegistrationEntryResponse) } // TODO: Check that no entries have been created } -func Test_FetchRegistrationEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestFetchRegistrationEntry() { registeredEntry := &common.RegistrationEntry{ Selectors: []*common.Selector{ {Type: "Type1", Value: "Value1"}, @@ -423,28 +424,24 @@ func Test_FetchRegistrationEntry(t *testing.T) { Ttl: 1, } - createRegistrationEntryResponse, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: registeredEntry}) - require.NoError(t, err) - require.NotNil(t, createRegistrationEntryResponse) + createRegistrationEntryResponse, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: registeredEntry}) + s.Require().NoError(err) + s.Require().NotNil(createRegistrationEntryResponse) registeredEntry.EntryId = createRegistrationEntryResponse.RegisteredEntryId - fetchRegistrationEntryResponse, err := ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: createRegistrationEntryResponse.RegisteredEntryId}) - require.NoError(t, err) - require.NotNil(t, fetchRegistrationEntryResponse) - assert.Equal(t, registeredEntry, fetchRegistrationEntryResponse.RegisteredEntry) + fetchRegistrationEntryResponse, err := s.ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: createRegistrationEntryResponse.RegisteredEntryId}) + s.Require().NoError(err) + s.Require().NotNil(fetchRegistrationEntryResponse) + s.Equal(registeredEntry, fetchRegistrationEntryResponse.RegisteredEntry) } -func Test_FetchInexistentRegistrationEntry(t *testing.T) { - ds := createDefault(t) - - fetchRegistrationEntryResponse, err := ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: "INEXISTENT"}) - require.NoError(t, err) - require.Nil(t, fetchRegistrationEntryResponse.RegisteredEntry) +func (s *PluginSuite) TestFetchInexistentRegistrationEntry() { + fetchRegistrationEntryResponse, err := s.ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: "INEXISTENT"}) + s.Require().NoError(err) + s.Require().Nil(fetchRegistrationEntryResponse.RegisteredEntry) } -func Test_FetchRegistrationEntries(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestFetchRegistrationEntries() { entry1 := &common.RegistrationEntry{ Selectors: []*common.Selector{ {Type: "Type1", Value: "Value1"}, @@ -467,31 +464,29 @@ func Test_FetchRegistrationEntries(t *testing.T) { Ttl: 2, } - createRegistrationEntryResponse, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) - require.NoError(t, err) - require.NotNil(t, createRegistrationEntryResponse) + createRegistrationEntryResponse, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) + s.Require().NoError(err) + s.Require().NotNil(createRegistrationEntryResponse) entry1.EntryId = createRegistrationEntryResponse.RegisteredEntryId - createRegistrationEntryResponse, err = ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry2}) - require.NoError(t, err) - require.NotNil(t, createRegistrationEntryResponse) + createRegistrationEntryResponse, err = s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry2}) + s.Require().NoError(err) + s.Require().NotNil(createRegistrationEntryResponse) entry2.EntryId = createRegistrationEntryResponse.RegisteredEntryId - fetchRegistrationEntriesResponse, err := ds.FetchRegistrationEntries(ctx, &common.Empty{}) - require.NoError(t, err) - require.NotNil(t, fetchRegistrationEntriesResponse) + fetchRegistrationEntriesResponse, err := s.ds.FetchRegistrationEntries(ctx, &common.Empty{}) + s.Require().NoError(err) + s.Require().NotNil(fetchRegistrationEntriesResponse) expectedResponse := &datastore.FetchRegistrationEntriesResponse{ RegisteredEntries: &common.RegistrationEntries{ Entries: []*common.RegistrationEntry{entry2, entry1}, }, } - assert.Equal(t, expectedResponse, fetchRegistrationEntriesResponse) + s.Equal(expectedResponse, fetchRegistrationEntriesResponse) } -func Test_UpdateRegistrationEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestUpdateRegistrationEntry() { entry1 := &common.RegistrationEntry{ Selectors: []*common.Selector{ {Type: "Type1", Value: "Value1"}, @@ -503,9 +498,9 @@ func Test_UpdateRegistrationEntry(t *testing.T) { Ttl: 1, } - createRegistrationEntryResponse, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) - require.NoError(t, err) - require.NotNil(t, createRegistrationEntryResponse) + createRegistrationEntryResponse, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) + s.Require().NoError(err) + s.Require().NotNil(createRegistrationEntryResponse) // TODO: Refactor message type to take EntryID directly from the entry - see #449 entry1.Ttl = 2 @@ -513,21 +508,19 @@ func Test_UpdateRegistrationEntry(t *testing.T) { RegisteredEntryId: createRegistrationEntryResponse.RegisteredEntryId, RegisteredEntry: entry1, } - updateRegistrationEntryResponse, err := ds.UpdateRegistrationEntry(ctx, updReq) - require.NoError(t, err) - require.NotNil(t, updateRegistrationEntryResponse) + updateRegistrationEntryResponse, err := s.ds.UpdateRegistrationEntry(ctx, updReq) + s.Require().NoError(err) + s.Require().NotNil(updateRegistrationEntryResponse) - fetchRegistrationEntryResponse, err := ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: updReq.RegisteredEntryId}) - require.NoError(t, err) - require.NotNil(t, fetchRegistrationEntryResponse) + fetchRegistrationEntryResponse, err := s.ds.FetchRegistrationEntry(ctx, &datastore.FetchRegistrationEntryRequest{RegisteredEntryId: updReq.RegisteredEntryId}) + s.Require().NoError(err) + s.Require().NotNil(fetchRegistrationEntryResponse) expectedResponse := &datastore.FetchRegistrationEntryResponse{RegisteredEntry: entry1} - assert.Equal(t, expectedResponse, fetchRegistrationEntryResponse) + s.Equal(expectedResponse, fetchRegistrationEntryResponse) } -func Test_DeleteRegistrationEntry(t *testing.T) { - ds := createDefault(t) - +func (s *PluginSuite) TestDeleteRegistrationEntry() { entry1 := &common.RegistrationEntry{ Selectors: []*common.Selector{ {Type: "Type1", Value: "Value1"}, @@ -550,23 +543,23 @@ func Test_DeleteRegistrationEntry(t *testing.T) { Ttl: 2, } - res1, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) - require.NoError(t, err) - require.NotNil(t, res1) + res1, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry1}) + s.Require().NoError(err) + s.Require().NotNil(res1) entry1.EntryId = res1.RegisteredEntryId - res2, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry2}) - require.NoError(t, err) - require.NotNil(t, res2) + res2, err := s.ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry2}) + s.Require().NoError(err) + s.Require().NotNil(res2) entry2.EntryId = res2.RegisteredEntryId // Make sure we deleted the right one - delRes, err := ds.DeleteRegistrationEntry(ctx, &datastore.DeleteRegistrationEntryRequest{RegisteredEntryId: res1.RegisteredEntryId}) - require.NoError(t, err) - require.Equal(t, entry1, delRes.RegisteredEntry) + delRes, err := s.ds.DeleteRegistrationEntry(ctx, &datastore.DeleteRegistrationEntryRequest{RegisteredEntryId: res1.RegisteredEntryId}) + s.Require().NoError(err) + s.Require().Equal(entry1, delRes.RegisteredEntry) } -func TestgormPlugin_ListParentIDEntries(t *testing.T) { +func (s *PluginSuite) TestListParentIDEntries() { allEntries := testutil.GetRegistrationEntries("entries.json") tests := []struct { name string @@ -589,21 +582,21 @@ func TestgormPlugin_ListParentIDEntries(t *testing.T) { }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ds := createDefault(t) + s.T().Run(test.name, func(t *testing.T) { + ds := s.newPlugin() for _, entry := range test.registrationEntries { r, _ := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry}) entry.EntryId = r.RegisteredEntryId } result, err := ds.ListParentIDEntries(ctx, &datastore.ListParentIDEntriesRequest{ ParentId: test.parentID}) - require.NoError(t, err) - assert.Equal(t, test.expectedList, result.RegisteredEntryList) + s.Require().NoError(err) + s.Equal(test.expectedList, result.RegisteredEntryList) }) } } -func Test_ListSelectorEntries(t *testing.T) { +func (s *PluginSuite) TestListSelectorEntries() { allEntries := testutil.GetRegistrationEntries("entries.json") tests := []struct { name string @@ -619,7 +612,7 @@ func Test_ListSelectorEntries(t *testing.T) { {Type: "b", Value: "2"}, {Type: "c", Value: "3"}, }, - expectedList: regEntries{allEntries[0]}, + expectedList: []*common.RegistrationEntry{allEntries[0]}, }, { name: "entries_by_selector_not_found", @@ -631,22 +624,22 @@ func Test_ListSelectorEntries(t *testing.T) { }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ds := createDefault(t) + s.T().Run(test.name, func(t *testing.T) { + ds := s.newPlugin() for _, entry := range test.registrationEntries { r, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry}) - require.NoError(t, err) + s.Require().NoError(err) entry.EntryId = r.RegisteredEntryId } result, err := ds.ListSelectorEntries(ctx, &datastore.ListSelectorEntriesRequest{ Selectors: test.selectors}) - require.NoError(t, err) - assert.Equal(t, test.expectedList, result.RegisteredEntryList) + s.Require().NoError(err) + s.Equal(test.expectedList, result.RegisteredEntryList) }) } } -func Test_ListMatchingEntries(t *testing.T) { +func (s *PluginSuite) TestListMatchingEntries() { allEntries := testutil.GetRegistrationEntries("entries.json") tests := []struct { name string @@ -678,152 +671,131 @@ func Test_ListMatchingEntries(t *testing.T) { }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ds := createDefault(t) + s.T().Run(test.name, func(t *testing.T) { + ds := s.newPlugin() for _, entry := range test.registrationEntries { r, err := ds.CreateRegistrationEntry(ctx, &datastore.CreateRegistrationEntryRequest{RegisteredEntry: entry}) - require.NoError(t, err) + s.Require().NoError(err) entry.EntryId = r.RegisteredEntryId } result, err := ds.ListMatchingEntries(ctx, &datastore.ListSelectorEntriesRequest{ Selectors: test.selectors}) - require.NoError(t, err) - assert.Equal(t, test.expectedList, result.RegisteredEntryList) + s.Require().NoError(err) + s.Equal(test.expectedList, result.RegisteredEntryList) }) } } -func Test_ListSpiffeEntriesEntry(t *testing.T) { - t.Skipf("TODO") -} - -func Test_RegisterToken(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestRegisterToken() { now := time.Now().Unix() req := &datastore.JoinToken{ Token: "foobar", Expiry: now, } - _, err := ds.RegisterToken(ctx, req) - require.NoError(t, err) + _, err := s.ds.RegisterToken(ctx, req) + s.Require().NoError(err) // Make sure we can't re-register - _, err = ds.RegisterToken(ctx, req) - assert.NotNil(t, err) + _, err = s.ds.RegisterToken(ctx, req) + s.NotNil(err) } -func Test_RegisterAndFetchToken(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestRegisterAndFetchToken() { now := time.Now().Unix() req := &datastore.JoinToken{ Token: "foobar", Expiry: now, } - _, err := ds.RegisterToken(ctx, req) - require.NoError(t, err) + _, err := s.ds.RegisterToken(ctx, req) + s.Require().NoError(err) // Don't need expiry for fetch req.Expiry = 0 - res, err := ds.FetchToken(ctx, req) - require.NoError(t, err) - assert.Equal(t, "foobar", res.Token) - assert.Equal(t, now, res.Expiry) + res, err := s.ds.FetchToken(ctx, req) + s.Require().NoError(err) + s.Equal("foobar", res.Token) + s.Equal(now, res.Expiry) } -func Test_DeleteToken(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestDeleteToken() { now := time.Now().Unix() req1 := &datastore.JoinToken{ Token: "foobar", Expiry: now, } - _, err := ds.RegisterToken(ctx, req1) - require.NoError(t, err) + _, err := s.ds.RegisterToken(ctx, req1) + s.Require().NoError(err) req2 := &datastore.JoinToken{ Token: "batbaz", Expiry: now, } - _, err = ds.RegisterToken(ctx, req2) - require.NoError(t, err) + _, err = s.ds.RegisterToken(ctx, req2) + s.Require().NoError(err) // Don't need expiry for delete req1.Expiry = 0 - _, err = ds.DeleteToken(ctx, req1) - require.NoError(t, err) + _, err = s.ds.DeleteToken(ctx, req1) + s.Require().NoError(err) // Should not be able to fetch after delete - resp, err := ds.FetchToken(ctx, req1) - require.NoError(t, err) - assert.Equal(t, "", resp.Token) + resp, err := s.ds.FetchToken(ctx, req1) + s.Require().NoError(err) + s.Equal("", resp.Token) // Second token should still be present - resp, err = ds.FetchToken(ctx, req2) - require.NoError(t, err) - assert.Equal(t, req2.Token, resp.Token) + resp, err = s.ds.FetchToken(ctx, req2) + s.Require().NoError(err) + s.Equal(req2.Token, resp.Token) } -func Test_PruneTokens(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestPruneTokens() { now := time.Now().Unix() req := &datastore.JoinToken{ Token: "foobar", Expiry: now, } - _, err := ds.RegisterToken(ctx, req) - require.NoError(t, err) + _, err := s.ds.RegisterToken(ctx, req) + s.Require().NoError(err) // Ensure we don't prune valid tokens, wind clock back 10s req.Expiry = (now - 10) - _, err = ds.PruneTokens(ctx, req) - require.NoError(t, err) - resp, err := ds.FetchToken(ctx, req) - require.NoError(t, err) - assert.Equal(t, "foobar", resp.Token) + _, err = s.ds.PruneTokens(ctx, req) + s.Require().NoError(err) + resp, err := s.ds.FetchToken(ctx, req) + s.Require().NoError(err) + s.Equal("foobar", resp.Token) // Ensure we prune old tokens req.Expiry = (now + 10) - _, err = ds.PruneTokens(ctx, req) - require.NoError(t, err) - resp, err = ds.FetchToken(ctx, req) - require.NoError(t, err) - assert.Equal(t, "", resp.Token) -} - -func Test_Configure(t *testing.T) { - t.Skipf("TODO") + _, err = s.ds.PruneTokens(ctx, req) + s.Require().NoError(err) + resp, err = s.ds.FetchToken(ctx, req) + s.Require().NoError(err) + s.Equal("", resp.Token) } -func Test_GetPluginInfo(t *testing.T) { - ds := createDefault(t) - resp, err := ds.GetPluginInfo(ctx, &spi.GetPluginInfoRequest{}) - require.NoError(t, err) - require.NotNil(t, resp) +func (s *PluginSuite) TestGetPluginInfo() { + resp, err := s.ds.GetPluginInfo(ctx, &spi.GetPluginInfoRequest{}) + s.Require().NoError(err) + s.Require().NotNil(resp) } -func Test_Migration(t *testing.T) { - require := require.New(t) - - tmpDir, err := ioutil.TempDir("", "spire-sql-datastore") - require.NoError(err) - defer os.RemoveAll(tmpDir) - - ds := New() - +func (s *PluginSuite) TestMigration() { for i := 0; i < codeVersion; i++ { dbName := fmt.Sprintf("v%d.sqlite3", i) - dbPath := filepath.Join(tmpDir, dbName) + dbPath := filepath.Join(s.dir, "migration-"+dbName) // copy the database file from the test data - require.NoError(copyFile(filepath.Join("testdata", "migration", dbName), dbPath)) + s.Require().NoError(copyFile(filepath.Join("testdata", "migration", dbName), dbPath)) // configure the datastore to use the new database - _, err = ds.Configure(context.Background(), &spi.ConfigureRequest{ + _, err := s.ds.Configure(context.Background(), &spi.ConfigureRequest{ Configuration: fmt.Sprintf(` database_type = "sqlite3" connection_string = "file://%s" `, dbPath), }) - require.NoError(err) + s.Require().NoError(err) switch i { case 0: @@ -832,55 +804,41 @@ func Test_Migration(t *testing.T) { // exist. if we try and create a bundle with the same id, it should // fail if the migration did not run, due to uniqueness // constraints. - _, err := ds.CreateBundle(context.Background(), &datastore.Bundle{ + _, err := s.ds.CreateBundle(context.Background(), &datastore.Bundle{ TrustDomain: "spiffe://otherdomain.org", }) - require.NoError(err) + s.Require().NoError(err) default: - t.Fatalf("no migration test added for version %d", i) + s.T().Fatalf("no migration test added for version %d", i) } } } -func Test_race(t *testing.T) { - ds := createDefault(t) +func (s *PluginSuite) TestRace() { + next := int64(0) + exp := time.Now().Add(time.Hour).Format(datastore.TimeFormat) - entry := &datastore.AttestedNodeEntry{ - BaseSpiffeId: "foo", - AttestationDataType: "aws-tag", - CertSerialNumber: "badcafe", - CertExpirationDate: time.Now().Add(time.Hour).Format(datastore.TimeFormat), - } + testutil.RaceTest(s.T(), func(t *testing.T) { + entry := &datastore.AttestedNodeEntry{ + BaseSpiffeId: fmt.Sprintf("foo%d", atomic.AddInt64(&next, 1)), + AttestationDataType: "aws-tag", + CertSerialNumber: "badcafe", + CertExpirationDate: exp, + } - testutil.RaceTest(t, func(t *testing.T) { - ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) - ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + _, err := s.ds.CreateAttestedNodeEntry(ctx, &datastore.CreateAttestedNodeEntryRequest{AttestedNodeEntry: entry}) + require.NoError(t, err) + _, err = s.ds.FetchAttestedNodeEntry(ctx, &datastore.FetchAttestedNodeEntryRequest{BaseSpiffeId: entry.BaseSpiffeId}) + require.NoError(t, err) }) } -func createDefault(t *testing.T) datastore.Plugin { - p := newPlugin() - p.DatabaseType = "sqlite3" - p.ConnectionString = fmt.Sprintf("file:memdb%d?mode=memory&cache=shared", atomic.AddUint64(&nextInMemoryId, 1)) - - require.NoError(t, p.restart()) - - p.db.LogMode(true) - return p -} - -func getTestDataFromJsonFile(t *testing.T, filePath string, jsonValue interface{}) error { +func (s *PluginSuite) getTestDataFromJsonFile(filePath string, jsonValue interface{}) { invalidRegistrationEntriesJson, err := ioutil.ReadFile(filePath) - if err != nil { - return err - } + s.Require().NoError(err) err = json.Unmarshal(invalidRegistrationEntriesJson, &jsonValue) - if err != nil { - return err - } - - return nil + s.Require().NoError(err) } func copyFile(src, dst string) error { diff --git a/pkg/server/plugin/datastore/sql/sqlite.go b/pkg/server/plugin/datastore/sql/sqlite.go index 11f6d75f9c..577587be40 100644 --- a/pkg/server/plugin/datastore/sql/sqlite.go +++ b/pkg/server/plugin/datastore/sql/sqlite.go @@ -10,9 +10,16 @@ type sqlite struct{} func (s sqlite) connect(connectionString string) (*gorm.DB, error) { db, err := gorm.Open("sqlite3", connectionString) if err != nil { - return nil, err + return nil, sqlError.Wrap(err) + } + if err := db.Exec("PRAGMA journal_mode = WAL").Error; err != nil { + db.Close() + return nil, sqlError.Wrap(err) + } + if err := db.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + db.Close() + return nil, sqlError.Wrap(err) } - db.Exec("PRAGMA foreign_keys = ON") return db, nil }