From f0905e9dd0706398fd0201df4c32ae55838db25e Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Fri, 27 Oct 2023 15:11:56 -0400 Subject: [PATCH] Add "ModifyQuery" to CRUD interface for arbitrary SQL processing While ReadQueryModifier allows CRUD models to have static additions (for joins or complex filtering), this will allow for more dynamic additions to any query. Signed-off-by: Andrew Richardson --- mocks/crudmocks/crud.go | 24 ++++++++++++++++++---- pkg/dbsql/crud.go | 42 ++++++++++++++++++++++++++++---------- pkg/dbsql/crud_test.go | 10 +++++++++ pkg/dbsql/database.go | 11 +++++++--- pkg/dbsql/database_test.go | 21 ++++++++++++++----- pkg/dbsql/filter_sql.go | 4 +++- 6 files changed, 88 insertions(+), 24 deletions(-) diff --git a/mocks/crudmocks/crud.go b/mocks/crudmocks/crud.go index 8cc9772..2952f86 100644 --- a/mocks/crudmocks/crud.go +++ b/mocks/crudmocks/crud.go @@ -309,6 +309,22 @@ func (_m *CRUD[T]) InsertMany(ctx context.Context, instances []T, allowPartialSu return r0 } +// ModifyQuery provides a mock function with given fields: modifier +func (_m *CRUD[T]) ModifyQuery(modifier func(squirrel.SelectBuilder) squirrel.SelectBuilder) dbsql.CRUDQuery[T] { + ret := _m.Called(modifier) + + var r0 dbsql.CRUDQuery[T] + if rf, ok := ret.Get(0).(func(func(squirrel.SelectBuilder) squirrel.SelectBuilder) dbsql.CRUDQuery[T]); ok { + r0 = rf(modifier) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(dbsql.CRUDQuery[T]) + } + } + + return r0 +} + // Replace provides a mock function with given fields: ctx, inst, hooks func (_m *CRUD[T]) Replace(ctx context.Context, inst T, hooks ...dbsql.PostCompletionHook) error { _va := make([]interface{}, len(hooks)) @@ -331,15 +347,15 @@ func (_m *CRUD[T]) Replace(ctx context.Context, inst T, hooks ...dbsql.PostCompl } // Scoped provides a mock function with given fields: scope -func (_m *CRUD[T]) Scoped(scope squirrel.Eq) *dbsql.CrudBase[T] { +func (_m *CRUD[T]) Scoped(scope squirrel.Eq) dbsql.CRUD[T] { ret := _m.Called(scope) - var r0 *dbsql.CrudBase[T] - if rf, ok := ret.Get(0).(func(squirrel.Eq) *dbsql.CrudBase[T]); ok { + var r0 dbsql.CRUD[T] + if rf, ok := ret.Get(0).(func(squirrel.Eq) dbsql.CRUD[T]); ok { r0 = rf(scope) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*dbsql.CrudBase[T]) + r0 = ret.Get(0).(dbsql.CRUD[T]) } } diff --git a/pkg/dbsql/crud.go b/pkg/dbsql/crud.go index d212e7a..31d1214 100644 --- a/pkg/dbsql/crud.go +++ b/pkg/dbsql/crud.go @@ -91,12 +91,7 @@ func (r *ResourceBase) SetUpdated(t *fftypes.FFTime) { r.Updated = t } -type CRUD[T Resource] interface { - Validate() - Upsert(ctx context.Context, inst T, optimization UpsertOptimization, hooks ...PostCompletionHook) (created bool, err error) - InsertMany(ctx context.Context, instances []T, allowPartialSuccess bool, hooks ...PostCompletionHook) (err error) - Insert(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error) - Replace(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error) +type CRUDQuery[T Resource] interface { GetByID(ctx context.Context, id string, getOpts ...GetOption) (inst T, err error) GetByUUIDOrName(ctx context.Context, uuidOrName string, getOpts ...GetOption) (result T, err error) GetByName(ctx context.Context, name string, getOpts ...GetOption) (instance T, err error) @@ -104,12 +99,22 @@ type CRUD[T Resource] interface { GetSequenceForID(ctx context.Context, id string) (seq int64, err error) GetMany(ctx context.Context, filter ffapi.Filter) (instances []T, fr *ffapi.FilterResult, err error) Count(ctx context.Context, filter ffapi.Filter) (count int64, err error) + ModifyQuery(modifier QueryModifier) CRUDQuery[T] +} + +type CRUD[T Resource] interface { + CRUDQuery[T] + Validate() + Upsert(ctx context.Context, inst T, optimization UpsertOptimization, hooks ...PostCompletionHook) (created bool, err error) + InsertMany(ctx context.Context, instances []T, allowPartialSuccess bool, hooks ...PostCompletionHook) (err error) + Insert(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error) + Replace(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error) Update(ctx context.Context, id string, update ffapi.Update, hooks ...PostCompletionHook) (err error) UpdateSparse(ctx context.Context, sparseUpdate T, hooks ...PostCompletionHook) (err error) UpdateMany(ctx context.Context, filter ffapi.Filter, update ffapi.Update, hooks ...PostCompletionHook) (err error) Delete(ctx context.Context, id string, hooks ...PostCompletionHook) (err error) DeleteMany(ctx context.Context, filter ffapi.Filter, hooks ...PostCompletionHook) (err error) // no events - Scoped(scope sq.Eq) *CrudBase[T] // allows dynamic scoping to a collection + Scoped(scope sq.Eq) CRUD[T] // allows dynamic scoping to a collection } type CrudBase[T Resource] struct { @@ -133,15 +138,30 @@ type CrudBase[T Resource] struct { // Optional extensions ReadTableAlias string ReadOnlyColumns []string - ReadQueryModifier func(sq.SelectBuilder) sq.SelectBuilder + ReadQueryModifier QueryModifier } -func (c *CrudBase[T]) Scoped(scope sq.Eq) *CrudBase[T] { +func (c *CrudBase[T]) Scoped(scope sq.Eq) CRUD[T] { cScoped := *c cScoped.ScopedFilter = func() sq.Eq { return scope } return &cScoped } +func (c *CrudBase[T]) ModifyQuery(newModifier QueryModifier) CRUDQuery[T] { + cModified := *c + originalModifier := cModified.ReadQueryModifier + cModified.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder { + if originalModifier != nil { + sb = originalModifier(sb) + } + if newModifier != nil { + sb = newModifier(sb) + } + return sb + } + return &cModified +} + func UUIDValidator(ctx context.Context, idStr string) error { _, err := fftypes.ParseUUID(ctx, idStr) return err @@ -656,7 +676,7 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f } instances = append(instances, inst) } - return instances, c.DB.QueryRes(ctx, c.Table, tx, fop, fi), err + return instances, c.DB.QueryRes(ctx, c.Table, tx, fop, c.ReadQueryModifier, fi), err } func (c *CrudBase[T]) Count(ctx context.Context, filter ffapi.Filter) (count int64, err error) { @@ -675,7 +695,7 @@ func (c *CrudBase[T]) Count(ctx context.Context, filter ffapi.Filter) (count int fop, } } - return c.DB.CountQuery(ctx, c.Table, nil, fop, "*") + return c.DB.CountQuery(ctx, c.Table, nil, fop, c.ReadQueryModifier, "*") } func (c *CrudBase[T]) Update(ctx context.Context, id string, update ffapi.Update, hooks ...PostCompletionHook) (err error) { diff --git a/pkg/dbsql/crud_test.go b/pkg/dbsql/crud_test.go index 56512e6..0dc0d8c 100644 --- a/pkg/dbsql/crud_test.go +++ b/pkg/dbsql/crud_test.go @@ -381,6 +381,16 @@ func TestCRUDWithDBEnd2End(t *testing.T) { assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) + // Check we get it back with custom modifiers + collection.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder { + return sb.Where(sq.Eq{"ns": "ns1"}) + } + c1copy, err = iCrud.ModifyQuery(func(sb sq.SelectBuilder) sq.SelectBuilder { + return sb.Where(sq.Eq{"field1": "hello1"}) + }).GetByName(ctx, *c1.Name) + assert.NoError(t, err) + checkEqualExceptTimes(t, *c1, *c1copy) + // Upsert the existing row optimized c1copy.Field1 = strPtr("hello again - 1") created, err := iCrud.Upsert(ctx, c1copy, UpsertOptimizationExisting) diff --git a/pkg/dbsql/database.go b/pkg/dbsql/database.go index a5a96fa..926d744 100644 --- a/pkg/dbsql/database.go +++ b/pkg/dbsql/database.go @@ -42,6 +42,8 @@ type Database struct { sequenceColumn string } +type QueryModifier = func(sq.SelectBuilder) sq.SelectBuilder + // PreCommitAccumulator is a structure that can accumulate state during // the transaction, then has a function that is called just before commit. type PreCommitAccumulator interface { @@ -225,7 +227,7 @@ func (s *Database) Query(ctx context.Context, table string, q sq.SelectBuilder) return s.QueryTx(ctx, table, nil, q) } -func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper, fop sq.Sqlizer, countExpr string) (count int64, err error) { +func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper, fop sq.Sqlizer, qm QueryModifier, countExpr string) (count int64, err error) { count = -1 l := log.L(ctx) if tx == nil { @@ -237,6 +239,9 @@ func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper, countExpr = "*" } q := sq.Select(fmt.Sprintf("COUNT(%s)", countExpr)).From(table).Where(fop) + if qm != nil { + q = qm(q) + } sqlQuery, args, err := q.PlaceholderFormat(s.features.PlaceholderFormat).ToSql() if err != nil { return count, i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed) @@ -263,10 +268,10 @@ func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper, return count, nil } -func (s *Database) QueryRes(ctx context.Context, table string, tx *TXWrapper, fop sq.Sqlizer, fi *ffapi.FilterInfo) *ffapi.FilterResult { +func (s *Database) QueryRes(ctx context.Context, table string, tx *TXWrapper, fop sq.Sqlizer, qm QueryModifier, fi *ffapi.FilterInfo) *ffapi.FilterResult { fr := &ffapi.FilterResult{} if fi.Count { - count, err := s.CountQuery(ctx, table, tx, fop, fi.CountExpr) + count, err := s.CountQuery(ctx, table, tx, fop, qm, fi.CountExpr) if err != nil { // Log, but continue log.L(ctx).Warnf("Unable to return count for query: %s", err) diff --git a/pkg/dbsql/database_test.go b/pkg/dbsql/database_test.go index 39a21e0..34711a3 100644 --- a/pkg/dbsql/database_test.go +++ b/pkg/dbsql/database_test.go @@ -499,14 +499,14 @@ func TestRollbackFail(t *testing.T) { func TestCountQueryBadSQL(t *testing.T) { s, _ := NewMockProvider().UTInit() - _, err := s.CountQuery(context.Background(), "table1", nil, sq.Insert("wrong"), "") + _, err := s.CountQuery(context.Background(), "table1", nil, sq.Insert("wrong"), nil, "") assert.Regexp(t, "FF00174", err) } func TestCountQueryQueryFailed(t *testing.T) { s, mdb := NewMockProvider().UTInit() mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnError(fmt.Errorf("pop")) - _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, "") + _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, nil, "") assert.Regexp(t, "FF00176.*pop", err) } @@ -516,21 +516,32 @@ func TestCountQueryScanFailTx(t *testing.T) { mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow("not a number")) ctx, tx, _, err := s.BeginOrUseTx(context.Background()) assert.NoError(t, err) - _, err = s.CountQuery(ctx, "table1", tx, sq.Eq{"col1": "val1"}, "") + _, err = s.CountQuery(ctx, "table1", tx, sq.Eq{"col1": "val1"}, nil, "") assert.Regexp(t, "FF00182", err) } func TestCountQueryWithExpr(t *testing.T) { s, mdb := NewMockProvider().UTInit() mdb.ExpectQuery("^SELECT COUNT\\(DISTINCT key\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10)) - _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, "DISTINCT key") + _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, nil, "DISTINCT key") + assert.NoError(t, err) + assert.NoError(t, mdb.ExpectationsWereMet()) +} + +func TestCountQueryWithModifier(t *testing.T) { + s, mdb := NewMockProvider().UTInit() + mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10)) + qm := func(sb sq.SelectBuilder) sq.SelectBuilder { + return sb.Where(sq.Eq{"col1": "val1"}) + } + _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, qm, "") assert.NoError(t, err) assert.NoError(t, mdb.ExpectationsWereMet()) } func TestQueryResSwallowError(t *testing.T) { s, _ := NewMockProvider().UTInit() - res := s.QueryRes(context.Background(), "table1", nil, sq.Insert("wrong"), &ffapi.FilterInfo{ + res := s.QueryRes(context.Background(), "table1", nil, sq.Insert("wrong"), nil, &ffapi.FilterInfo{ Count: true, }) assert.Equal(t, int64(-1), *res.TotalCount) diff --git a/pkg/dbsql/filter_sql.go b/pkg/dbsql/filter_sql.go index 26a109c..ca80d50 100644 --- a/pkg/dbsql/filter_sql.go +++ b/pkg/dbsql/filter_sql.go @@ -143,7 +143,9 @@ func (s *Database) filterSelectFinalized(ctx context.Context, tableName string, sort[i] = fmt.Sprintf("%s%s%s", s.mapFieldName(tableName, sf.Field, typeMap), direction, nulls) } sortString = strings.Join(sort, ", ") - sel = sel.OrderBy(sortString) + if sortString != "" { + sel = sel.OrderBy(sortString) + } if fi.Skip > 0 { sel = sel.Offset(fi.Skip) }