Commit 707a8334 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: add option to use named parameter in query arguments

Modify the new Context methods to take a name-value driver struct.
This will require more modifications to drivers to use, but will
reduce the overall number of structures that need to be maintained
over time.

Fixes #12381

Change-Id: I30747533ce418a1be5991a0c8767a26e8451adbd
Reviewed-on: https://go-review.googlesource.com/30166Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 99df54f1
...@@ -21,8 +21,8 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript ...@@ -21,8 +21,8 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript
// Stmt.Query into driver Values. // Stmt.Query into driver Values.
// //
// The statement ds may be nil, if no statement is available. // The statement ds may be nil, if no statement is available.
func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
dargs := make([]driver.Value, len(args)) nvargs := make([]driver.NamedValue, len(args))
var si driver.Stmt var si driver.Stmt
if ds != nil { if ds != nil {
si = ds.si si = ds.si
...@@ -33,16 +33,27 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { ...@@ -33,16 +33,27 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
if !ok { if !ok {
for n, arg := range args { for n, arg := range args {
var err error var err error
dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg) nvargs[n].Ordinal = n + 1
if np, ok := arg.(NamedParam); ok {
arg = np.Value
nvargs[n].Name = np.Name
}
nvargs[n].Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err) return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
} }
} }
return dargs, nil return nvargs, nil
} }
// Let the Stmt convert its own arguments. // Let the Stmt convert its own arguments.
for n, arg := range args { for n, arg := range args {
nvargs[n].Ordinal = n + 1
if np, ok := arg.(NamedParam); ok {
arg = np.Value
nvargs[n].Name = np.Name
}
// First, see if the value itself knows how to convert // First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString // itself to a driver type. For example, a NullString
// struct changing into a string or nil. // struct changing into a string or nil.
...@@ -66,18 +77,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { ...@@ -66,18 +77,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
// same error. // same error.
var err error var err error
ds.Lock() ds.Lock()
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) nvargs[n].Value, err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock() ds.Unlock()
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
} }
if !driver.IsValue(dargs[n]) { if !driver.IsValue(nvargs[n].Value) {
return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T", return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
arg, dargs[n]) arg, nvargs[n].Value)
} }
} }
return dargs, nil return nvargs, nil
} }
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
......
...@@ -50,9 +50,13 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver ...@@ -50,9 +50,13 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
} }
} }
func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) { func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx, is := execer.(driver.ExecerContext); is { if execerCtx, is := execer.(driver.ExecerContext); is {
return execerCtx.ExecContext(ctx, query, dargs) return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
} }
if ctx.Done() == context.Background().Done() { if ctx.Done() == context.Background().Done() {
return execer.Exec(query, dargs) return execer.Exec(query, dargs)
...@@ -90,9 +94,13 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg ...@@ -90,9 +94,13 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg
} }
} }
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) { func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is { if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, dargs) return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
} }
if ctx.Done() == context.Background().Done() { if ctx.Done() == context.Background().Done() {
return queryer.Query(query, dargs) return queryer.Query(query, dargs)
...@@ -130,9 +138,13 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d ...@@ -130,9 +138,13 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d
} }
} }
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) { func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is { if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, dargs) return siCtx.ExecContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
} }
if ctx.Done() == context.Background().Done() { if ctx.Done() == context.Background().Done() {
return si.Exec(dargs) return si.Exec(dargs)
...@@ -170,9 +182,13 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value ...@@ -170,9 +182,13 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value
} }
} }
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) { func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is { if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, dargs) return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
} }
if ctx.Done() == context.Background().Done() { if ctx.Done() == context.Background().Done() {
return si.Query(dargs) return si.Query(dargs)
...@@ -253,3 +269,14 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { ...@@ -253,3 +269,14 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
return r.txi, r.err return r.txi, r.err
} }
} }
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
...@@ -24,6 +24,16 @@ import ( ...@@ -24,6 +24,16 @@ import (
// time.Time // time.Time
type Value interface{} type Value interface{}
// NamedValue holds both the value name and value.
// The Ordinal is the position of the parameter starting from one and is always set.
// If the Name is not empty it should be used for the parameter identifier and
// not the ordinal position.
type NamedValue struct {
Name string
Ordinal int
Value Value
}
// Driver is the interface that must be implemented by a database // Driver is the interface that must be implemented by a database
// driver. // driver.
type Driver interface { type Driver interface {
...@@ -71,7 +81,7 @@ type Execer interface { ...@@ -71,7 +81,7 @@ type Execer interface {
// ExecerContext is like execer, but must honor the context timeout and return // ExecerContext is like execer, but must honor the context timeout and return
// when the context is cancelled. // when the context is cancelled.
type ExecerContext interface { type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []Value) (Result, error) ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
} }
// Queryer is an optional interface that may be implemented by a Conn. // Queryer is an optional interface that may be implemented by a Conn.
...@@ -88,7 +98,7 @@ type Queryer interface { ...@@ -88,7 +98,7 @@ type Queryer interface {
// QueryerContext is like Queryer, but most honor the context timeout and return // QueryerContext is like Queryer, but most honor the context timeout and return
// when the context is cancelled. // when the context is cancelled.
type QueryerContext interface { type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []Value) (Rows, error) QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
} }
// Conn is a connection to a database. It is not used concurrently // Conn is a connection to a database. It is not used concurrently
...@@ -174,13 +184,13 @@ type Stmt interface { ...@@ -174,13 +184,13 @@ type Stmt interface {
// StmtExecContext enhances the Stmt interface by providing Exec with context. // StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface { type StmtExecContext interface {
// ExecContext must honor the context timeout and return when it is cancelled. // ExecContext must honor the context timeout and return when it is cancelled.
ExecContext(ctx context.Context, args []Value) (Result, error) ExecContext(ctx context.Context, args []NamedValue) (Result, error)
} }
// StmtQueryContext enhances the Stmt interface by providing Query with context. // StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface { type StmtQueryContext interface {
// QueryContext must honor the context timeout and return when it is cancelled. // QueryContext must honor the context timeout and return when it is cancelled.
QueryContext(ctx context.Context, args []Value) (Rows, error) QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
} }
// ColumnConverter may be optionally implemented by Stmt if the // ColumnConverter may be optionally implemented by Stmt if the
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package sql package sql
import ( import (
"context"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
...@@ -32,6 +33,7 @@ var _ = log.Printf ...@@ -32,6 +33,7 @@ var _ = log.Printf
// where types are: "string", [u]int{8,16,32,64}, "bool" // where types are: "string", [u]int{8,16,32,64}, "bool"
// INSERT|<tablename>|col=val,col2=val2,col3=? // INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
// //
// Any of these can be preceded by PANIC|<method>|, to cause the // Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic. // named method on fakeStmt to panic.
...@@ -103,6 +105,12 @@ type fakeTx struct { ...@@ -103,6 +105,12 @@ type fakeTx struct {
c *fakeConn c *fakeConn
} }
type boundCol struct {
Column string
Placeholder string
Ordinal int
}
type fakeStmt struct { type fakeStmt struct {
c *fakeConn c *fakeConn
q string // just for debugging q string // just for debugging
...@@ -120,7 +128,7 @@ type fakeStmt struct { ...@@ -120,7 +128,7 @@ type fakeStmt struct {
colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
placeholders int // used by INSERT/SELECT: number of ? params placeholders int // used by INSERT/SELECT: number of ? params
whereCol []string // used by SELECT (all placeholders) whereCol []boundCol // used by SELECT (all placeholders)
placeholderConverter []driver.ValueConverter // used by INSERT placeholderConverter []driver.ValueConverter // used by INSERT
} }
...@@ -339,18 +347,23 @@ func (c *fakeConn) Close() (err error) { ...@@ -339,18 +347,23 @@ func (c *fakeConn) Close() (err error) {
return nil return nil
} }
func checkSubsetTypes(args []driver.Value) error { func checkSubsetTypes(args []driver.NamedValue) error {
for n, arg := range args { for _, arg := range args {
switch arg.(type) { switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time: case int64, float64, bool, nil, []byte, string, time.Time:
default: default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
} }
} }
return nil return nil
} }
func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
// Ensure that ExecContext is called if available.
panic("ExecContext was not called.")
}
func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// This is an optional interface, but it's implemented here // This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types. // just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't // ErrSkip is returned so the caller acts as if we didn't
...@@ -363,6 +376,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error ...@@ -363,6 +376,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
} }
func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
// Ensure that ExecContext is called if available.
panic("QueryContext was not called.")
}
func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// This is an optional interface, but it's implemented here // This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types. // just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't // ErrSkip is returned so the caller acts as if we didn't
...@@ -403,13 +421,13 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err ...@@ -403,13 +421,13 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
stmt.Close() stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
} }
if value != "?" { if !strings.HasPrefix(value, "?") {
stmt.Close() stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column) stmt.table, column)
} }
stmt.whereCol = append(stmt.whereCol, column)
stmt.placeholders++ stmt.placeholders++
stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
} }
return stmt, nil return stmt, nil
} }
...@@ -454,7 +472,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err ...@@ -454,7 +472,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
} }
stmt.colName = append(stmt.colName, column) stmt.colName = append(stmt.colName, column)
if value != "?" { if !strings.HasPrefix(value, "?") {
var subsetVal interface{} var subsetVal interface{}
// Convert to driver subset type // Convert to driver subset type
switch ctype { switch ctype {
...@@ -477,7 +495,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err ...@@ -477,7 +495,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
} else { } else {
stmt.placeholders++ stmt.placeholders++
stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
stmt.colValue = append(stmt.colValue, "?") stmt.colValue = append(stmt.colValue, value)
} }
} }
return stmt, nil return stmt, nil
...@@ -580,6 +598,9 @@ var errClosed = errors.New("fakedb: statement has been closed") ...@@ -580,6 +598,9 @@ var errClosed = errors.New("fakedb: statement has been closed")
var hookExecBadConn func() bool var hookExecBadConn func() bool
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
panic("Using ExecContext")
}
func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s.panic == "Exec" { if s.panic == "Exec" {
panic(s.panic) panic(s.panic)
} }
...@@ -620,7 +641,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { ...@@ -620,7 +641,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
// When doInsert is true, add the row to the table. // When doInsert is true, add the row to the table.
// When doInsert is false do prep-work and error checking, but don't // When doInsert is false do prep-work and error checking, but don't
// actually add the row to the table. // actually add the row to the table.
func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
db := s.c.db db := s.c.db
if len(args) != s.placeholders { if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")
...@@ -646,8 +667,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result ...@@ -646,8 +667,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
} }
var val interface{} var val interface{}
if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
val = args[argPos] if strvalue == "?" {
val = args[argPos].Value
} else {
// Assign value from argument placeholder name.
for _, a := range args {
if a.Name == strvalue {
val = a.Value
break
}
}
}
argPos++ argPos++
} else { } else {
val = s.colValue[n] val = s.colValue[n]
...@@ -667,6 +698,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result ...@@ -667,6 +698,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
var hookQueryBadConn func() bool var hookQueryBadConn func() bool
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("Use QueryContext")
}
func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s.panic == "Query" { if s.panic == "Query" {
panic(s.panic) panic(s.panic)
} }
...@@ -700,9 +735,9 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -700,9 +735,9 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
} }
if s.table == "magicquery" { if s.table == "magicquery" {
if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
if args[0] == "sleep" { if args[0].Value == "sleep" {
time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
} }
} }
} }
...@@ -725,8 +760,8 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -725,8 +760,8 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
// Process the where clause, skipping non-match rows. This is lazy // Process the where clause, skipping non-match rows. This is lazy
// and just uses fmt.Sprintf("%v") to test equality. Good enough // and just uses fmt.Sprintf("%v") to test equality. Good enough
// for test code. // for test code.
for widx, wcol := range s.whereCol { for _, wcol := range s.whereCol {
idx := t.columnIndex(wcol) idx := t.columnIndex(wcol.Column)
if idx == -1 { if idx == -1 {
t.mu.Unlock() t.mu.Unlock()
return nil, fmt.Errorf("db: invalid where clause column %q", wcol) return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
...@@ -736,7 +771,19 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -736,7 +771,19 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
// lazy hack to avoid sprintf %v on a []byte // lazy hack to avoid sprintf %v on a []byte
tcol = string(bs) tcol = string(bs)
} }
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { var argValue interface{}
if wcol.Placeholder == "?" {
argValue = args[wcol.Ordinal-1].Value
} else {
// Assign arg value from placeholder name.
for _, a := range args {
if a.Name == wcol.Placeholder {
argValue = a.Value
break
}
}
}
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
continue rows continue rows
} }
} }
......
...@@ -67,6 +67,27 @@ func Drivers() []string { ...@@ -67,6 +67,27 @@ func Drivers() []string {
return list return list
} }
// NamedParam may be passed into query parameter arguments to associate
// a named placeholder with a value.
type NamedParam struct {
// Name of the parameter placeholder. If empty the ordinal position in the
// argument list will be used.
Name string
// Value of the parameter. It may be assigned the same value types as
// the query arguments.
Value interface{}
}
// Param provides a more concise way to create NamedParam values.
func Param(name string, value interface{}) NamedParam {
// This method exists because the go1compat promise
// doesn't guarantee that structs don't grow more fields,
// so unkeyed struct literals are a vet error. Thus, we don't
// want to encourage sql.NamedParam{name, value}.
return NamedParam{Name: name, Value: value}
}
// RawBytes is a byte slice that holds a reference to memory owned by // RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a Scan into a RawBytes, the slice is only // the database itself. After a Scan into a RawBytes, the slice is only
// valid until the next call to Next, Scan, or Close. // valid until the next call to Next, Scan, or Close.
...@@ -1064,7 +1085,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate ...@@ -1064,7 +1085,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
}() }()
if execer, ok := dc.ci.(driver.Execer); ok { if execer, ok := dc.ci.(driver.Execer); ok {
var dargs []driver.Value var dargs []driver.NamedValue
dargs, err = driverArgs(nil, args) dargs, err = driverArgs(nil, args)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -395,6 +395,53 @@ func TestMultiResultSetQuery(t *testing.T) { ...@@ -395,6 +395,53 @@ func TestMultiResultSetQuery(t *testing.T) {
} }
} }
func TestQueryNamedParam(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
prepares0 := numPrepares(t, db)
rows, err := db.Query(
// Ensure the name and age parameters only match on placeholder name, not position.
"SELECT|people|age,name|name=?name,age=?age",
Param("?age", 2),
Param("?name", "Bob"),
)
if err != nil {
t.Fatalf("Query: %v", err)
}
type row struct {
age int
name string
}
got := []row{}
for rows.Next() {
var r row
err = rows.Scan(&r.age, &r.name)
if err != nil {
t.Fatalf("Scan: %v", err)
}
got = append(got, r)
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want := []row{
{age: 2, name: "Bob"},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
}
// And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection.
if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestByteOwnership(t *testing.T) { func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment