Commit 0d163ce1 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: do not bypass the driver locks with Context methods

When context methods were initially added it was attempted to unify
behavior between drivers without Context methods and those with
Context methods to always return right away when the Context expired.
However in doing so the driver call could be executed outside of the
scope of the driver connection lock and thus bypassing thread safety.

The new behavior waits until the driver operation is complete. It then
checks to see if the context has expired and if so returns that error.

Change-Id: I4a5c7c3263420c57778f36a5ed6fa0ef8cb32b20
Reviewed-on: https://go-review.googlesource.com/32422Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 3825656e
......@@ -14,40 +14,16 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
if ctx.Done() == context.Background().Done() {
return ci.Prepare(query)
}
type R struct {
err error
panic interface{}
si driver.Stmt
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.si, r.err = ci.Prepare(query)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
si, err := ci.Prepare(query)
if err == nil {
select {
default:
case <-ctx.Done():
si.Close()
return nil, ctx.Err()
}
return r.si, r.err
}
return si, err
}
func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
......@@ -58,84 +34,38 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
if err != nil {
return nil, err
}
if ctx.Done() == context.Background().Done() {
return execer.Exec(query, dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.resi, r.err = execer.Exec(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
resi, err := execer.Exec(query, dargs)
if err == nil {
select {
default:
case <-ctx.Done():
return resi, ctx.Err()
}
return r.resi, r.err
}
return resi, err
}
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, nvdargs)
ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
return ret, err
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
if ctx.Done() == context.Background().Done() {
return queryer.Query(query, dargs)
}
type R struct {
err error
panic interface{}
rowsi driver.Rows
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = queryer.Query(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
rowsi, err := queryer.Query(query, dargs)
if err == nil {
select {
default:
case <-ctx.Done():
rowsi.Close()
return nil, ctx.Err()
}
return r.rowsi, r.err
}
return rowsi, err
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
......@@ -146,40 +76,16 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam
if err != nil {
return nil, err
}
if ctx.Done() == context.Background().Done() {
return si.Exec(dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.resi, r.err = si.Exec(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
resi, err := si.Exec(dargs)
if err == nil {
select {
default:
case <-ctx.Done():
return resi, ctx.Err()
}
return r.resi, r.err
}
return resi, err
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
......@@ -190,40 +96,17 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na
if err != nil {
return nil, err
}
if ctx.Done() == context.Background().Done() {
return si.Query(dargs)
}
type R struct {
err error
panic interface{}
rowsi driver.Rows
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = si.Query(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
rowsi, err := si.Query(dargs)
if err == nil {
select {
default:
case <-ctx.Done():
rowsi.Close()
return nil, ctx.Err()
}
return r.rowsi, r.err
}
return rowsi, err
}
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
......@@ -249,35 +132,16 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
return nil, errors.New("sql: driver does not support read-only transactions")
}
type R struct {
err error
panic interface{}
txi driver.Tx
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.txi, r.err = ci.Begin()
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
txi, err := ci.Begin()
if err == nil {
select {
default:
case <-ctx.Done():
txi.Rollback()
return nil, ctx.Err()
}
return r.txi, r.err
}
return txi, err
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
......
......@@ -87,12 +87,21 @@ type Pinger interface {
// statement.
//
// Exec may return ErrSkip.
//
// Deprecated: Drivers should implement ExecerContext instead (or additionally).
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// ExecerContext is like execer, but must honor the context timeout and return
// when the context is cancelled.
// ExecerContext is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
// first prepare a query, execute the statement, and then close the
// statement.
//
// ExecerContext may return ErrSkip.
//
// ExecerContext must honor the context timeout and return when the context is canceled.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
}
......@@ -104,12 +113,21 @@ type ExecerContext interface {
// statement.
//
// Query may return ErrSkip.
//
// Deprecated: Drivers should implement QueryerContext instead (or additionally).
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// QueryerContext is like Queryer, but most honor the context timeout and return
// when the context is cancelled.
// QueryerContext is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement QueryerContext, the sql package's DB.Query will
// first prepare a query, execute the statement, and then close the
// statement.
//
// QueryerContext may return ErrSkip.
//
// QueryerContext must honor the context timeout and return when the context is canceled.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
}
......@@ -133,6 +151,8 @@ type Conn interface {
Close() error
// Begin starts and returns a new transaction.
//
// Deprecated: Drivers should implement ConnBeginContext instead (or additionally).
Begin() (Tx, error)
}
......@@ -167,8 +187,8 @@ func ReadOnlyFromContext(ctx context.Context) (readonly bool) {
// ConnBeginContext enhances the Conn interface with context.
type ConnBeginContext interface {
// BeginContext starts and returns a new transaction.
// The provided context should be used to roll the transaction back
// if it is cancelled.
// If the context is canceled by the user the sql package will
// call Tx.Rollback before discarding and closing the connection.
//
// This must call IsolationFromContext to determine if there is a set
// isolation level. If the driver does not support setting the isolation
......@@ -215,22 +235,32 @@ type Stmt interface {
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
// StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface {
// ExecContext must honor the context timeout and return when it is cancelled.
// ExecContext executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// ExecContext must honor the context timeout and return when it is canceled.
ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}
// StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface {
// QueryContext must honor the context timeout and return when it is cancelled.
// QueryContext executes a query that may return rows, such as a
// SELECT.
//
// QueryContext must honor the context timeout and return when it is canceled.
QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
......
......@@ -39,6 +39,9 @@ var _ = log.Printf
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
// Any of these can be proceeded by WAIT|<duration>|, to cause the
// named method on fakeStmt to sleep for the specified duration.
//
// Multiple of these can be combined when separated with a semicolon.
//
// When opening a fakeDriver's database, it starts empty with no
......@@ -119,6 +122,7 @@ type fakeStmt struct {
cmd string
table string
panic string
wait time.Duration
next *fakeStmt // used for returning multiple results.
......@@ -526,14 +530,28 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
if firstStmt == nil {
firstStmt = stmt
}
if len(parts) >= 3 && parts[0] == "PANIC" {
stmt.panic = parts[1]
parts = parts[2:]
if len(parts) >= 3 {
switch parts[0] {
case "PANIC":
stmt.panic = parts[1]
parts = parts[2:]
case "WAIT":
wait, err := time.ParseDuration(parts[1])
if err != nil {
return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
}
parts = parts[2:]
stmt.wait = wait
}
}
cmd := parts[0]
stmt.cmd = cmd
parts = parts[1:]
if stmt.wait > 0 {
time.Sleep(stmt.wait)
}
c.incrStat(&c.stmtsMade)
var err error
switch cmd {
......@@ -619,6 +637,16 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return nil, err
}
if s.wait > 0 {
time.Sleep(s.wait)
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
db := s.c.db
switch s.cmd {
case "WIPE":
......
......@@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
return nil, errDBClosed
}
// Check if the context is expired.
if err := ctx.Err(); err != nil {
select {
default:
case <-ctx.Done():
db.mu.Unlock()
return nil, err
return nil, ctx.Err()
}
lifetime := db.maxLifetime
......@@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
// BeginContext starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. Tx.Commit will return an error if the context provided to
// BeginContext is canceled.
//
// An isolation level may be set by setting the value in the context
// before calling this. If a non-default isolation level is used
// that the driver doesn't support an error will be returned. Different drivers
......@@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er
dc: dc,
txi: txi,
cancel: cancel,
ctx: ctx,
}
go func() {
go func(tx *Tx) {
select {
case <-ctx.Done():
if !tx.done {
tx.Rollback()
case <-tx.ctx.Done():
if !tx.isDone() {
// Discard and close the connection used to ensure the transaction
// is closed and the resources are released.
tx.rollback(true)
}
}
}()
}(tx)
return tx, nil
}
......@@ -1370,10 +1380,11 @@ type Tx struct {
dc *driverConn
txi driver.Tx
// done transitions from false to true exactly once, on Commit
// done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
done bool
// Use atomic operations on value when checking value.
done int32
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
......@@ -1384,6 +1395,13 @@ type Tx struct {
// cancel is called after done transitions from false to true.
cancel func()
// ctx lives for the life of the transaction.
ctx context.Context
}
func (tx *Tx) isDone() bool {
return atomic.LoadInt32(&tx.done) != 0
}
// ErrTxDone is returned by any operation that is performed on a transaction
......@@ -1391,10 +1409,9 @@ type Tx struct {
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
func (tx *Tx) close(err error) {
if tx.done {
if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
panic("double close") // internal error
}
tx.done = true
tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
......@@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) {
}
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
if tx.done {
if tx.isDone() {
return nil, ErrTxDone
}
return tx.dc, nil
......@@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction.
func (tx *Tx) Commit() error {
if tx.done {
select {
default:
case <-tx.ctx.Done():
return tx.ctx.Err()
}
if tx.isDone() {
return ErrTxDone
}
var err error
......@@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error {
return err
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() error {
if tx.done {
// rollback aborts the transaction and optionally forces the pool to discard
// the connection.
func (tx *Tx) rollback(discardConn bool) error {
if tx.isDone() {
return ErrTxDone
}
var err error
......@@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error {
if err != driver.ErrBadConn {
tx.closePrepared()
}
if discardConn {
err = driver.ErrBadConn
}
tx.close(err)
return err
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() error {
return tx.rollback(false)
}
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
......@@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var si driver.Stmt
withLock(dc, func() {
si, err = dc.ci.Prepare(query)
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
......@@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
}
var si driver.Stmt
withLock(dc, func() {
si, err = dc.ci.Prepare(stmt.query)
si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
txs := &Stmt{
db: tx.db,
......
......@@ -141,10 +141,7 @@ func closeDB(t testing.TB, db *DB) {
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
db.mu.Lock()
count := db.numOpen
db.mu.Unlock()
if count != 0 {
if count := db.numOpenConns(); count != 0 {
t.Fatalf("%d connections still open after closing DB", count)
}
}
......@@ -183,6 +180,12 @@ func (db *DB) numFreeConns() int {
return len(db.freeConn)
}
func (db *DB) numOpenConns() int {
db.mu.Lock()
defer db.mu.Unlock()
return db.numOpen
}
// clearAllConns closes all connections in db.
func (db *DB) clearAllConns(t *testing.T) {
db.SetMaxIdleConns(0)
......@@ -320,6 +323,75 @@ func TestQueryContext(t *testing.T) {
}
}
func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
deadline := time.Now().Add(waitFor)
for time.Now().Before(deadline) {
if fn() {
return true
}
time.Sleep(checkEvery)
}
return false
}
func TestQueryContextWait(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
prepares0 := numPrepares(t, db)
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
// This will trigger the *fakeConn.Prepare method which will take time
// performing the query. The ctxDriverPrepare func will check the context
// after this and close the rows and return an error.
_, err := db.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
if err != context.DeadlineExceeded {
t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
}
// Verify closed rows connection after error condition.
if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestTxContextWait(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
tx, err := db.BeginContext(ctx)
if err != nil {
t.Fatal(err)
}
// This will trigger the *fakeConn.Prepare method which will take time
// performing the query. The ctxDriverPrepare func will check the context
// after this and close the rows and return an error.
_, err = tx.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
if err != context.DeadlineExceeded {
t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
}
var numFree int
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
numFree = db.numFreeConns()
return numFree == 0
}) {
t.Fatalf("free conns after hitting EOF = %d; want 0", numFree)
}
// Ensure the dropped connection allows more connections to be made.
// Checked on DB Close.
waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
return db.numOpenConns() == 0
})
}
func TestMultiResultSetQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment