Commit f28c8fba authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

database/sql: associate a mutex with each driver interface

The database/sql/driver docs make this promise:

   "Conn is a connection to a database. It is not used
   concurrently by multiple goroutines."

That promises exists as part of database/sql's overall
goal of making drivers relatively easy to write.

So far this promise has been kept without the use of locks by
being careful in the database/sql package, but sometimes too
careful. (cf. golang.org/issue/3857)

The CL associates a Mutex with each driver.Conn, and with the
interface value progeny thereof. (e.g. each driver.Tx,
driver.Stmt, driver.Rows, driver.Result, etc) Then whenever
those interface values are used, the Locker is locked.

This CL should be a no-op (aside from some new Lock/Unlock
pairs) and doesn't attempt to fix Issue 3857 or Issue 4459,
but should make it much easier in a subsequent CL.

Update #3857

R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/7803043
parent eb80431b
...@@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript ...@@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript
// driverArgs converts arguments from callers of Stmt.Exec and // driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values. // Stmt.Query into driver Values.
// //
// The statement si may be nil, if no statement is available. // The statement ds may be nil, if no statement is available.
func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
dargs := make([]driver.Value, len(args)) dargs := make([]driver.Value, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
}
cc, ok := si.(driver.ColumnConverter) cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter. // Normal path, for a driver.Stmt that is not a ColumnConverter.
...@@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { ...@@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// column before going across the network to get the // column before going across the network to get the
// same error. // same error.
var err error var err error
ds.Lock()
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
} }
......
...@@ -191,12 +191,35 @@ type DB struct { ...@@ -191,12 +191,35 @@ type DB struct {
dsn string dsn string
mu sync.Mutex // protects following fields mu sync.Mutex // protects following fields
outConn map[driver.Conn]bool // whether the conn is in use outConn map[*driverConn]bool // whether the conn is in use
freeConn []driver.Conn freeConn []*driverConn
closed bool closed bool
dep map[finalCloser]depSet dep map[finalCloser]depSet
onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned onConnPut map[*driverConn][]func() // code (with mu held) run when conn is next returned
lastPut map[driver.Conn]string // stacktrace of last conn's put; debug only lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
}
// driverConn wraps a driver.Conn with a mutex, to
// be held during all calls into the Conn. (including any calls onto
// interfaces returned via that Conn, such as calls on Tx, Stmt,
// Result, Rows)
type driverConn struct {
sync.Mutex
ci driver.Conn
}
// driverStmt associates a driver.Stmt with the
// *driverConn from which it came, so the driverConn's lock can be
// held during calls.
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
}
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
return ds.si.Close()
} }
// depSet is a finalCloser's outstanding dependencies // depSet is a finalCloser's outstanding dependencies
...@@ -270,9 +293,9 @@ func Open(driverName, dataSourceName string) (*DB, error) { ...@@ -270,9 +293,9 @@ func Open(driverName, dataSourceName string) (*DB, error) {
db := &DB{ db := &DB{
driver: driveri, driver: driveri,
dsn: dataSourceName, dsn: dataSourceName,
outConn: make(map[driver.Conn]bool), outConn: make(map[*driverConn]bool),
lastPut: make(map[driver.Conn]string), lastPut: make(map[*driverConn]string),
onConnPut: make(map[driver.Conn][]func()), onConnPut: make(map[*driverConn][]func()),
} }
return db, nil return db, nil
} }
...@@ -283,11 +306,11 @@ func (db *DB) Ping() error { ...@@ -283,11 +306,11 @@ func (db *DB) Ping() error {
// TODO(bradfitz): give drivers an optional hook to implement // TODO(bradfitz): give drivers an optional hook to implement
// this in a more efficient or more reliable way, if they // this in a more efficient or more reliable way, if they
// have one. // have one.
c, err := db.conn() dc, err := db.conn()
if err != nil { if err != nil {
return err return err
} }
db.putConn(c, nil) db.putConn(dc, nil)
return nil return nil
} }
...@@ -296,8 +319,10 @@ func (db *DB) Close() error { ...@@ -296,8 +319,10 @@ func (db *DB) Close() error {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
var err error var err error
for _, c := range db.freeConn { for _, dc := range db.freeConn {
err1 := c.Close() dc.Lock()
err1 := dc.ci.Close()
dc.Unlock()
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
...@@ -314,8 +339,8 @@ func (db *DB) maxIdleConns() int { ...@@ -314,8 +339,8 @@ func (db *DB) maxIdleConns() int {
return defaultMaxIdleConns return defaultMaxIdleConns
} }
// conn returns a newly-opened or cached driver.Conn // conn returns a newly-opened or cached *driverConn
func (db *DB) conn() (driver.Conn, error) { func (db *DB) conn() (*driverConn, error) {
db.mu.Lock() db.mu.Lock()
if db.closed { if db.closed {
db.mu.Unlock() db.mu.Unlock()
...@@ -329,13 +354,16 @@ func (db *DB) conn() (driver.Conn, error) { ...@@ -329,13 +354,16 @@ func (db *DB) conn() (driver.Conn, error) {
return conn, nil return conn, nil
} }
db.mu.Unlock() db.mu.Unlock()
conn, err := db.driver.Open(db.dsn)
if err == nil { ci, err := db.driver.Open(db.dsn)
db.mu.Lock() if err != nil {
db.outConn[conn] = true return nil, err
db.mu.Unlock()
} }
return conn, err dc := &driverConn{ci: ci}
db.mu.Lock()
db.outConn[dc] = true
db.mu.Unlock()
return dc, nil
} }
// connIfFree returns (wanted, true) if wanted is still a valid conn and // connIfFree returns (wanted, true) if wanted is still a valid conn and
...@@ -343,7 +371,7 @@ func (db *DB) conn() (driver.Conn, error) { ...@@ -343,7 +371,7 @@ func (db *DB) conn() (driver.Conn, error) {
// //
// If wanted is valid but in use, connIfFree returns (wanted, false). // If wanted is valid but in use, connIfFree returns (wanted, false).
// If wanted is invalid, connIfFre returns (nil, false). // If wanted is invalid, connIfFre returns (nil, false).
func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { func (db *DB) connIfFree(wanted *driverConn) (conn *driverConn, ok bool) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
if db.outConn[wanted] { if db.outConn[wanted] {
...@@ -362,12 +390,12 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { ...@@ -362,12 +390,12 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
} }
// putConnHook is a hook for testing. // putConnHook is a hook for testing.
var putConnHook func(*DB, driver.Conn) var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that si is no longer used and should // noteUnusedDriverStatement notes that si is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is // be closed whenever possible (when c is next not in use), unless c is
// already closed. // already closed.
func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) { func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
if db.outConn[c] { if db.outConn[c] {
...@@ -385,24 +413,24 @@ const debugGetPut = false ...@@ -385,24 +413,24 @@ const debugGetPut = false
// putConn adds a connection to the db's free pool. // putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection. // err is optionally the last error that occurred on this connection.
func (db *DB) putConn(c driver.Conn, err error) { func (db *DB) putConn(dc *driverConn, err error) {
db.mu.Lock() db.mu.Lock()
if !db.outConn[c] { if !db.outConn[dc] {
if debugGetPut { if debugGetPut {
fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c]) fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
} }
panic("sql: connection returned that was never out") panic("sql: connection returned that was never out")
} }
if debugGetPut { if debugGetPut {
db.lastPut[c] = stack() db.lastPut[dc] = stack()
} }
delete(db.outConn, c) delete(db.outConn, dc)
if fns, ok := db.onConnPut[c]; ok { if fns, ok := db.onConnPut[dc]; ok {
for _, fn := range fns { for _, fn := range fns {
fn() fn()
} }
delete(db.onConnPut, c) delete(db.onConnPut, dc)
} }
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
...@@ -411,17 +439,20 @@ func (db *DB) putConn(c driver.Conn, err error) { ...@@ -411,17 +439,20 @@ func (db *DB) putConn(c driver.Conn, err error) {
return return
} }
if putConnHook != nil { if putConnHook != nil {
putConnHook(db, c) putConnHook(db, dc)
} }
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
db.freeConn = append(db.freeConn, c) db.freeConn = append(db.freeConn, dc)
db.mu.Unlock() db.mu.Unlock()
return return
} }
// TODO: check to see if we need this Conn for any prepared // TODO: check to see if we need this Conn for any prepared
// statements which are still active? // statements which are still active?
db.mu.Unlock() db.mu.Unlock()
c.Close()
dc.Lock()
dc.ci.Close()
dc.Unlock()
} }
// Prepare creates a prepared statement for later queries or executions. // Prepare creates a prepared statement for later queries or executions.
...@@ -446,22 +477,24 @@ func (db *DB) prepare(query string) (*Stmt, error) { ...@@ -446,22 +477,24 @@ func (db *DB) prepare(query string) (*Stmt, error) {
// to a connection, and to execute this prepared statement // to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else // we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one. // get a new connection + re-prepare + execute on that one.
ci, err := db.conn() dc, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
si, err := ci.Prepare(query) dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
if err != nil { if err != nil {
db.putConn(ci, err) db.putConn(dc, err)
return nil, err return nil, err
} }
stmt := &Stmt{ stmt := &Stmt{
db: db, db: db,
query: query, query: query,
css: []connStmt{{ci, si}}, css: []connStmt{{dc, si}},
} }
db.addDep(stmt, stmt) db.addDep(stmt, stmt)
db.putConn(ci, nil) db.putConn(dc, nil)
return stmt, nil return stmt, nil
} }
...@@ -480,35 +513,39 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { ...@@ -480,35 +513,39 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
} }
func (db *DB) exec(query string, args []interface{}) (res Result, err error) { func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
ci, err := db.conn() dc, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
db.putConn(ci, err) db.putConn(dc, err)
}() }()
if execer, ok := ci.(driver.Execer); ok { if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args) dargs, err := driverArgs(nil, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dc.Lock()
resi, err := execer.Exec(query, dargs) resi, err := execer.Exec(query, dargs)
dc.Unlock()
if err != driver.ErrSkip { if err != driver.ErrSkip {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return result{resi}, nil return driverResult{dc, resi}, nil
} }
} }
sti, err := ci.Prepare(query) dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer sti.Close() defer withLock(dc, func() { si.Close() })
return resultFromStatement(sti, args...) return resultFromStatement(driverStmt{dc, si}, args...)
} }
// Query executes a query that returns rows, typically a SELECT. // Query executes a query that returns rows, typically a SELECT.
...@@ -538,14 +575,16 @@ func (db *DB) query(query string, args []interface{}) (*Rows, error) { ...@@ -538,14 +575,16 @@ func (db *DB) query(query string, args []interface{}) (*Rows, error) {
// queryConn executes a query on the given connection. // queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function. // The connection gets released by the releaseConn function.
func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := ci.(driver.Queryer); ok { if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args) dargs, err := driverArgs(nil, args)
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
return nil, err return nil, err
} }
dc.Lock()
rowsi, err := queryer.Query(query, dargs) rowsi, err := queryer.Query(query, dargs)
dc.Unlock()
if err != driver.ErrSkip { if err != driver.ErrSkip {
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
...@@ -555,7 +594,7 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a ...@@ -555,7 +594,7 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a
// with releaseConn. // with releaseConn.
rows := &Rows{ rows := &Rows{
db: db, db: db,
ci: ci, dc: dc,
releaseConn: releaseConn, releaseConn: releaseConn,
rowsi: rowsi, rowsi: rowsi,
} }
...@@ -563,16 +602,21 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a ...@@ -563,16 +602,21 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a
} }
} }
sti, err := ci.Prepare(query) dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
return nil, err return nil, err
} }
rowsi, err := rowsiFromStatement(sti, args...) ds := driverStmt{dc, si}
rowsi, err := rowsiFromStatement(ds, args...)
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
sti.Close() dc.Lock()
si.Close()
dc.Unlock()
return nil, err return nil, err
} }
...@@ -580,10 +624,10 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a ...@@ -580,10 +624,10 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a
// with releaseConn. // with releaseConn.
rows := &Rows{ rows := &Rows{
db: db, db: db,
ci: ci, dc: dc,
releaseConn: releaseConn, releaseConn: releaseConn,
rowsi: rowsi, rowsi: rowsi,
closeStmt: sti, closeStmt: si,
} }
return rows, nil return rows, nil
} }
...@@ -611,18 +655,20 @@ func (db *DB) Begin() (*Tx, error) { ...@@ -611,18 +655,20 @@ func (db *DB) Begin() (*Tx, error) {
} }
func (db *DB) begin() (tx *Tx, err error) { func (db *DB) begin() (tx *Tx, err error) {
ci, err := db.conn() dc, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
txi, err := ci.Begin() dc.Lock()
txi, err := dc.ci.Begin()
dc.Unlock()
if err != nil { if err != nil {
db.putConn(ci, err) db.putConn(dc, err)
return nil, err return nil, err
} }
return &Tx{ return &Tx{
db: db, db: db,
ci: ci, dc: dc,
txi: txi, txi: txi,
}, nil }, nil
} }
...@@ -641,13 +687,15 @@ func (db *DB) Driver() driver.Driver { ...@@ -641,13 +687,15 @@ func (db *DB) Driver() driver.Driver {
type Tx struct { type Tx struct {
db *DB db *DB
// ci is owned exclusively until Commit or Rollback, at which point // dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn. // it's returned with putConn.
ci driver.Conn // TODO(bradfitz): golang.org/issue/3857
dc *driverConn
txi driver.Tx txi driver.Tx
// cimu is held while somebody is using ci (between grabConn // cimu is held while somebody is using ci (between grabConn
// and releaseConn) // and releaseConn)
// TODO(bradfitz): golang.org/issue/3857
cimu sync.Mutex cimu sync.Mutex
// done transitions from false to true exactly once, on Commit // done transitions from false to true exactly once, on Commit
...@@ -663,17 +711,17 @@ func (tx *Tx) close() { ...@@ -663,17 +711,17 @@ func (tx *Tx) close() {
panic("double close") // internal error panic("double close") // internal error
} }
tx.done = true tx.done = true
tx.db.putConn(tx.ci, nil) tx.db.putConn(tx.dc, nil)
tx.ci = nil tx.dc = nil
tx.txi = nil tx.txi = nil
} }
func (tx *Tx) grabConn() (driver.Conn, error) { func (tx *Tx) grabConn() (*driverConn, error) {
if tx.done { if tx.done {
return nil, ErrTxDone return nil, ErrTxDone
} }
tx.cimu.Lock() tx.cimu.Lock()
return tx.ci, nil return tx.dc, nil
} }
func (tx *Tx) releaseConn() { func (tx *Tx) releaseConn() {
...@@ -686,6 +734,8 @@ func (tx *Tx) Commit() error { ...@@ -686,6 +734,8 @@ func (tx *Tx) Commit() error {
return ErrTxDone return ErrTxDone
} }
defer tx.close() defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Commit() return tx.txi.Commit()
} }
...@@ -695,6 +745,8 @@ func (tx *Tx) Rollback() error { ...@@ -695,6 +745,8 @@ func (tx *Tx) Rollback() error {
return ErrTxDone return ErrTxDone
} }
defer tx.close() defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Rollback() return tx.txi.Rollback()
} }
...@@ -718,21 +770,26 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { ...@@ -718,21 +770,26 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// Perhaps just looking at the reference count (by noting // Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer // Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count. // on Stmt to drop the reference count.
ci, err := tx.grabConn() dc, err := tx.grabConn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.releaseConn() defer tx.releaseConn()
si, err := ci.Prepare(query) dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := &Stmt{ stmt := &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
txsi: si, txsi: &driverStmt{
Locker: dc,
si: si,
},
query: query, query: query,
} }
return stmt, nil return stmt, nil
...@@ -756,16 +813,21 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { ...@@ -756,16 +813,21 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if tx.db != stmt.db { if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
} }
ci, err := tx.grabConn() dc, err := tx.grabConn()
if err != nil { if err != nil {
return &Stmt{stickyErr: err} return &Stmt{stickyErr: err}
} }
defer tx.releaseConn() defer tx.releaseConn()
si, err := ci.Prepare(stmt.query) dc.Lock()
si, err := dc.ci.Prepare(stmt.query)
dc.Unlock()
return &Stmt{ return &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
txsi: si, txsi: &driverStmt{
Locker: dc,
si: si,
},
query: stmt.query, query: stmt.query,
stickyErr: err, stickyErr: err,
} }
...@@ -774,33 +836,37 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { ...@@ -774,33 +836,37 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// Exec executes a query that doesn't return rows. // Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE. // For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
ci, err := tx.grabConn() dc, err := tx.grabConn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.releaseConn() defer tx.releaseConn()
if execer, ok := ci.(driver.Execer); ok { if execer, ok := dc.ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args) dargs, err := driverArgs(nil, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dc.Lock()
resi, err := execer.Exec(query, dargs) resi, err := execer.Exec(query, dargs)
dc.Unlock()
if err == nil { if err == nil {
return result{resi}, nil return driverResult{dc, resi}, nil
} }
if err != driver.ErrSkip { if err != driver.ErrSkip {
return nil, err return nil, err
} }
} }
sti, err := ci.Prepare(query) dc.Lock()
si, err := dc.ci.Prepare(query)
dc.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer sti.Close() defer withLock(dc, func() { si.Close() })
return resultFromStatement(sti, args...) return resultFromStatement(driverStmt{dc, si}, args...)
} }
// Query executes a query that returns rows, typically a SELECT. // Query executes a query that returns rows, typically a SELECT.
...@@ -825,7 +891,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { ...@@ -825,7 +891,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
// connStmt is a prepared statement on a particular connection. // connStmt is a prepared statement on a particular connection.
type connStmt struct { type connStmt struct {
ci driver.Conn dc *driverConn
si driver.Stmt si driver.Stmt
} }
...@@ -840,7 +906,7 @@ type Stmt struct { ...@@ -840,7 +906,7 @@ type Stmt struct {
// If in a transaction, else both nil: // If in a transaction, else both nil:
tx *Tx tx *Tx
txsi driver.Stmt txsi *driverStmt
mu sync.Mutex // protects the rest of the fields mu sync.Mutex // protects the rest of the fields
closed bool closed bool
...@@ -857,39 +923,45 @@ type Stmt struct { ...@@ -857,39 +923,45 @@ type Stmt struct {
func (s *Stmt) Exec(args ...interface{}) (Result, error) { func (s *Stmt) Exec(args ...interface{}) (Result, error) {
s.closemu.RLock() s.closemu.RLock()
defer s.closemu.RUnlock() defer s.closemu.RUnlock()
_, releaseConn, si, err := s.connStmt() dc, releaseConn, si, err := s.connStmt()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer releaseConn(nil) defer releaseConn(nil)
return resultFromStatement(si, args...) return resultFromStatement(driverStmt{dc, si}, args...)
} }
func resultFromStatement(si driver.Stmt, args ...interface{}) (Result, error) { func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
ds.Lock()
want := ds.si.NumInput()
ds.Unlock()
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the // placeholders, so we won't sanity check input here and instead let the
// driver deal with errors. // driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want { if want != -1 && len(args) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
} }
dargs, err := driverArgs(si, args) dargs, err := driverArgs(&ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resi, err := si.Exec(dargs) ds.Lock()
resi, err := ds.si.Exec(dargs)
ds.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return result{resi}, nil return driverResult{ds.Locker, resi}, nil
} }
// connStmt returns a free driver connection on which to execute the // connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a // statement, a function to call to release the connection, and a
// statement bound to that connection. // statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) { func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil { if err = s.stickyErr; err != nil {
return return
} }
...@@ -909,7 +981,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St ...@@ -909,7 +981,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
return return
} }
releaseConn = func(error) { s.tx.releaseConn() } releaseConn = func(error) { s.tx.releaseConn() }
return ci, releaseConn, s.txsi, nil return ci, releaseConn, s.txsi.si, nil
} }
var cs connStmt var cs connStmt
...@@ -917,7 +989,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St ...@@ -917,7 +989,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
for _, v := range s.css { for _, v := range s.css {
// TODO(bradfitz): lazily clean up entries in this // TODO(bradfitz): lazily clean up entries in this
// list with dead conns while enumerating // list with dead conns while enumerating
if _, match = s.db.connIfFree(v.ci); match { if _, match = s.db.connIfFree(v.dc); match {
cs = v cs = v
break break
} }
...@@ -928,11 +1000,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St ...@@ -928,11 +1000,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
// TODO(bradfitz): or wait for one? make configurable later? // TODO(bradfitz): or wait for one? make configurable later?
if !match { if !match {
for i := 0; ; i++ { for i := 0; ; i++ {
ci, err := s.db.conn() dc, err := s.db.conn()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
si, err := ci.Prepare(s.query) dc.Lock()
si, err := dc.ci.Prepare(s.query)
dc.Unlock()
if err == driver.ErrBadConn && i < 10 { if err == driver.ErrBadConn && i < 10 {
continue continue
} }
...@@ -940,14 +1014,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St ...@@ -940,14 +1014,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
return nil, nil, nil, err return nil, nil, nil, err
} }
s.mu.Lock() s.mu.Lock()
cs = connStmt{ci, si} cs = connStmt{dc, si}
s.css = append(s.css, cs) s.css = append(s.css, cs)
s.mu.Unlock() s.mu.Unlock()
break break
} }
} }
conn := cs.ci conn := cs.dc
releaseConn = func(err error) { s.db.putConn(conn, err) } releaseConn = func(err error) { s.db.putConn(conn, err) }
return conn, releaseConn, cs.si, nil return conn, releaseConn, cs.si, nil
} }
...@@ -958,12 +1032,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -958,12 +1032,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
s.closemu.RLock() s.closemu.RLock()
defer s.closemu.RUnlock() defer s.closemu.RUnlock()
ci, releaseConn, si, err := s.connStmt() dc, releaseConn, si, err := s.connStmt()
if err != nil { if err != nil {
return nil, err return nil, err
} }
rowsi, err := rowsiFromStatement(si, args...) ds := driverStmt{dc, si}
rowsi, err := rowsiFromStatement(ds, args...)
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
return nil, err return nil, err
...@@ -973,7 +1048,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -973,7 +1048,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
// with releaseConn. // with releaseConn.
rows := &Rows{ rows := &Rows{
db: s.db, db: s.db,
ci: ci, dc: dc,
rowsi: rowsi, rowsi: rowsi,
// releaseConn set below // releaseConn set below
} }
...@@ -985,20 +1060,26 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -985,20 +1060,26 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return rows, nil return rows, nil
} }
func rowsiFromStatement(si driver.Stmt, args ...interface{}) (driver.Rows, error) { func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
ds.Lock()
want := ds.si.NumInput()
ds.Unlock()
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the // placeholders, so we won't sanity check input here and instead let the
// driver deal with errors. // driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want { if want != -1 && len(args) != want {
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
} }
dargs, err := driverArgs(si, args) dargs, err := driverArgs(&ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rowsi, err := si.Query(dargs) ds.Lock()
rowsi, err := ds.si.Query(dargs)
ds.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1049,7 +1130,7 @@ func (s *Stmt) Close() error { ...@@ -1049,7 +1130,7 @@ func (s *Stmt) Close() error {
func (s *Stmt) finalClose() error { func (s *Stmt) finalClose() error {
for _, v := range s.css { for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.ci, v.si) s.db.noteUnusedDriverStatement(v.dc, v.si)
} }
s.css = nil s.css = nil
return nil return nil
...@@ -1070,7 +1151,7 @@ func (s *Stmt) finalClose() error { ...@@ -1070,7 +1151,7 @@ func (s *Stmt) finalClose() error {
// ... // ...
type Rows struct { type Rows struct {
db *DB db *DB
ci driver.Conn // owned; must call releaseConn when closed to release dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error) releaseConn func(error)
rowsi driver.Rows rowsi driver.Rows
...@@ -1243,11 +1324,31 @@ type Result interface { ...@@ -1243,11 +1324,31 @@ type Result interface {
RowsAffected() (int64, error) RowsAffected() (int64, error)
} }
type result struct { type driverResult struct {
driver.Result sync.Locker // the *driverConn
resi driver.Result
}
func (dr driverResult) LastInsertId() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.LastInsertId()
}
func (dr driverResult) RowsAffected() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.RowsAffected()
} }
func stack() string { func stack() string {
var buf [1024]byte var buf [1024]byte
return string(buf[:runtime.Stack(buf[:], false)]) return string(buf[:runtime.Stack(buf[:], false)])
} }
// withLock runs while holding lk.
func withLock(lk sync.Locker, fn func()) {
lk.Lock()
fn()
lk.Unlock()
}
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package sql package sql
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
...@@ -16,10 +15,10 @@ import ( ...@@ -16,10 +15,10 @@ import (
func init() { func init() {
type dbConn struct { type dbConn struct {
db *DB db *DB
c driver.Conn c *driverConn
} }
freedFrom := make(map[dbConn]string) freedFrom := make(map[dbConn]string)
putConnHook = func(db *DB, c driver.Conn) { putConnHook = func(db *DB, c *driverConn) {
for _, oc := range db.freeConn { for _, oc := range db.freeConn {
if oc == c { if oc == c {
// print before panic, as panic may get lost due to conflicting panic // print before panic, as panic may get lost due to conflicting panic
...@@ -78,7 +77,7 @@ func numPrepares(t *testing.T, db *DB) int { ...@@ -78,7 +77,7 @@ func numPrepares(t *testing.T, db *DB) int {
if n := len(db.freeConn); n != 1 { if n := len(db.freeConn); n != 1 {
t.Fatalf("free conns = %d; want 1", n) t.Fatalf("free conns = %d; want 1", n)
} }
return db.freeConn[0].(*fakeConn).numPrepare return db.freeConn[0].ci.(*fakeConn).numPrepare
} }
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
...@@ -576,7 +575,7 @@ func TestQueryRowClosingStmt(t *testing.T) { ...@@ -576,7 +575,7 @@ func TestQueryRowClosingStmt(t *testing.T) {
if len(db.freeConn) != 1 { if len(db.freeConn) != 1 {
t.Fatalf("expected 1 free conn") t.Fatalf("expected 1 free conn")
} }
fakeConn := db.freeConn[0].(*fakeConn) fakeConn := db.freeConn[0].ci.(*fakeConn)
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
t.Errorf("statement close mismatch: made %d, closed %d", made, closed) t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment