Commit 0d163ce1 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: do not bypass the driver locks with Context methods

When context methods were initially added it was attempted to unify
behavior between drivers without Context methods and those with
Context methods to always return right away when the Context expired.
However in doing so the driver call could be executed outside of the
scope of the driver connection lock and thus bypassing thread safety.

The new behavior waits until the driver operation is complete. It then
checks to see if the context has expired and if so returns that error.

Change-Id: I4a5c7c3263420c57778f36a5ed6fa0ef8cb32b20
Reviewed-on: https://go-review.googlesource.com/32422Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 3825656e
...@@ -14,40 +14,16 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver ...@@ -14,40 +14,16 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
if ciCtx, is := ci.(driver.ConnPrepareContext); is { if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query) return ciCtx.PrepareContext(ctx, query)
} }
if ctx.Done() == context.Background().Done() { si, err := ci.Prepare(query)
return ci.Prepare(query) if err == nil {
} select {
default:
type R struct { case <-ctx.Done():
err error si.Close()
panic interface{} return nil, ctx.Err()
si driver.Stmt
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.si, r.err = ci.Prepare(query)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.si, r.err
} }
return si, err
} }
func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
...@@ -58,84 +34,38 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda ...@@ -58,84 +34,38 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ctx.Done() == context.Background().Done() {
return execer.Exec(query, dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1) resi, err := execer.Exec(query, dargs)
go func() { if err == nil {
r := R{} select {
defer func() { default:
if v := recover(); v != nil { case <-ctx.Done():
r.panic = v return resi, ctx.Err()
}
rc <- r
}()
r.resi, r.err = execer.Exec(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.resi, r.err
} }
return resi, err
} }
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is { if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, nvdargs) ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
return ret, err
} }
dargs, err := namedValueToValue(nvdargs) dargs, err := namedValueToValue(nvdargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ctx.Done() == context.Background().Done() {
return queryer.Query(query, dargs)
}
type R struct { rowsi, err := queryer.Query(query, dargs)
err error if err == nil {
panic interface{} select {
rowsi driver.Rows default:
} case <-ctx.Done():
rowsi.Close()
rc := make(chan R, 1) return nil, ctx.Err()
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = queryer.Query(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.rowsi, r.err
} }
return rowsi, err
} }
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
...@@ -146,40 +76,16 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam ...@@ -146,40 +76,16 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ctx.Done() == context.Background().Done() {
return si.Exec(dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1) resi, err := si.Exec(dargs)
go func() { if err == nil {
r := R{} select {
defer func() { default:
if v := recover(); v != nil { case <-ctx.Done():
r.panic = v return resi, ctx.Err()
}
rc <- r
}()
r.resi, r.err = si.Exec(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.resi, r.err
} }
return resi, err
} }
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
...@@ -190,40 +96,17 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na ...@@ -190,40 +96,17 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ctx.Done() == context.Background().Done() {
return si.Query(dargs)
}
type R struct { rowsi, err := si.Query(dargs)
err error if err == nil {
panic interface{} select {
rowsi driver.Rows default:
} case <-ctx.Done():
rowsi.Close()
rc := make(chan R, 1) return nil, ctx.Err()
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = si.Query(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.rowsi, r.err
} }
return rowsi, err
} }
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported") var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
...@@ -249,35 +132,16 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { ...@@ -249,35 +132,16 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
return nil, errors.New("sql: driver does not support read-only transactions") return nil, errors.New("sql: driver does not support read-only transactions")
} }
type R struct { txi, err := ci.Begin()
err error if err == nil {
panic interface{} select {
txi driver.Tx default:
} case <-ctx.Done():
rc := make(chan R, 1) txi.Rollback()
go func() { return nil, ctx.Err()
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.txi, r.err = ci.Begin()
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
} }
return r.txi, r.err
} }
return txi, err
} }
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
......
...@@ -87,12 +87,21 @@ type Pinger interface { ...@@ -87,12 +87,21 @@ type Pinger interface {
// statement. // statement.
// //
// Exec may return ErrSkip. // Exec may return ErrSkip.
//
// Deprecated: Drivers should implement ExecerContext instead (or additionally).
type Execer interface { type Execer interface {
Exec(query string, args []Value) (Result, error) Exec(query string, args []Value) (Result, error)
} }
// ExecerContext is like execer, but must honor the context timeout and return // ExecerContext is an optional interface that may be implemented by a Conn.
// when the context is cancelled. //
// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
// first prepare a query, execute the statement, and then close the
// statement.
//
// ExecerContext may return ErrSkip.
//
// ExecerContext must honor the context timeout and return when the context is canceled.
type ExecerContext interface { type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error) ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
} }
...@@ -104,12 +113,21 @@ type ExecerContext interface { ...@@ -104,12 +113,21 @@ type ExecerContext interface {
// statement. // statement.
// //
// Query may return ErrSkip. // Query may return ErrSkip.
//
// Deprecated: Drivers should implement QueryerContext instead (or additionally).
type Queryer interface { type Queryer interface {
Query(query string, args []Value) (Rows, error) Query(query string, args []Value) (Rows, error)
} }
// QueryerContext is like Queryer, but most honor the context timeout and return // QueryerContext is an optional interface that may be implemented by a Conn.
// when the context is cancelled. //
// If a Conn does not implement QueryerContext, the sql package's DB.Query will
// first prepare a query, execute the statement, and then close the
// statement.
//
// QueryerContext may return ErrSkip.
//
// QueryerContext must honor the context timeout and return when the context is canceled.
type QueryerContext interface { type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error) QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
} }
...@@ -133,6 +151,8 @@ type Conn interface { ...@@ -133,6 +151,8 @@ type Conn interface {
Close() error Close() error
// Begin starts and returns a new transaction. // Begin starts and returns a new transaction.
//
// Deprecated: Drivers should implement ConnBeginContext instead (or additionally).
Begin() (Tx, error) Begin() (Tx, error)
} }
...@@ -167,8 +187,8 @@ func ReadOnlyFromContext(ctx context.Context) (readonly bool) { ...@@ -167,8 +187,8 @@ func ReadOnlyFromContext(ctx context.Context) (readonly bool) {
// ConnBeginContext enhances the Conn interface with context. // ConnBeginContext enhances the Conn interface with context.
type ConnBeginContext interface { type ConnBeginContext interface {
// BeginContext starts and returns a new transaction. // BeginContext starts and returns a new transaction.
// The provided context should be used to roll the transaction back // If the context is canceled by the user the sql package will
// if it is cancelled. // call Tx.Rollback before discarding and closing the connection.
// //
// This must call IsolationFromContext to determine if there is a set // This must call IsolationFromContext to determine if there is a set
// isolation level. If the driver does not support setting the isolation // isolation level. If the driver does not support setting the isolation
...@@ -215,22 +235,32 @@ type Stmt interface { ...@@ -215,22 +235,32 @@ type Stmt interface {
// Exec executes a query that doesn't return rows, such // Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE. // as an INSERT or UPDATE.
//
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error) Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a // Query executes a query that may return rows, such as a
// SELECT. // SELECT.
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error) Query(args []Value) (Rows, error)
} }
// StmtExecContext enhances the Stmt interface by providing Exec with context. // StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface { type StmtExecContext interface {
// ExecContext must honor the context timeout and return when it is cancelled. // ExecContext executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// ExecContext must honor the context timeout and return when it is canceled.
ExecContext(ctx context.Context, args []NamedValue) (Result, error) ExecContext(ctx context.Context, args []NamedValue) (Result, error)
} }
// StmtQueryContext enhances the Stmt interface by providing Query with context. // StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface { type StmtQueryContext interface {
// QueryContext must honor the context timeout and return when it is cancelled. // QueryContext executes a query that may return rows, such as a
// SELECT.
//
// QueryContext must honor the context timeout and return when it is canceled.
QueryContext(ctx context.Context, args []NamedValue) (Rows, error) QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
} }
......
...@@ -39,6 +39,9 @@ var _ = log.Printf ...@@ -39,6 +39,9 @@ var _ = log.Printf
// Any of these can be preceded by PANIC|<method>|, to cause the // Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic. // named method on fakeStmt to panic.
// //
// Any of these can be proceeded by WAIT|<duration>|, to cause the
// named method on fakeStmt to sleep for the specified duration.
//
// Multiple of these can be combined when separated with a semicolon. // Multiple of these can be combined when separated with a semicolon.
// //
// When opening a fakeDriver's database, it starts empty with no // When opening a fakeDriver's database, it starts empty with no
...@@ -119,6 +122,7 @@ type fakeStmt struct { ...@@ -119,6 +122,7 @@ type fakeStmt struct {
cmd string cmd string
table string table string
panic string panic string
wait time.Duration
next *fakeStmt // used for returning multiple results. next *fakeStmt // used for returning multiple results.
...@@ -526,14 +530,28 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -526,14 +530,28 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
if firstStmt == nil { if firstStmt == nil {
firstStmt = stmt firstStmt = stmt
} }
if len(parts) >= 3 && parts[0] == "PANIC" { if len(parts) >= 3 {
stmt.panic = parts[1] switch parts[0] {
parts = parts[2:] case "PANIC":
stmt.panic = parts[1]
parts = parts[2:]
case "WAIT":
wait, err := time.ParseDuration(parts[1])
if err != nil {
return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
}
parts = parts[2:]
stmt.wait = wait
}
} }
cmd := parts[0] cmd := parts[0]
stmt.cmd = cmd stmt.cmd = cmd
parts = parts[1:] parts = parts[1:]
if stmt.wait > 0 {
time.Sleep(stmt.wait)
}
c.incrStat(&c.stmtsMade) c.incrStat(&c.stmtsMade)
var err error var err error
switch cmd { switch cmd {
...@@ -619,6 +637,16 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d ...@@ -619,6 +637,16 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return nil, err return nil, err
} }
if s.wait > 0 {
time.Sleep(s.wait)
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
db := s.c.db db := s.c.db
switch s.cmd { switch s.cmd {
case "WIPE": case "WIPE":
......
...@@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn ...@@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
return nil, errDBClosed return nil, errDBClosed
} }
// Check if the context is expired. // Check if the context is expired.
if err := ctx.Err(); err != nil { select {
default:
case <-ctx.Done():
db.mu.Unlock() db.mu.Unlock()
return nil, err return nil, ctx.Err()
} }
lifetime := db.maxLifetime lifetime := db.maxLifetime
...@@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row { ...@@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
// BeginContext starts a transaction. // BeginContext starts a transaction.
// //
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. Tx.Commit will return an error if the context provided to
// BeginContext is canceled.
//
// An isolation level may be set by setting the value in the context // An isolation level may be set by setting the value in the context
// before calling this. If a non-default isolation level is used // before calling this. If a non-default isolation level is used
// that the driver doesn't support an error will be returned. Different drivers // that the driver doesn't support an error will be returned. Different drivers
...@@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er ...@@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er
dc: dc, dc: dc,
txi: txi, txi: txi,
cancel: cancel, cancel: cancel,
ctx: ctx,
} }
go func() { go func(tx *Tx) {
select { select {
case <-ctx.Done(): case <-tx.ctx.Done():
if !tx.done { if !tx.isDone() {
tx.Rollback() // Discard and close the connection used to ensure the transaction
// is closed and the resources are released.
tx.rollback(true)
} }
} }
}() }(tx)
return tx, nil return tx, nil
} }
...@@ -1370,10 +1380,11 @@ type Tx struct { ...@@ -1370,10 +1380,11 @@ type Tx struct {
dc *driverConn dc *driverConn
txi driver.Tx txi driver.Tx
// done transitions from false to true exactly once, on Commit // done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with // or Rollback. once done, all operations fail with
// ErrTxDone. // ErrTxDone.
done bool // Use atomic operations on value when checking value.
done int32
// All Stmts prepared for this transaction. These will be closed after the // All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back. // transaction has been committed or rolled back.
...@@ -1384,6 +1395,13 @@ type Tx struct { ...@@ -1384,6 +1395,13 @@ type Tx struct {
// cancel is called after done transitions from false to true. // cancel is called after done transitions from false to true.
cancel func() cancel func()
// ctx lives for the life of the transaction.
ctx context.Context
}
func (tx *Tx) isDone() bool {
return atomic.LoadInt32(&tx.done) != 0
} }
// ErrTxDone is returned by any operation that is performed on a transaction // ErrTxDone is returned by any operation that is performed on a transaction
...@@ -1391,10 +1409,9 @@ type Tx struct { ...@@ -1391,10 +1409,9 @@ type Tx struct {
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
func (tx *Tx) close(err error) { func (tx *Tx) close(err error) {
if tx.done { if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
panic("double close") // internal error panic("double close") // internal error
} }
tx.done = true
tx.db.putConn(tx.dc, err) tx.db.putConn(tx.dc, err)
tx.cancel() tx.cancel()
tx.dc = nil tx.dc = nil
...@@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) { ...@@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) {
} }
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
if tx.done { if tx.isDone() {
return nil, ErrTxDone return nil, ErrTxDone
} }
return tx.dc, nil return tx.dc, nil
...@@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() { ...@@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction. // Commit commits the transaction.
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
if tx.done { select {
default:
case <-tx.ctx.Done():
return tx.ctx.Err()
}
if tx.isDone() {
return ErrTxDone return ErrTxDone
} }
var err error var err error
...@@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error { ...@@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error {
return err return err
} }
// Rollback aborts the transaction. // rollback aborts the transaction and optionally forces the pool to discard
func (tx *Tx) Rollback() error { // the connection.
if tx.done { func (tx *Tx) rollback(discardConn bool) error {
if tx.isDone() {
return ErrTxDone return ErrTxDone
} }
var err error var err error
...@@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error { ...@@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error {
if err != driver.ErrBadConn { if err != driver.ErrBadConn {
tx.closePrepared() tx.closePrepared()
} }
if discardConn {
err = driver.ErrBadConn
}
tx.close(err) tx.close(err)
return err return err
} }
// Rollback aborts the transaction.
func (tx *Tx) Rollback() error {
return tx.rollback(false)
}
// Prepare creates a prepared statement for use within a transaction. // Prepare creates a prepared statement for use within a transaction.
// //
// The returned statement operates within the transaction and will be closed // The returned statement operates within the transaction and will be closed
...@@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { ...@@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var si driver.Stmt var si driver.Stmt
withLock(dc, func() { withLock(dc, func() {
si, err = dc.ci.Prepare(query) si, err = ctxDriverPrepare(ctx, dc.ci, query)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { ...@@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
} }
var si driver.Stmt var si driver.Stmt
withLock(dc, func() { withLock(dc, func() {
si, err = dc.ci.Prepare(stmt.query) si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
}) })
txs := &Stmt{ txs := &Stmt{
db: tx.db, db: tx.db,
......
...@@ -141,10 +141,7 @@ func closeDB(t testing.TB, db *DB) { ...@@ -141,10 +141,7 @@ func closeDB(t testing.TB, db *DB) {
if err != nil { if err != nil {
t.Fatalf("error closing DB: %v", err) t.Fatalf("error closing DB: %v", err)
} }
db.mu.Lock() if count := db.numOpenConns(); count != 0 {
count := db.numOpen
db.mu.Unlock()
if count != 0 {
t.Fatalf("%d connections still open after closing DB", count) t.Fatalf("%d connections still open after closing DB", count)
} }
} }
...@@ -183,6 +180,12 @@ func (db *DB) numFreeConns() int { ...@@ -183,6 +180,12 @@ func (db *DB) numFreeConns() int {
return len(db.freeConn) return len(db.freeConn)
} }
func (db *DB) numOpenConns() int {
db.mu.Lock()
defer db.mu.Unlock()
return db.numOpen
}
// clearAllConns closes all connections in db. // clearAllConns closes all connections in db.
func (db *DB) clearAllConns(t *testing.T) { func (db *DB) clearAllConns(t *testing.T) {
db.SetMaxIdleConns(0) db.SetMaxIdleConns(0)
...@@ -320,6 +323,75 @@ func TestQueryContext(t *testing.T) { ...@@ -320,6 +323,75 @@ func TestQueryContext(t *testing.T) {
} }
} }
func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
deadline := time.Now().Add(waitFor)
for time.Now().Before(deadline) {
if fn() {
return true
}
time.Sleep(checkEvery)
}
return false
}
func TestQueryContextWait(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
prepares0 := numPrepares(t, db)
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
// This will trigger the *fakeConn.Prepare method which will take time
// performing the query. The ctxDriverPrepare func will check the context
// after this and close the rows and return an error.
_, err := db.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
if err != context.DeadlineExceeded {
t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
}
// Verify closed rows connection after error condition.
if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestTxContextWait(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
tx, err := db.BeginContext(ctx)
if err != nil {
t.Fatal(err)
}
// This will trigger the *fakeConn.Prepare method which will take time
// performing the query. The ctxDriverPrepare func will check the context
// after this and close the rows and return an error.
_, err = tx.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
if err != context.DeadlineExceeded {
t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
}
var numFree int
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
numFree = db.numFreeConns()
return numFree == 0
}) {
t.Fatalf("free conns after hitting EOF = %d; want 0", numFree)
}
// Ensure the dropped connection allows more connections to be made.
// Checked on DB Close.
waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
return db.numOpenConns() == 0
})
}
func TestMultiResultSetQuery(t *testing.T) { func TestMultiResultSetQuery(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment