Commit 729685c1 authored by Daniel Theophanes's avatar Daniel Theophanes

database/sql: ensure Rows is closed when Tx closes

Close any Rows queried within a Tx when the Tx is closed. This prevents
the Tx from blocking on rollback if a Rows query has not been closed yet.

Fixes #20575

Change-Id: I4efe9c4150e951d8a0f1c40d9d5e325964fdd608
Reviewed-on: https://go-review.googlesource.com/44812Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent eea8c88a
...@@ -1284,12 +1284,14 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat ...@@ -1284,12 +1284,14 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
return nil, err return nil, err
} }
return db.queryDC(ctx, dc, dc.releaseConn, query, args) return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
} }
// queryDC executes a query on the given connection. // queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function. // The connection gets released by the releaseConn function.
func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { // The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok { if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(dc.ci, nil, args) dargs, err := driverArgs(dc.ci, nil, args)
if err != nil { if err != nil {
...@@ -1312,7 +1314,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro ...@@ -1312,7 +1314,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
releaseConn: releaseConn, releaseConn: releaseConn,
rowsi: rowsi, rowsi: rowsi,
} }
rows.initContextClose(ctx) rows.initContextClose(ctx, txctx)
return rows, nil return rows, nil
} }
} }
...@@ -1343,7 +1345,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro ...@@ -1343,7 +1345,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
rowsi: rowsi, rowsi: rowsi,
closeStmt: ds, closeStmt: ds,
} }
rows.initContextClose(ctx) rows.initContextClose(ctx, txctx)
return rows, nil return rows, nil
} }
...@@ -1532,7 +1534,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface ...@@ -1532,7 +1534,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface
} }
c.closemu.RLock() c.closemu.RLock()
return c.db.queryDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args) return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
} }
// QueryRowContext executes a query that is expected to return at most one row. // QueryRowContext executes a query that is expected to return at most one row.
...@@ -1687,11 +1689,12 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle ...@@ -1687,11 +1689,12 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
// close returns the connection to the pool and // close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit. // must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) { func (tx *Tx) close(err error) {
tx.cancel()
tx.closemu.Lock() tx.closemu.Lock()
defer tx.closemu.Unlock() defer tx.closemu.Unlock()
tx.releaseConn(err) tx.releaseConn(err)
tx.cancel()
tx.dc = nil tx.dc = nil
tx.txi = nil tx.txi = nil
} }
...@@ -1957,7 +1960,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{ ...@@ -1957,7 +1960,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
return nil, err return nil, err
} }
return tx.db.queryDC(ctx, dc, tx.closemuRUnlockRelease, query, args) return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
} }
// Query executes a query that returns rows, typically a SELECT. // Query executes a query that returns rows, typically a SELECT.
...@@ -2207,7 +2210,11 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er ...@@ -2207,7 +2210,11 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
releaseConn(err) releaseConn(err)
s.db.removeDep(s, rows) s.db.removeDep(s, rows)
} }
rows.initContextClose(ctx) var txctx context.Context
if s.tx != nil {
txctx = s.tx.ctx
}
rows.initContextClose(ctx, txctx)
return rows, nil return rows, nil
} }
...@@ -2363,14 +2370,24 @@ type Rows struct { ...@@ -2363,14 +2370,24 @@ type Rows struct {
lastcols []driver.Value lastcols []driver.Value
} }
func (rs *Rows) initContextClose(ctx context.Context) { func (rs *Rows) initContextClose(ctx, txctx context.Context) {
ctx, rs.cancel = context.WithCancel(ctx) ctx, rs.cancel = context.WithCancel(ctx)
go rs.awaitDone(ctx) go rs.awaitDone(ctx, txctx)
} }
// awaitDone blocks until the rows are closed or the context canceled. // awaitDone blocks until either ctx or txctx is canceled. The ctx is provided
func (rs *Rows) awaitDone(ctx context.Context) { // from the query context and is canceled when the query Rows is closed.
<-ctx.Done() // If the query was issued in a transaction, the transaction's context
// is also provided in txctx to ensure Rows is closed if the Tx is closed.
func (rs *Rows) awaitDone(ctx, txctx context.Context) {
var txctxDone <-chan struct{}
if txctx != nil {
txctxDone = txctx.Done()
}
select {
case <-ctx.Done():
case <-txctxDone:
}
rs.close(ctx.Err()) rs.close(ctx.Err())
} }
......
...@@ -2467,6 +2467,32 @@ func TestManyErrBadConn(t *testing.T) { ...@@ -2467,6 +2467,32 @@ func TestManyErrBadConn(t *testing.T) {
} }
} }
// TestIssue20575 ensures the Rows from query does not block
// closing a transaction. Ensure Rows is closed while closing a trasaction.
func TestIssue20575(t *testing.T) {
db := newTestDB(t, "people")
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
if err != nil {
t.Fatal(err)
}
// Do not close Rows from QueryContext.
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
select {
default:
case <-ctx.Done():
t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
}
}
// golang.org/issue/5718 // golang.org/issue/5718
func TestErrBadConnReconnect(t *testing.T) { func TestErrBadConnReconnect(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment