Skip to content

Commit

Permalink
Merge pull request #133 from o1egl/fix_race
Browse files Browse the repository at this point in the history
Fix data race in  query factory initialisation
  • Loading branch information
doug-martin authored Aug 22, 2019
2 parents 37cfb20 + cc41938 commit 9bf6e8d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 5 deletions.
11 changes: 7 additions & 4 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"context"
"database/sql"
"sync"

"github.com/doug-martin/goqu/v8/exec"
)
Expand All @@ -28,6 +29,7 @@ type (
dialect string
Db SQLDatabase
qf exec.QueryFactory
qfOnce sync.Once
}
)

Expand Down Expand Up @@ -324,9 +326,9 @@ func (d *Database) QueryRowContext(ctx context.Context, query string, args ...in
}

func (d *Database) queryFactory() exec.QueryFactory {
if d.qf == nil {
d.qfOnce.Do(func() {
d.qf = exec.NewQueryFactory(d)
}
})
return d.qf
}

Expand Down Expand Up @@ -443,6 +445,7 @@ type (
dialect string
Tx SQLTx
qf exec.QueryFactory
qfOnce sync.Once
}
)

Expand Down Expand Up @@ -545,9 +548,9 @@ func (td *TxDatabase) QueryRowContext(ctx context.Context, query string, args ..
}

func (td *TxDatabase) queryFactory() exec.QueryFactory {
if td.qf == nil {
td.qfOnce.Do(func() {
td.qf = exec.NewQueryFactory(td)
}
})
return td.qf
}

Expand Down
76 changes: 76 additions & 0 deletions database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package goqu
import (
"context"
"fmt"
"sync"
"testing"

"github.com/stretchr/testify/assert"

"github.com/DATA-DOG/go-sqlmock"
"github.com/doug-martin/goqu/v8/internal/errors"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -313,6 +316,40 @@ func (ds *databaseSuite) TestWithTx() {
}
}

func (ds *databaseSuite) TestDataRace() {
t := ds.T()
mDb, mock, err := sqlmock.New()
assert.NoError(t, err)
db := newDatabase("mock", mDb)

const concurrency = 10

for i := 0; i < concurrency; i++ {
mock.ExpectQuery(`SELECT "address", "name" FROM "items"`).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"address", "name"}).
FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2"))
}

wg := sync.WaitGroup{}
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()

sql := db.From("items").Limit(1)
var item testActionItem
found, err := sql.ScanStruct(&item)
assert.NoError(t, err)
assert.True(t, found)
assert.Equal(t, item.Address, "111 Test Addr")
assert.Equal(t, item.Name, "Test1")
}()
}

wg.Wait()
}

func TestDatabaseSuite(t *testing.T) {
suite.Run(t, new(databaseSuite))
}
Expand Down Expand Up @@ -623,6 +660,45 @@ func (tds *txdatabaseSuite) TestWrap() {
}), "goqu: tx error")
}

func (tds *txdatabaseSuite) TestDataRace() {
t := tds.T()
mDb, mock, err := sqlmock.New()
assert.NoError(t, err)
mock.ExpectBegin()
db := newDatabase("mock", mDb)
tx, err := db.Begin()
assert.NoError(t, err)

const concurrency = 10

for i := 0; i < concurrency; i++ {
mock.ExpectQuery(`SELECT "address", "name" FROM "items"`).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"address", "name"}).
FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2"))
}

wg := sync.WaitGroup{}
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()

sql := tx.From("items").Limit(1)
var item testActionItem
found, err := sql.ScanStruct(&item)
assert.NoError(t, err)
assert.True(t, found)
assert.Equal(t, item.Address, "111 Test Addr")
assert.Equal(t, item.Name, "Test1")
}()
}

wg.Wait()
mock.ExpectCommit()
assert.NoError(t, tx.Commit())
}

func TestTxDatabaseSuite(t *testing.T) {
suite.Run(t, new(txdatabaseSuite))
}
2 changes: 1 addition & 1 deletion go.test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
set -e
echo "" > coverage.txt

go test -coverprofile=coverage.txt -coverpkg=./... ./...
go test -race -coverprofile=coverage.txt -coverpkg=./... ./...

0 comments on commit 9bf6e8d

Please sign in to comment.