Skip to content

Commit

Permalink
feat: expose connection transactions with context and options
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored and aeneasr committed May 24, 2022
1 parent db86847 commit 012ea29
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
20 changes: 13 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pop

import (
"context"
"database/sql"
"errors"
"fmt"
"sync/atomic"
Expand Down Expand Up @@ -185,21 +186,26 @@ func (c *Connection) Rollback(fn func(tx *Connection)) error {

// NewTransaction starts a new transaction on the connection
func (c *Connection) NewTransaction() (*Connection, error) {
return c.NewTransactionContextOptions(c.Context(), nil)
}

// NewTransactionContext starts a new transaction on the connection using the provided context
func (c *Connection) NewTransactionContext(ctx context.Context) (*Connection, error) {
return c.NewTransactionContextOptions(ctx, nil)
}

// NewTransactionContextOptions starts a new transaction on the connection using the provided context and transaction options
func (c *Connection) NewTransactionContextOptions(ctx context.Context, options *sql.TxOptions) (*Connection, error) {
var cn *Connection
if c.TX == nil {
tx, err := c.Store.Transaction()
tx, err := c.Store.TransactionContextOptions(ctx, options)
if err != nil {
return cn, fmt.Errorf("couldn't start a new transaction: %w", err)
}
var store store = tx

// Rewrap the store if it was a context store
if cs, ok := c.Store.(contextStore); ok {
store = contextStore{store: store, ctx: cs.ctx}
}
cn = &Connection{
ID: randx.String(30),
Store: store,
Store: contextStore{store: tx, ctx: ctx},
Dialect: c.Dialect,
TX: tx,
}
Expand Down
45 changes: 45 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//go:build sqlite
// +build sqlite

package pop

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -52,3 +54,46 @@ func Test_Connection_Open_BadDriver(t *testing.T) {
err = c.Open()
r.Error(err)
}

func Test_Connection_Transaction(t *testing.T) {
r := require.New(t)
ctx := context.WithValue(context.Background(), "test", "test")

c, err := NewConnection(&ConnectionDetails{
URL: "sqlite://file::memory:?_fk=true",
})
r.NoError(err)
r.NoError(c.Open())
c = c.WithContext(ctx)

t.Run("func=NewTransaction", func(t *testing.T) {
r := require.New(t)
tx, err := c.NewTransaction()
r.NoError(err)

// has transaction and context
r.NotNil(tx.TX)
r.Nil(c.TX)
r.Equal(ctx, tx.Context())

// does not start a new transaction
ntx, err := tx.NewTransaction()
r.Equal(tx, ntx)

r.NoError(tx.TX.Rollback())
})

t.Run("func=NewTransactionContext", func(t *testing.T) {
r := require.New(t)
nctx := context.WithValue(ctx, "nested", "test")
tx, err := c.NewTransactionContext(nctx)
r.NoError(err)

// has transaction and context
r.NotNil(tx.TX)
r.Nil(c.TX)
r.Equal(nctx, tx.Context())

r.NoError(tx.TX.Rollback())
})
}

0 comments on commit 012ea29

Please sign in to comment.