Commit cd24a8a5 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: ensure a Stmt from a Conn executes on the same driver.Conn

Ensure a Stmt prepared on a Conn executes on the same driver.Conn.
This also removes another instance of duplicated prepare logic
as a side effect.

Fixes #20647

Change-Id: Ia00a19e4dd15e19e4d754105babdff5dc127728f
Reviewed-on: https://go-review.googlesource.com/45391Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 200d0cc1
......@@ -393,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
// the prepared statements in a pool.
func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err != nil {
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
// No need to manage open statements if there is a single connection grabber.
if cg != nil {
return ds, nil
}
// Track each driverConn's open statements, so we can close them
// before closing the conn.
......@@ -406,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
ds := &driverStmt{Locker: dc, si: si}
dc.openStmt[ds] = true
return ds, nil
}
......@@ -1165,17 +1171,20 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil {
return nil, err
}
return db.prepareDC(ctx, dc, dc.releaseConn, query)
return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
}
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), query string) (*Stmt, error) {
// prepareDC prepares a query on the driverConn and calls release before
// returning. When cg == nil it implies that a connection pool is used, and
// when cg != nil only a single driver connection is used.
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
var ds *driverStmt
var err error
defer func() {
release(err)
}()
withLock(dc, func() {
ds, err = dc.prepareLocked(ctx, query)
ds, err = dc.prepareLocked(ctx, cg, query)
})
if err != nil {
return nil, err
......@@ -1183,10 +1192,18 @@ func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error)
stmt := &Stmt{
db: db,
query: query,
css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
cg: cg,
cgds: ds,
}
// When cg == nil this statement will need to keep track of various
// connections they are prepared on and record the stmt dependency on
// the DB.
if cg == nil {
stmt.css = []connStmt{{dc, ds}}
stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
db.addDep(stmt, stmt)
}
return stmt, nil
}
......@@ -1474,6 +1491,8 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
return conn, nil
}
type releaseConn func(error)
// Conn represents a single database session rather a pool of database
// sessions. Prefer running queries from DB unless there is a specific
// need for a continuous single database session.
......@@ -1501,46 +1520,41 @@ type Conn struct {
done int32
}
func (c *Conn) grabConn() (*driverConn, error) {
func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
if atomic.LoadInt32(&c.done) != 0 {
return nil, ErrConnDone
return nil, nil, ErrConnDone
}
return c.dc, nil
c.closemu.RLock()
return c.dc, c.closemuRUnlockCondReleaseConn, nil
}
// PingContext verifies the connection to the database is still alive.
func (c *Conn) PingContext(ctx context.Context) error {
dc, err := c.grabConn()
dc, release, err := c.grabConn(ctx)
if err != nil {
return err
}
c.closemu.RLock()
return c.db.pingDC(ctx, dc, c.closemuRUnlockCondReleaseConn)
return c.db.pingDC(ctx, dc, release)
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
dc, err := c.grabConn()
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
c.closemu.RLock()
return c.db.execDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args)
return c.db.execDC(ctx, dc, release, query, args)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
dc, err := c.grabConn()
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
c.closemu.RLock()
return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
return c.db.queryDC(ctx, nil, dc, release, query, args)
}
// QueryRowContext executes a query that is expected to return at most one row.
......@@ -1563,13 +1577,11 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
dc, err := c.grabConn()
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
c.closemu.RLock()
return c.db.prepareDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query)
return c.db.prepareDC(ctx, dc, release, c, query)
}
// BeginTx starts a transaction.
......@@ -1583,13 +1595,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
dc, err := c.grabConn()
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
c.closemu.RLock()
return c.db.beginDC(ctx, dc, c.closemuRUnlockCondReleaseConn, opts)
return c.db.beginDC(ctx, dc, release, opts)
}
// closemuRUnlockCondReleaseConn read unlocks closemu
......@@ -1601,6 +1611,10 @@ func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
}
}
func (c *Conn) txCtx() context.Context {
return nil
}
func (c *Conn) close(err error) error {
if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
return ErrConnDone
......@@ -1712,19 +1726,28 @@ func (tx *Tx) close(err error) {
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
return nil, nil, ctx.Err()
}
// closeme.RLock must come before the check for isDone to prevent the Tx from
// closing while a query is executing.
tx.closemu.RLock()
if tx.isDone() {
return nil, ErrTxDone
tx.closemu.RUnlock()
return nil, nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
return tx.dc, nil
return tx.dc, tx.closemuRUnlockRelease, nil
}
func (tx *Tx) txCtx() context.Context {
return tx.ctx
}
// closemuRUnlockRelease is used as a func(error) method value in
......@@ -1801,31 +1824,15 @@ func (tx *Tx) Rollback() error {
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
dc, err := tx.grabConn(ctx)
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
var si driver.Stmt
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
if err != nil {
return nil, err
}
stmt := &Stmt{
db: tx.db,
tx: tx,
txds: &driverStmt{
Locker: dc,
si: si,
},
query: query,
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
......@@ -1855,20 +1862,19 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
dc, release, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
var si driver.Stmt
var parentStmt *Stmt
stmt.mu.Lock()
if stmt.closed || stmt.tx != nil {
if stmt.closed || stmt.cg != nil {
// If the statement has been closed or already belongs to a
// transaction, we can't reuse it in this connection.
// Since tx.StmtContext should never need to be called with a
......@@ -1907,8 +1913,8 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
txs := &Stmt{
db: tx.db,
tx: tx,
txds: &driverStmt{
cg: tx,
cgds: &driverStmt{
Locker: dc,
si: si,
},
......@@ -1943,14 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
tx.closemu.RLock()
dc, err := tx.grabConn(ctx)
dc, release, err := tx.grabConn(ctx)
if err != nil {
tx.closemu.RUnlock()
return nil, err
}
return tx.db.execDC(ctx, dc, tx.closemuRUnlockRelease, query, args)
return tx.db.execDC(ctx, dc, release, query, args)
}
// Exec executes a query that doesn't return rows.
......@@ -1961,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
tx.closemu.RLock()
dc, err := tx.grabConn(ctx)
dc, release, err := tx.grabConn(ctx)
if err != nil {
tx.closemu.RUnlock()
return nil, err
}
return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
......@@ -2004,6 +2004,24 @@ type connStmt struct {
ds *driverStmt
}
// stmtConnGrabber represents a Tx or Conn that will return the underlying
// driverConn and release function.
type stmtConnGrabber interface {
// grabConn returns the driverConn and the associated release function
// that must be called when the operation completes.
grabConn(context.Context) (*driverConn, releaseConn, error)
// txCtx returns the transaction context if available.
// The returned context should be selected on along with
// any query context when awaiting a cancel.
txCtx() context.Context
}
var (
_ stmtConnGrabber = &Tx{}
_ stmtConnGrabber = &Conn{}
)
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct {
......@@ -2014,9 +2032,13 @@ type Stmt struct {
closemu sync.RWMutex // held exclusively during close, for read otherwise.
// If in a transaction, else both nil:
tx *Tx
txds *driverStmt
// If Stmt is prepared on a Tx or Conn then cg is present and will
// only ever grab a connection from cg.
// If cg is nil then the Stmt must grab an arbitrary connection
// from db and determine if it must prepare the stmt again by
// inspecting css.
cg stmtConnGrabber
cgds *driverStmt
// parentStmt is set when a transaction-specific statement
// is requested from an identical statement prepared on the same
......@@ -2031,8 +2053,8 @@ type Stmt struct {
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
// used if tx == nil and one is found that has idle
// connections. If tx != nil, txds is always used.
// used if cg == nil and one is found that has idle
// connections. If cg != nil, cgds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
......@@ -2120,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
......@@ -2131,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
return
}
// In a transaction, we always use the connection that the
// transaction was created on.
if s.tx != nil {
// In a transaction or connection, we always use the connection that the
// the stmt was created on.
if s.cg != nil {
s.mu.Unlock()
ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
releaseConn = func(error) {}
return ci, releaseConn, s.txds, nil
return dc, releaseConn, s.cgds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
dc, err := s.db.conn(ctx, strategy)
dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
......@@ -2175,7 +2196,7 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
si, err := dc.prepareLocked(ctx, s.query)
si, err := dc.prepareLocked(ctx, s.cg, s.query)
if err != nil {
return nil, err
}
......@@ -2226,8 +2247,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
s.db.removeDep(s, rows)
}
var txctx context.Context
if s.tx != nil {
txctx = s.tx.ctx
if s.cg != nil {
txctx = s.cg.txCtx()
}
rows.initContextClose(ctx, txctx)
return rows, nil
......@@ -2323,9 +2344,12 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
txds := s.cgds
s.cgds = nil
s.mu.Unlock()
if s.tx == nil {
if s.cg == nil {
return s.db.removeDep(s, s)
}
......@@ -2334,7 +2358,7 @@ func (s *Stmt) Close() error {
// in the css array of the parentStmt.
return s.db.removeDep(s.parentStmt, s)
}
return s.txds.Close()
return txds.Close()
}
func (s *Stmt) finalClose() error {
......
......@@ -877,7 +877,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
{&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
{&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
......@@ -3231,6 +3231,42 @@ func TestIssue18719(t *testing.T) {
cancel()
}
func TestIssue20647(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
rows1, err := stmt.QueryContext(ctx)
if err != nil {
t.Fatal("rows1", err)
}
defer rows1.Close()
rows2, err := stmt.QueryContext(ctx)
if err != nil {
t.Fatal("rows2", err)
}
defer rows2.Close()
if rows1.dc != rows2.dc {
t.Fatal("stmt prepared on Conn does not use same connection")
}
}
func TestConcurrency(t *testing.T) {
list := []struct {
name string
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment