Commit e13df02e authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: add context methods

Add context methods to sql and sql/driver methods. If
the driver doesn't implement context methods the connection
pool will still handle timeouts when a query fails to return
in time or when a connection is not available from the pool
in time.

There will be a follow-up CL that will add support for
context values that specify transaction levels and modes
that a driver can use.

Fixes #15123

Change-Id: Ia99f3957aa3f177b23044dd99d4ec217491a30a7
Reviewed-on: https://go-review.googlesource.com/29381Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 54a72d90
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sql
import (
"context"
"database/sql/driver"
"errors"
)
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
type R struct {
err error
panic interface{}
si driver.Stmt
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.si, r.err = ci.Prepare(query)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.si, r.err
}
}
func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
if execerCtx, is := execer.(driver.ExecerContext); is {
return execerCtx.ExecContext(ctx, query, dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.resi, r.err = execer.Exec(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.resi, r.err
}
}
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, dargs)
}
type R struct {
err error
panic interface{}
rowsi driver.Rows
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = queryer.Query(query, dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.rowsi, r.err
}
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, dargs)
}
type R struct {
err error
panic interface{}
resi driver.Result
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.resi, r.err = si.Exec(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.resi, r.err
}
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, dargs)
}
type R struct {
err error
panic interface{}
rowsi driver.Rows
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.rowsi, r.err = si.Query(dargs)
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.rowsi, r.err
}
}
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
if ciCtx, is := ci.(driver.ConnBeginContext); is {
return ciCtx.BeginContext(ctx)
}
// TODO(kardianos): check the transaction level in ctx. If set and non-default
// then return an error here as the BeginContext driver value is not supported.
type R struct {
err error
panic interface{}
txi driver.Tx
}
rc := make(chan R, 1)
go func() {
r := R{}
defer func() {
if v := recover(); v != nil {
r.panic = v
}
rc <- r
}()
r.txi, r.err = ci.Begin()
}()
select {
case <-ctx.Done():
go func() {
<-rc
close(rc)
}()
return nil, ctx.Err()
case r := <-rc:
if r.panic != nil {
panic(r.panic)
}
return r.txi, r.err
}
}
......@@ -8,7 +8,10 @@
// Most code should use package sql.
package driver
import "errors"
import (
"context"
"errors"
)
// Value is a value that drivers must be able to handle.
// It is either nil or an instance of one of these types:
......@@ -65,6 +68,12 @@ type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// ExecerContext is like execer, but must honor the context timeout and return
// when the context is cancelled.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []Value) (Result, error)
}
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will
......@@ -76,6 +85,12 @@ type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// QueryerContext is like Queryer, but most honor the context timeout and return
// when the context is cancelled.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
}
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
......@@ -98,6 +113,23 @@ type Conn interface {
Begin() (Tx, error)
}
// ConnPrepareContext enhances the Conn interface with context.
type ConnPrepareContext interface {
// PrepareContext returns a prepared statement, bound to this connection.
// context is for the preparation of the statement,
// it must not store the context within the statement itself.
PrepareContext(ctx context.Context, query string) (Stmt, error)
}
// ConnBeginContext enhances the Conn interface with context.
type ConnBeginContext interface {
// BeginContext starts and returns a new transaction.
// the provided context should be used to roll the transaction back
// if it is cancelled. If there is an isolation level in context
// that is not supported by the driver an error must be returned.
BeginContext(ctx context.Context) (Tx, error)
}
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
......@@ -139,6 +171,18 @@ type Stmt interface {
Query(args []Value) (Rows, error)
}
// StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface {
// ExecContext must honor the context timeout and return when it is cancelled.
ExecContext(ctx context.Context, args []Value) (Result, error)
}
// StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface {
// QueryContext must honor the context timeout and return when it is cancelled.
QueryContext(ctx context.Context, args []Value) (Rows, error)
}
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
......
......@@ -13,6 +13,7 @@
package sql
import (
"context"
"database/sql/driver"
"errors"
"fmt"
......@@ -297,8 +298,8 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
si, err := dc.ci.Prepare(query)
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err == nil {
// Track each driverConn's open statements, so we can close them
// before closing the conn.
......@@ -494,13 +495,13 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return db, nil
}
// Ping verifies a connection to the database is still alive,
// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) Ping() error {
func (db *DB) PingContext(ctx context.Context) error {
// TODO(bradfitz): give drivers an optional hook to implement
// this in a more efficient or more reliable way, if they
// have one.
dc, err := db.conn(cachedOrNewConn)
dc, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
return err
}
......@@ -508,6 +509,12 @@ func (db *DB) Ping() error {
return nil
}
// Ping verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) Ping() error {
return db.PingContext(context.Background())
}
// Close closes the database, releasing any open resources.
//
// It is rare to Close a DB, as the DB handle is meant to be
......@@ -777,12 +784,16 @@ type connRequest struct {
var errDBClosed = errors.New("sql: database is closed")
// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
// Check if the context is expired.
if err := ctx.Err(); err != nil {
return nil, err
}
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
......@@ -808,7 +819,12 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
req := make(chan connRequest, 1)
db.connRequests = append(db.connRequests, req)
db.mu.Unlock()
ret, ok := <-req
// Timeout the connection request with the context.
select {
case <-ctx.Done():
return nil, ctx.Err()
case ret, ok := <-req:
if !ok {
return nil, errDBClosed
}
......@@ -818,6 +834,7 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
}
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
......@@ -952,40 +969,51 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
// connection to be opened.
const maxBadConnRetries = 2
// Prepare creates a prepared statement for later queries or executions.
// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
// Context is for the preparation of the statment, not for the execution of
// the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
for i := 0; i < maxBadConnRetries; i++ {
stmt, err = db.prepare(query, cachedOrNewConn)
stmt, err = db.prepare(ctx, query, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
return db.prepare(query, alwaysNewConn)
return db.prepare(ctx, query, alwaysNewConn)
}
return stmt, err
}
func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
dc, err := db.conn(strategy)
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var si driver.Stmt
withLock(dc, func() {
si, err = dc.prepareLocked(query)
si, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
......@@ -1002,25 +1030,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
return stmt, nil
}
// Exec executes a query without returning any rows.
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
var res Result
var err error
for i := 0; i < maxBadConnRetries; i++ {
res, err = db.exec(query, args, cachedOrNewConn)
res, err = db.exec(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
return db.exec(query, args, alwaysNewConn)
return db.exec(ctx, query, args, alwaysNewConn)
}
return res, err
}
func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
dc, err := db.conn(strategy)
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
......@@ -1036,7 +1070,7 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
}
var resi driver.Result
withLock(dc, func() {
resi, err = execer.Exec(query, dargs)
resi, err = ctxDriverExec(ctx, execer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
......@@ -1048,44 +1082,50 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
var si driver.Stmt
withLock(dc, func() {
si, err = dc.ci.Prepare(query)
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
defer withLock(dc, func() { si.Close() })
return resultFromStatement(driverStmt{dc, si}, args...)
return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
// Query executes a query that returns rows, typically a SELECT.
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
for i := 0; i < maxBadConnRetries; i++ {
rows, err = db.query(query, args, cachedOrNewConn)
rows, err = db.query(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
return db.query(query, args, alwaysNewConn)
return db.query(ctx, query, args, alwaysNewConn)
}
return rows, err
}
func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
ci, err := db.conn(strategy)
// Query executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
ci, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.queryConn(ci, ci.releaseConn, query, args)
return db.queryConn(ctx, ci, ci.releaseConn, query, args)
}
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
......@@ -1094,7 +1134,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
}
var rowsi driver.Rows
withLock(dc, func() {
rowsi, err = queryer.Query(query, dargs)
rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
......@@ -1115,7 +1155,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
var si driver.Stmt
var err error
withLock(dc, func() {
si, err = dc.ci.Prepare(query)
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
......@@ -1123,7 +1163,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
}
ds := driverStmt{dc, si}
rowsi, err := rowsiFromStatement(ds, args...)
rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
withLock(dc, func() {
si.Close()
......@@ -1143,49 +1183,77 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
return rows, nil
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
rows, err := db.Query(query, args...)
return &Row{rows: rows, err: err}
return db.QueryRowContext(context.Background(), query, args...)
}
// Begin starts a transaction. The isolation level is dependent on
// the driver.
func (db *DB) Begin() (*Tx, error) {
// BeginContext starts a transaction. If a non-default isolation level is used
// that the driver doesn't support an error will be returned. Different drivers
// may have slightly different meanings for the same isolation level.
func (db *DB) BeginContext(ctx context.Context) (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < maxBadConnRetries; i++ {
tx, err = db.begin(cachedOrNewConn)
tx, err = db.begin(ctx, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
return db.begin(alwaysNewConn)
return db.begin(ctx, alwaysNewConn)
}
return tx, err
}
func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
dc, err := db.conn(strategy)
// Begin starts a transaction. The default isolation level is dependent on
// the driver.
func (db *DB) Begin() (*Tx, error) {
return db.BeginContext(context.Background())
}
func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var txi driver.Tx
withLock(dc, func() {
txi, err = dc.ci.Begin()
txi, err = ctxDriverBegin(ctx, dc.ci)
})
if err != nil {
db.putConn(dc, err)
return nil, err
}
return &Tx{
// Schedule the transaction to rollback when the context is cancelled.
// The cancel function in Tx will be called after done is set to true.
ctx, cancel := context.WithCancel(ctx)
tx = &Tx{
db: db,
dc: dc,
txi: txi,
}, nil
cancel: cancel,
}
go func() {
select {
case <-ctx.Done():
if !tx.done {
tx.Rollback()
}
}
}()
return tx, nil
}
// Driver returns the database's underlying driver.
......@@ -1222,6 +1290,9 @@ type Tx struct {
sync.Mutex
v []*Stmt
}
// cancel is called after done transitions from false to true.
cancel func()
}
// ErrTxDone is returned by any operation that is performed on a transaction
......@@ -1234,11 +1305,12 @@ func (tx *Tx) close(err error) {
}
tx.done = true
tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
tx.txi = nil
}
func (tx *Tx) grabConn() (*driverConn, error) {
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
if tx.done {
return nil, ErrTxDone
}
......@@ -1292,7 +1364,10 @@ func (tx *Tx) Rollback() error {
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
// Context will be used for the preparation of the context, not
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
......@@ -1306,7 +1381,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
dc, err := tx.grabConn()
dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
......@@ -1334,7 +1409,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil
}
// Stmt returns a transaction-specific prepared statement from
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query)
}
// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
......@@ -1342,11 +1427,11 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
......@@ -1355,7 +1440,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
dc, err := tx.grabConn()
dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
......@@ -1379,10 +1464,26 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return txs
}
// Exec executes a query that doesn't return rows.
// Stmt returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt)
}
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
dc, err := tx.grabConn()
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
......@@ -1413,25 +1514,43 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
}
defer withLock(dc, func() { si.Close() })
return resultFromStatement(driverStmt{dc, si}, args...)
return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
dc, err := tx.grabConn()
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
return tx.ExecContext(context.Background(), query, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
releaseConn := func(error) {}
return tx.db.queryConn(dc, releaseConn, query, args)
return tx.db.queryConn(ctx, dc, releaseConn, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...)
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
rows, err := tx.Query(query, args...)
return &Row{rows: rows, err: err}
return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
......@@ -1468,15 +1587,15 @@ type Stmt struct {
lastNumClosed uint64
}
// Exec executes a prepared statement with the given arguments and
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt()
dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
......@@ -1484,7 +1603,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, err
}
res, err = resultFromStatement(driverStmt{dc, si}, args...)
res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
......@@ -1493,13 +1612,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, driver.ErrBadConn
}
// Exec executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
func driverNumInput(ds driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
......@@ -1516,7 +1641,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
ds.Lock()
defer ds.Unlock()
resi, err := ds.si.Exec(dargs)
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
......@@ -1552,7 +1678,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil {
return
}
......@@ -1567,7 +1693,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
// transaction was created on.
if s.tx != nil {
s.mu.Unlock()
ci, err = s.tx.grabConn() // blocks, waiting for the connection.
ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
......@@ -1578,8 +1704,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
s.removeClosedStmtLocked()
s.mu.Unlock()
// TODO(bradfitz): or always wait for one? make configurable later?
dc, err := s.db.conn(cachedOrNewConn)
dc, err := s.db.conn(ctx, cachedOrNewConn)
if err != nil {
return nil, nil, nil, err
}
......@@ -1595,7 +1720,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
si, err = dc.prepareLocked(s.query)
si, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
......@@ -1609,15 +1734,15 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
return dc, dc.releaseConn, si, nil
}
// Query executes a prepared query statement with the given arguments
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt()
dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
......@@ -1625,7 +1750,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, err
}
rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
......@@ -1650,7 +1775,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, driver.ErrBadConn
}
func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
// Query executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
......@@ -1670,14 +1801,15 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
ds.Lock()
defer ds.Unlock()
rowsi, err := ds.si.Query(dargs)
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return rowsi, nil
}
// QueryRow executes a prepared query statement with the given arguments.
// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
......@@ -1687,15 +1819,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
// Example usage:
//
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
func (s *Stmt) QueryRow(args ...interface{}) *Row {
rows, err := s.Query(args...)
// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
//
// Example usage:
//
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
func (s *Stmt) QueryRow(args ...interface{}) *Row {
return s.QueryRowContext(context.Background(), args...)
}
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()
......
......@@ -5,6 +5,7 @@
package sql
import (
"context"
"database/sql/driver"
"errors"
"fmt"
......@@ -1159,17 +1160,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) {
db.SetMaxOpenConns(3)
conn0, err := db.conn(cachedOrNewConn)
ctx := context.Background()
conn0, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
conn1, err := db.conn(cachedOrNewConn)
conn1, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
conn2, err := db.conn(cachedOrNewConn)
conn2, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
......
......@@ -228,8 +228,8 @@ var pkgDeps = map[string][]string{
"compress/lzw": {"L4"},
"compress/zlib": {"L4", "compress/flate"},
"context": {"errors", "fmt", "reflect", "sync", "time"},
"database/sql": {"L4", "container/list", "database/sql/driver"},
"database/sql/driver": {"L4", "time"},
"database/sql": {"L4", "container/list", "context", "database/sql/driver"},
"database/sql/driver": {"L4", "context", "time"},
"debug/dwarf": {"L4"},
"debug/elf": {"L4", "OS", "debug/dwarf", "compress/zlib"},
"debug/gosym": {"L4"},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment