Skip to content

Commit

Permalink
database/sql: fix race when canceling queries immediately
Browse files Browse the repository at this point in the history
Previously the following could happen, though in practice it would
be rare.

Goroutine 1:
	(*Tx).QueryContext begins a query, passing in userContext

Goroutine 2:
	(*Tx).awaitDone starts to wait on the context derived from the passed in context

Goroutine 1:
	(*Tx).grabConn returns a valid (*driverConn)
	The (*driverConn) passes to (*DB).queryConn

Goroutine 3:
	userContext is canceled

Goroutine 2:
	(*Tx).awaitDone unblocks and calls (*Tx).rollback
	(*driverConn).finalClose obtains dc.Mutex
	(*driverConn).finalClose sets dc.ci = nil

Goroutine 1:
	(*DB).queryConn obtains dc.Mutex in withLock
	ctxDriverPrepare accepts dc.ci which is now nil
	ctxCriverPrepare panics on the nil ci

The fix for this is to guard the Tx methods with a RWLock
holding it exclusivly when closing the Tx and holding a read lock
when executing a query.

Fixes #18719

Change-Id: I37aa02c37083c9793dabd28f7f934a1c5cbc05ea
Reviewed-on: https://go-review.googlesource.com/35550
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
  • Loading branch information
kardianos authored and bradfitz committed Jan 26, 2017
1 parent 1cf0818 commit 2b283ce
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 35 deletions.
91 changes: 64 additions & 27 deletions src/database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
cancel: cancel,
ctx: ctx,
}
go func(tx *Tx) {
select {
case <-tx.ctx.Done():
if !tx.isDone() {
// Discard and close the connection used to ensure the transaction
// is closed and the resources are released.
tx.rollback(true)
}
}
}(tx)
go tx.awaitDone()
return tx, nil
}

Expand All @@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
type Tx struct {
db *DB

// closemu prevents the transaction from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex

// dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn.
dc *driverConn
Expand All @@ -1413,6 +1409,20 @@ type Tx struct {
ctx context.Context
}

// awaitDone blocks until the context in Tx is canceled and rolls back
// the transaction if it's not already done.
func (tx *Tx) awaitDone() {
// Wait for either the transaction to be committed or rolled
// back, or for the associated context to be closed.
<-tx.ctx.Done()

// Discard and close the connection used to ensure the
// transaction is closed and the resources are released. This
// rollback does nothing if the transaction has already been
// committed or rolled back.
tx.rollback(true)
}

func (tx *Tx) isDone() bool {
return atomic.LoadInt32(&tx.done) != 0
}
Expand All @@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
// close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) {
tx.closemu.Lock()
defer tx.closemu.Unlock()

tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
tx.txi = nil
}

// hookTxGrabConn specifies an optional hook to be called on
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()

func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
if tx.isDone() {
return nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
return tx.dc, nil
}

Expand Down Expand Up @@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()

// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
Expand Down Expand Up @@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
tx.closemu.RLock()
defer tx.closemu.RUnlock()

// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
Expand Down Expand Up @@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()

dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {

// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()

dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -2038,25 +2075,21 @@ type Rows struct {
// closed value is 1 when the Rows is closed.
// Use atomic operations on value when checking value.
closed int32
ctxClose chan struct{} // closed when Rows is closed, may be null.
cancel func() // called when Rows is closed, may be nil.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
closeStmt *driverStmt // if non-nil, statement to Close on close
}

func (rs *Rows) initContextClose(ctx context.Context) {
if ctx.Done() == context.Background().Done() {
return
}
ctx, rs.cancel = context.WithCancel(ctx)
go rs.awaitDone(ctx)
}

rs.ctxClose = make(chan struct{})
go func() {
select {
case <-ctx.Done():
rs.Close()
case <-rs.ctxClose:
}
}()
// awaitDone blocks until the rows are closed or the context canceled.
func (rs *Rows) awaitDone(ctx context.Context) {
<-ctx.Done()
rs.Close()
}

// Next prepares the next result row for reading with the Scan method. It
Expand Down Expand Up @@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
return nil
}

var rowsCloseHook func(*Rows, *error)
// rowsCloseHook returns a function so tests may install the
// hook throug a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }

func (rs *Rows) isClosed() bool {
return atomic.LoadInt32(&rs.closed) != 0
Expand All @@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
return nil
}
if rs.ctxClose != nil {
close(rs.ctxClose)
}

err := rs.rowsi.Close()
if fn := rowsCloseHook; fn != nil {
if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
if rs.cancel != nil {
rs.cancel()
}

if rs.closeStmt != nil {
rs.closeStmt.Close()
}
Expand Down
91 changes: 83 additions & 8 deletions src/database/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
}
}

var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)

func init() {
rowsCloseHook = func() func(*Rows, *error) {
fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
return fn
}
}

func setRowsCloseHook(fn func(*Rows, *error)) {
if fn == nil {
// Can't change an atomic.Value back to nil, so set it to this
// no-op func instead.
fn = func(*Rows, *error) {}
}
atomicRowsCloseHook.Store(fn)
}

// Test issue 6651
func TestIssue6651(t *testing.T) {
db := newTestDB(t, "people")
Expand All @@ -1147,17 +1166,18 @@ func TestIssue6651(t *testing.T) {
return fmt.Errorf(want)
}
defer func() { rowsCursorNextHook = nil }()

err := db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want)
}
rowsCursorNextHook = nil

want = "error in rows.Close"
rowsCloseHook = func(rows *Rows, err *error) {
setRowsCloseHook(func(rows *Rows, err *error) {
*err = fmt.Errorf(want)
}
defer func() { rowsCloseHook = nil }()
})
defer setRowsCloseHook(nil)
err = db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want)
Expand Down Expand Up @@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) {
db.dumpDeps(t)
}

if len(stmt.css) > nquery {
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
return len(stmt.css) <= nquery
}) {
t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
}

Expand Down Expand Up @@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) {
if err != nil {
t.Fatal(err)
}
rowsCloseHook = func(rows *Rows, err *error) {
setRowsCloseHook(func(rows *Rows, err *error) {
*err = driver.ErrBadConn
}
defer func() { rowsCloseHook = nil }()
})
defer setRowsCloseHook(nil)
for i := 0; i < 10; i++ {
rows, err := stmt.Query()
if err != nil {
Expand Down Expand Up @@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) {
if err != nil {
return
}
rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
// This is expected to give a cancel error many, but not all the time.
// Test failure will happen with a panic or other race condition being
// reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
if rows != nil {
rows.Close()
}
Expand All @@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) {
time.Sleep(milliWait * 3 * time.Millisecond)
}

// TestIssue18719 closes the context right before use. The sql.driverConn
// will nil out the ci on close in a lock, but if another process uses it right after
// it will panic with on the nil ref.
//
// See https://golang.org/cl/35550 .
func TestIssue18719(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}

hookTxGrabConn = func() {
cancel()

// Wait for the context to cancel and tx to rollback.
for tx.isDone() == false {
time.Sleep(time.Millisecond * 3)
}
}
defer func() { hookTxGrabConn = nil }()

// This call will grab the connection and cancel the context
// after it has done so. Code after must deal with the canceled state.
rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
if err != nil {
rows.Close()
t.Fatalf("expected error %v but got %v", nil, err)
}

// Rows may be ignored because it will be closed when the context is canceled.

// Do not explicitly rollback. The rollback will happen from the
// canceled context.

// Wait for connections to return to pool.
var numOpen int
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
numOpen = db.numOpenConns()
return numOpen == 0
}) {
t.Fatalf("open conns after hitting EOF = %d; want 0", numOpen)
}
}

func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest))
doConcurrentTest(t, new(concurrentDBExecTest))
Expand Down

0 comments on commit 2b283ce

Please sign in to comment.