Skip to content

Commit

Permalink
feat: add BeforeAppendModelHook
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 15, 2021
1 parent b350bf0 commit 0b55de7
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 65 deletions.
10 changes: 9 additions & 1 deletion bun.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ type (

NullTime = schema.NullTime
BaseModel = schema.BaseModel
Query = schema.Query

BeforeAppendModelHook = schema.BeforeAppendModelHook

BeforeScanRowHook = schema.BeforeScanRowHook
AfterScanRowHook = schema.AfterScanRowHook

// DEPRECATED. Use BeforeScanRowHook instead.
BeforeScanHook = schema.BeforeScanHook
AfterScanHook = schema.AfterScanHook
// DEPRECATED. Use AfterScanRowHook instead.
AfterScanHook = schema.AfterScanHook
)

type BeforeSelectHook interface {
Expand Down
22 changes: 2 additions & 20 deletions hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,18 @@ package bun
import (
"context"
"database/sql"
"reflect"
"strings"
"sync/atomic"
"time"

"github.com/uptrace/bun/schema"
)

type IQuery interface {
schema.QueryAppender
Operation() string
GetModel() Model
GetTableName() string
}

type QueryEvent struct {
DB *DB

QueryAppender schema.QueryAppender // Deprecated: use IQuery instead
IQuery IQuery
IQuery Query
Query string
QueryArgs []interface{}
Model Model
Expand Down Expand Up @@ -58,7 +50,7 @@ type QueryHook interface {

func (db *DB) beforeQuery(
ctx context.Context,
iquery IQuery,
iquery Query,
query string,
queryArgs []interface{},
model Model,
Expand Down Expand Up @@ -116,13 +108,3 @@ func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIn
db.queryHooks[hookIndex].AfterQuery(ctx, event)
}
}

//------------------------------------------------------------------------------

func callBeforeScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx)
}

func callAfterScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(schema.AfterScanHook).AfterScan(ctx)
}
49 changes: 39 additions & 10 deletions internal/dbtest/model_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,14 @@ func TestModelHook(t *testing.T) {
}

func testModelHook(t *testing.T, dbName string, db *bun.DB) {
_, err := db.NewDropTable().Model((*ModelHookTest)(nil)).IfExists().Exec(ctx)
require.NoError(t, err)

_, err = db.NewCreateTable().Model((*ModelHookTest)(nil)).Exec(ctx)
err := db.ResetModel(ctx, (*ModelHookTest)(nil))
require.NoError(t, err)

{
hook := &ModelHookTest{ID: 1}
_, err := db.NewInsert().Model(hook).Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{"BeforeInsert", "AfterInsert"}, events.Flush())
require.Equal(t, []string{"BeforeInsert", "BeforeAppendModel", "AfterInsert"}, events.Flush())
}

{
Expand All @@ -58,13 +55,14 @@ func testModelHook(t *testing.T, dbName string, db *bun.DB) {
require.NoError(t, err)
require.Equal(t, []string{
"BeforeSelect",
"BeforeAppendModel",
"BeforeScan",
"AfterScan",
"AfterSelect",
}, events.Flush())
}

{
t.Run("selectEmptySlice", func(t *testing.T) {
hooks := make([]ModelHookTest, 0)
err := db.NewSelect().Model(&hooks).Scan(ctx)
require.NoError(t, err)
Expand All @@ -74,34 +72,65 @@ func testModelHook(t *testing.T, dbName string, db *bun.DB) {
"AfterScan",
"AfterSelect",
}, events.Flush())
}
})

{
hook := &ModelHookTest{ID: 1}
_, err := db.NewUpdate().Model(hook).Where("id = 1").Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{"BeforeUpdate", "AfterUpdate"}, events.Flush())
require.Equal(t, []string{"BeforeUpdate", "BeforeAppendModel", "AfterUpdate"}, events.Flush())
}

{
hook := &ModelHookTest{ID: 1}
_, err := db.NewDelete().Model(hook).Where("id = 1").Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{"BeforeDelete", "AfterDelete"}, events.Flush())
require.Equal(t, []string{"BeforeDelete", "BeforeAppendModel", "AfterDelete"}, events.Flush())
}

{
_, err := db.NewDelete().Model((*ModelHookTest)(nil)).Where("TRUE").Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{"BeforeDelete", "AfterDelete"}, events.Flush())
}

t.Run("insertSlice", func(t *testing.T) {
hooks := []ModelHookTest{{ID: 1}, {ID: 2}}
_, err := db.NewInsert().Model(&hooks).Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{
"BeforeInsert",
"BeforeAppendModel",
"BeforeAppendModel",
"AfterInsert",
}, events.Flush())
})

t.Run("insertSliceOfPtr", func(t *testing.T) {
hooks := []*ModelHookTest{{ID: 3}, {ID: 4}}
_, err := db.NewInsert().Model(&hooks).Exec(ctx)
require.NoError(t, err)
require.Equal(t, []string{
"BeforeInsert",
"BeforeAppendModel",
"BeforeAppendModel",
"AfterInsert",
}, events.Flush())
})
}

type ModelHookTest struct {
ID int
Value string
}

var _ bun.BeforeAppendModelHook = (*ModelHookTest)(nil)

func (t *ModelHookTest) BeforeAppendModel(query bun.Query) error {
events.Add("BeforeAppendModel")
return nil
}

var _ bun.BeforeScanHook = (*ModelHookTest)(nil)

func (t *ModelHookTest) BeforeScan(c context.Context) error {
Expand Down Expand Up @@ -182,7 +211,7 @@ func (t *ModelHookTest) AfterDelete(ctx context.Context, query *bun.DeleteQuery)

func assertQueryModel(query interface{ GetModel() bun.Model }) {
switch value := query.GetModel().Value(); value.(type) {
case *ModelHookTest, *[]ModelHookTest:
case *ModelHookTest, *[]ModelHookTest, *[]*ModelHookTest:
// ok
default:
panic(fmt.Errorf("unexpected: %T", value))
Expand Down
10 changes: 4 additions & 6 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ var errNilModel = errors.New("bun: Model(nil)")

var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()

type Model interface {
ScanRows(ctx context.Context, rows *sql.Rows) (int, error)
Value() interface{}
}
type Model = schema.Model

type rowScanner interface {
ScanRow(ctx context.Context, rows *sql.Rows) error
Expand All @@ -27,8 +24,9 @@ type rowScanner interface {
type TableModel interface {
Model

schema.BeforeScanHook
schema.AfterScanHook
schema.BeforeAppendModelHook
schema.BeforeScanRowHook
schema.AfterScanRowHook
ScanColumn(column string, src interface{}) error

Table() *schema.Table
Expand Down
25 changes: 23 additions & 2 deletions model_table_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,31 @@ func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, er
return n, nil
}

var _ schema.BeforeAppendModelHook = (*sliceTableModel)(nil)

func (m *sliceTableModel) BeforeAppendModel(query Query) error {
if !m.table.HasBeforeAppendModelHook() {
return nil
}

sliceLen := m.slice.Len()
for i := 0; i < sliceLen; i++ {
strct := m.slice.Index(i)
if !m.sliceOfPtr {
strct = strct.Addr()
}
err := strct.Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(query)
if err != nil {
return err
}
}
return nil
}

// Inherit these hooks from structTableModel.
var (
_ schema.BeforeScanHook = (*sliceTableModel)(nil)
_ schema.AfterScanHook = (*sliceTableModel)(nil)
_ schema.BeforeScanRowHook = (*sliceTableModel)(nil)
_ schema.AfterScanRowHook = (*sliceTableModel)(nil)
)

func (m *sliceTableModel) updateSoftDeleteField(tm time.Time) error {
Expand Down
63 changes: 45 additions & 18 deletions model_table_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,65 @@ func (m *structTableModel) mountJoins() {
}
}

var _ schema.BeforeScanHook = (*structTableModel)(nil)
var _ schema.BeforeAppendModelHook = (*structTableModel)(nil)

func (m *structTableModel) BeforeScan(ctx context.Context) error {
if !m.table.HasBeforeScanHook() {
func (m *structTableModel) BeforeAppendModel(query Query) error {
if !m.table.HasBeforeAppendModelHook() || !m.strct.IsValid() {
return nil
}
return callBeforeScanHook(ctx, m.strct.Addr())
return m.strct.Addr().Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(query)
}

var _ schema.AfterScanHook = (*structTableModel)(nil)
var _ schema.BeforeScanRowHook = (*structTableModel)(nil)

func (m *structTableModel) AfterScan(ctx context.Context) error {
if !m.table.HasAfterScanHook() || !m.structInited {
func (m *structTableModel) BeforeScanRow(ctx context.Context) error {
if m.table.HasBeforeScanRowHook() {
return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx)
}
if m.table.HasBeforeScanHook() {
return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx)
}
return nil
}

var _ schema.AfterScanRowHook = (*structTableModel)(nil)

func (m *structTableModel) AfterScanRow(ctx context.Context) error {
if !m.structInited {
return nil
}

var firstErr error
if m.table.HasAfterScanRowHook() {
firstErr := m.strct.Addr().Interface().(schema.AfterScanRowHook).AfterScanRow(ctx)

for _, j := range m.joins {
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil {
firstErr = err
}
}
}

if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil {
firstErr = err
return firstErr
}

for _, j := range m.joins {
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil {
firstErr = err
if m.table.HasAfterScanHook() {
firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx)

for _, j := range m.joins {
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil {
firstErr = err
}
}
}

return firstErr
}

return firstErr
return nil
}

func (m *structTableModel) getJoin(name string) *relationJoin {
Expand Down Expand Up @@ -257,7 +284,7 @@ func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
}

func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
if err := m.BeforeScan(ctx); err != nil {
if err := m.BeforeScanRow(ctx); err != nil {
return err
}

Expand All @@ -266,7 +293,7 @@ func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []i
return err
}

if err := m.AfterScan(ctx); err != nil {
if err := m.AfterScanRow(ctx); err != nil {
return err
}

Expand Down
11 changes: 9 additions & 2 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ func (q *baseQuery) getModel(dest []interface{}) (Model, error) {
return newModel(q.db, dest)
}

func (q *baseQuery) beforeAppendModel(query Query) error {
if q.tableModel != nil {
return q.tableModel.BeforeAppendModel(query)
}
return nil
}

//------------------------------------------------------------------------------

func (q *baseQuery) checkSoftDelete() error {
Expand Down Expand Up @@ -462,7 +469,7 @@ func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {

func (q *baseQuery) scan(
ctx context.Context,
iquery IQuery,
iquery Query,
query string,
model Model,
hasDest bool,
Expand Down Expand Up @@ -494,7 +501,7 @@ func (q *baseQuery) scan(

func (q *baseQuery) exec(
ctx context.Context,
iquery IQuery,
iquery Query,
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
Expand Down
4 changes: 4 additions & 0 deletions query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,13 @@ func (q *DeleteQuery) Operation() string {
}

func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if err := q.beforeAppendModel(q); err != nil {
return nil, err
}
if q.err != nil {
return nil, q.err
}

fmter = formatterWithModel(fmter, q)

if q.isSoftDelete() {
Expand Down
4 changes: 4 additions & 0 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
if q.err != nil {
return nil, q.err
}
if err := q.beforeAppendModel(q); err != nil {
return nil, err
}

fmter = formatterWithModel(fmter, q)

b, err = q.appendWith(fmter, b)
Expand Down
Loading

0 comments on commit 0b55de7

Please sign in to comment.