Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tx methods to IDB (#587) #591

Merged
merged 1 commit into from
Jul 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package bun

import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -431,6 +433,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
type Tx struct {
ctx context.Context
db *DB
// name is the name of a savepoint
name string
*sql.Tx
}

Expand Down Expand Up @@ -479,19 +483,51 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
}

func (tx Tx) Commit() error {
if tx.name == "" {
return tx.commitTX()
}
return tx.commitSP()
}

func (tx Tx) commitTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil)
err := tx.Tx.Commit()
tx.db.afterQuery(ctx, event, nil, err)
return err
}

func (tx Tx) commitSP() error {
if tx.Dialect().Features().Has(feature.MSSavepoint) {
return nil
}
query := "RELEASE SAVEPOINT " + tx.name
_, err := tx.ExecContext(tx.ctx, query)
return err
}

func (tx Tx) Rollback() error {
if tx.name == "" {
return tx.rollbackTX()
}
return tx.rollbackSP()
}

func (tx Tx) rollbackTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil)
err := tx.Tx.Rollback()
tx.db.afterQuery(ctx, event, nil, err)
return err
}

func (tx Tx) rollbackSP() error {
query := "ROLLBACK TO SAVEPOINT " + tx.name
if tx.Dialect().Features().Has(feature.MSSavepoint) {
query = "ROLLBACK TRANSACTION " + tx.name
}
_, err := tx.ExecContext(tx.ctx, query)
return err
}

func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}
Expand Down Expand Up @@ -534,6 +570,60 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac

//------------------------------------------------------------------------------

func (tx Tx) Begin() (Tx, error) {
return tx.BeginTx(tx.ctx, nil)
}

// BeginTx will save a point in the running transaction.
func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {
// mssql savepoint names are limited to 32 characters
sp := make([]byte, 14)
_, err := rand.Read(sp)
if err != nil {
return Tx{}, err
}

qName := "SP_" + hex.EncodeToString(sp)
query := "SAVEPOINT " + qName
if tx.Dialect().Features().Has(feature.MSSavepoint) {
query = "SAVE TRANSACTION " + qName
}
_, err = tx.ExecContext(ctx, query)
if err != nil {
return Tx{}, err
}
return Tx{
ctx: ctx,
db: tx.db,
Tx: tx.Tx,
name: qName,
}, nil
}

func (tx Tx) RunInTx(
ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
) error {
sp, err := tx.BeginTx(ctx, nil)
if err != nil {
return err
}

var done bool

defer func() {
if !done {
_ = sp.Rollback()
}
}()

if err := fn(ctx, sp); err != nil {
return err
}

done = true
return sp.Commit()
}

func (tx Tx) Dialect() schema.Dialect {
return tx.db.Dialect()
}
Expand Down
1 change: 1 addition & 0 deletions dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ const (
OffsetFetch
SelectExists
UpdateFromTable
MSSavepoint
)
3 changes: 2 additions & 1 deletion dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func New() *Dialect {
feature.Identity |
feature.Output |
feature.OffsetFetch |
feature.UpdateFromTable
feature.UpdateFromTable |
feature.MSSavepoint
return d
}

Expand Down
97 changes: 97 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func TestDB(t *testing.T) {
{testEmbedModelPointer},
{testJSONMarshaler},
{testNilDriverValue},
{testRunInTxAndSavepoint},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -1401,3 +1402,99 @@ func testNilDriverValue(t *testing.T, db *bun.DB) {
_, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx)
require.NoError(t, err)
}

func testRunInTxAndSavepoint(t *testing.T, db *bun.DB) {
type Counter struct {
Count int64
}

err := db.ResetModel(ctx, (*Counter)(nil))
require.NoError(t, err)

_, err = db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx)
require.NoError(t, err)

err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewUpdate().Model((*Counter)(nil)).
Set("count = count + 1").
Where("1 = 1").
Exec(ctx)
return err
})
require.NoError(t, err)
// rolling back the transaction should rollback what happened inside savepoint
return errors.New("fake error")
})
require.Error(t, err)

var count int
err = db.NewSelect().Model((*Counter)(nil)).Scan(ctx, &count)
require.NoError(t, err)
require.Equal(t, 0, count)

err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
require.NoError(t, err)
return err
})
require.NoError(t, err)

// ignored on purpose this error
// rolling back a savepoint should not affect the transaction
// nor other savepoints on the same level
_ = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 2}).
Exec(ctx)
require.NoError(t, err)
return errors.New("fake error")
})

return err
})
require.NoError(t, err)

count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 2, count)

err = db.ResetModel(ctx, (*Counter)(nil))
require.NoError(t, err)

// happy path, commit transaction, savepoints and sub-savepoints
err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := tx.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
require.NoError(t, err)

err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
if err != nil {
return err
}

return sp.RunInTx(ctx, nil, func(ctx context.Context, subSp bun.Tx) error {
_, err := subSp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
return err
})
})
require.NoError(t, err)

err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 2}).
Exec(ctx)
return err
})

return err
})
require.NoError(t, err)

count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 4, count)
}
3 changes: 3 additions & 0 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type IDB interface {
NewTruncateTable() *TruncateTableQuery
NewAddColumn() *AddColumnQuery
NewDropColumn() *DropColumnQuery

BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
RunInTx(ctx context.Context, opts *sql.TxOptions, f func(ctx context.Context, tx Tx) error) error
}

var (
Expand Down