Commit 2b283ced authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: fix race when canceling queries immediately

Previously the following could happen, though in practice it would
be rare.

Goroutine 1:
	(*Tx).QueryContext begins a query, passing in userContext

Goroutine 2:
	(*Tx).awaitDone starts to wait on the context derived from the passed in context

Goroutine 1:
	(*Tx).grabConn returns a valid (*driverConn)
	The (*driverConn) passes to (*DB).queryConn

Goroutine 3:
	userContext is canceled

Goroutine 2:
	(*Tx).awaitDone unblocks and calls (*Tx).rollback
	(*driverConn).finalClose obtains dc.Mutex
	(*driverConn).finalClose sets dc.ci = nil

Goroutine 1:
	(*DB).queryConn obtains dc.Mutex in withLock
	ctxDriverPrepare accepts dc.ci which is now nil
	ctxCriverPrepare panics on the nil ci

The fix for this is to guard the Tx methods with a RWLock
holding it exclusivly when closing the Tx and holding a read lock
when executing a query.

Fixes #18719

Change-Id: I37aa02c37083c9793dabd28f7f934a1c5cbc05ea
Reviewed-on: https://go-review.googlesource.com/35550
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 1cf08182
...@@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra ...@@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
cancel: cancel, cancel: cancel,
ctx: ctx, ctx: ctx,
} }
go func(tx *Tx) { go tx.awaitDone()
select {
case <-tx.ctx.Done():
if !tx.isDone() {
// Discard and close the connection used to ensure the transaction
// is closed and the resources are released.
tx.rollback(true)
}
}
}(tx)
return tx, nil return tx, nil
} }
...@@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver { ...@@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
type Tx struct { type Tx struct {
db *DB db *DB
// closemu prevents the transaction from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex
// dc is owned exclusively until Commit or Rollback, at which point // dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn. // it's returned with putConn.
dc *driverConn dc *driverConn
...@@ -1413,6 +1409,20 @@ type Tx struct { ...@@ -1413,6 +1409,20 @@ type Tx struct {
ctx context.Context ctx context.Context
} }
// awaitDone blocks until the context in Tx is canceled and rolls back
// the transaction if it's not already done.
func (tx *Tx) awaitDone() {
// Wait for either the transaction to be committed or rolled
// back, or for the associated context to be closed.
<-tx.ctx.Done()
// Discard and close the connection used to ensure the
// transaction is closed and the resources are released. This
// rollback does nothing if the transaction has already been
// committed or rolled back.
tx.rollback(true)
}
func (tx *Tx) isDone() bool { func (tx *Tx) isDone() bool {
return atomic.LoadInt32(&tx.done) != 0 return atomic.LoadInt32(&tx.done) != 0
} }
...@@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle ...@@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
// close returns the connection to the pool and // close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit. // must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) { func (tx *Tx) close(err error) {
tx.closemu.Lock()
defer tx.closemu.Unlock()
tx.db.putConn(tx.dc, err) tx.db.putConn(tx.dc, err)
tx.cancel() tx.cancel()
tx.dc = nil tx.dc = nil
tx.txi = nil tx.txi = nil
} }
// hookTxGrabConn specifies an optional hook to be called on
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
if tx.isDone() { if tx.isDone() {
return nil, ErrTxDone return nil, ErrTxDone
} }
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
return tx.dc, nil return tx.dc, nil
} }
...@@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error { ...@@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
// for the execution of the returned statement. The returned statement // for the execution of the returned statement. The returned statement
// will run in the transaction context. // will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
// TODO(bradfitz): We could be more efficient here and either // TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on // provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if // perhaps a different Conn), and re-create it on this Conn if
...@@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { ...@@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// The returned statement operates within the transaction and will be closed // The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back. // when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
// TODO(bradfitz): optimize this. Currently this re-prepares // TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but // each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements // we should really cache already-prepared statements
...@@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { ...@@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// ExecContext executes a query that doesn't return rows. // ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE. // For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
dc, err := tx.grabConn(ctx) dc, err := tx.grabConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
// QueryContext executes a query that returns rows, typically a SELECT. // QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
dc, err := tx.grabConn(ctx) dc, err := tx.grabConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -2038,25 +2075,21 @@ type Rows struct { ...@@ -2038,25 +2075,21 @@ type Rows struct {
// closed value is 1 when the Rows is closed. // closed value is 1 when the Rows is closed.
// Use atomic operations on value when checking value. // Use atomic operations on value when checking value.
closed int32 closed int32
ctxClose chan struct{} // closed when Rows is closed, may be null. cancel func() // called when Rows is closed, may be nil.
lastcols []driver.Value lastcols []driver.Value
lasterr error // non-nil only if closed is true lasterr error // non-nil only if closed is true
closeStmt *driverStmt // if non-nil, statement to Close on close closeStmt *driverStmt // if non-nil, statement to Close on close
} }
func (rs *Rows) initContextClose(ctx context.Context) { func (rs *Rows) initContextClose(ctx context.Context) {
if ctx.Done() == context.Background().Done() { ctx, rs.cancel = context.WithCancel(ctx)
return go rs.awaitDone(ctx)
} }
rs.ctxClose = make(chan struct{}) // awaitDone blocks until the rows are closed or the context canceled.
go func() { func (rs *Rows) awaitDone(ctx context.Context) {
select { <-ctx.Done()
case <-ctx.Done(): rs.Close()
rs.Close()
case <-rs.ctxClose:
}
}()
} }
// Next prepares the next result row for reading with the Scan method. It // Next prepares the next result row for reading with the Scan method. It
...@@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error { ...@@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
return nil return nil
} }
var rowsCloseHook func(*Rows, *error) // rowsCloseHook returns a function so tests may install the
// hook throug a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }
func (rs *Rows) isClosed() bool { func (rs *Rows) isClosed() bool {
return atomic.LoadInt32(&rs.closed) != 0 return atomic.LoadInt32(&rs.closed) != 0
...@@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error { ...@@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) { if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
return nil return nil
} }
if rs.ctxClose != nil {
close(rs.ctxClose)
}
err := rs.rowsi.Close() err := rs.rowsi.Close()
if fn := rowsCloseHook; fn != nil { if fn := rowsCloseHook(); fn != nil {
fn(rs, &err) fn(rs, &err)
} }
if rs.cancel != nil {
rs.cancel()
}
if rs.closeStmt != nil { if rs.closeStmt != nil {
rs.closeStmt.Close() rs.closeStmt.Close()
} }
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
...@@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) { ...@@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
} }
} }
var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
func init() {
rowsCloseHook = func() func(*Rows, *error) {
fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
return fn
}
}
func setRowsCloseHook(fn func(*Rows, *error)) {
if fn == nil {
// Can't change an atomic.Value back to nil, so set it to this
// no-op func instead.
fn = func(*Rows, *error) {}
}
atomicRowsCloseHook.Store(fn)
}
// Test issue 6651 // Test issue 6651
func TestIssue6651(t *testing.T) { func TestIssue6651(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
...@@ -1147,6 +1166,7 @@ func TestIssue6651(t *testing.T) { ...@@ -1147,6 +1166,7 @@ func TestIssue6651(t *testing.T) {
return fmt.Errorf(want) return fmt.Errorf(want)
} }
defer func() { rowsCursorNextHook = nil }() defer func() { rowsCursorNextHook = nil }()
err := db.QueryRow("SELECT|people|name|").Scan(&v) err := db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want { if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want) t.Errorf("error = %q; want %q", err, want)
...@@ -1154,10 +1174,10 @@ func TestIssue6651(t *testing.T) { ...@@ -1154,10 +1174,10 @@ func TestIssue6651(t *testing.T) {
rowsCursorNextHook = nil rowsCursorNextHook = nil
want = "error in rows.Close" want = "error in rows.Close"
rowsCloseHook = func(rows *Rows, err *error) { setRowsCloseHook(func(rows *Rows, err *error) {
*err = fmt.Errorf(want) *err = fmt.Errorf(want)
} })
defer func() { rowsCloseHook = nil }() defer setRowsCloseHook(nil)
err = db.QueryRow("SELECT|people|name|").Scan(&v) err = db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want { if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want) t.Errorf("error = %q; want %q", err, want)
...@@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) { ...@@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) {
db.dumpDeps(t) db.dumpDeps(t)
} }
if len(stmt.css) > nquery { if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
return len(stmt.css) <= nquery
}) {
t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery) t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
} }
...@@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) { ...@@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
rowsCloseHook = func(rows *Rows, err *error) { setRowsCloseHook(func(rows *Rows, err *error) {
*err = driver.ErrBadConn *err = driver.ErrBadConn
} })
defer func() { rowsCloseHook = nil }() defer setRowsCloseHook(nil)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
rows, err := stmt.Query() rows, err := stmt.Query()
if err != nil { if err != nil {
...@@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) { ...@@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) {
if err != nil { if err != nil {
return return
} }
rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") // This is expected to give a cancel error many, but not all the time.
// Test failure will happen with a panic or other race condition being
// reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
if rows != nil { if rows != nil {
rows.Close() rows.Close()
} }
...@@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) { ...@@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) {
time.Sleep(milliWait * 3 * time.Millisecond) time.Sleep(milliWait * 3 * time.Millisecond)
} }
// TestIssue18719 closes the context right before use. The sql.driverConn
// will nil out the ci on close in a lock, but if another process uses it right after
// it will panic with on the nil ref.
//
// See https://golang.org/cl/35550 .
func TestIssue18719(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
hookTxGrabConn = func() {
cancel()
// Wait for the context to cancel and tx to rollback.
for tx.isDone() == false {
time.Sleep(time.Millisecond * 3)
}
}
defer func() { hookTxGrabConn = nil }()
// This call will grab the connection and cancel the context
// after it has done so. Code after must deal with the canceled state.
rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
if err != nil {
rows.Close()
t.Fatalf("expected error %v but got %v", nil, err)
}
// Rows may be ignored because it will be closed when the context is canceled.
// Do not explicitly rollback. The rollback will happen from the
// canceled context.
// Wait for connections to return to pool.
var numOpen int
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
numOpen = db.numOpenConns()
return numOpen == 0
}) {
t.Fatalf("open conns after hitting EOF = %d; want 0", numOpen)
}
}
func TestConcurrency(t *testing.T) { func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest)) doConcurrentTest(t, new(concurrentDBQueryTest))
doConcurrentTest(t, new(concurrentDBExecTest)) doConcurrentTest(t, new(concurrentDBExecTest))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment