Commit 93fe8c0c authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

database/sql: use driver.ColumnConverter everywhere consistently

It was only being used for (*Stmt).Exec, not Query, and not for
the same two methods on *DB.

This unifies (*Stmt).Exec's old inline code into the old
subsetArgs function, renaming it in the process (changing the
old word "subset" to "driver", mostly converted earlier)

Fixes #3640

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6258045
parent 6dbaa206
...@@ -14,19 +14,61 @@ import ( ...@@ -14,19 +14,61 @@ import (
"strconv" "strconv"
) )
// subsetTypeArgs takes a slice of arguments from callers of the sql // driverArgs converts arguments from callers of Stmt.Exec and
// package and converts them into a slice of the driver package's // Stmt.Query into driver Values.
// "subset types". //
func subsetTypeArgs(args []interface{}) ([]driver.Value, error) { // The statement si may be nil, if no statement is available.
out := make([]driver.Value, len(args)) func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
dargs := make([]driver.Value, len(args))
cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter.
if !ok {
for n, arg := range args { for n, arg := range args {
var err error var err error
out[n], err = driver.DefaultParameterConverter.ConvertValue(arg) dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
}
}
return dargs, nil
}
// Let the Stmt convert its own arguments.
for n, arg := range args {
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
if svi, ok := arg.(driver.Valuer); ok {
sv, err := svi.Value()
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err) return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err)
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
} }
arg = sv
} }
return out, nil
// Second, ask the column to sanity check itself. For
// example, drivers might use this to make sure that
// an int64 values being inserted into a 16-bit
// integer field is in range (before getting
// truncated), or that a nil can't go into a NOT NULL
// column before going across the network to get the
// same error.
var err error
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
}
if !driver.IsValue(dargs[n]) {
return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
arg, dargs[n])
}
}
return dargs, nil
} }
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
......
...@@ -383,6 +383,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -383,6 +383,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
} }
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
if len(s.placeholderConverter) == 0 {
return driver.DefaultParameterConverter
}
return s.placeholderConverter[idx] return s.placeholderConverter[idx]
} }
...@@ -598,6 +601,28 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { ...@@ -598,6 +601,28 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return nil return nil
} }
// fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter.
//
// This could be surprising behavior to retroactively apply to
// driver.String now that Go1 is out, but this is convenient for
// our TestPointerParamsAndScans.
//
type fakeDriverString struct{}
func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
switch c := v.(type) {
case string, []byte:
return v, nil
case *string:
if c == nil {
return nil, nil
}
return *c, nil
}
return fmt.Sprintf("%v", v), nil
}
func converterForType(typ string) driver.ValueConverter { func converterForType(typ string) driver.ValueConverter {
switch typ { switch typ {
case "bool": case "bool":
...@@ -607,9 +632,9 @@ func converterForType(typ string) driver.ValueConverter { ...@@ -607,9 +632,9 @@ func converterForType(typ string) driver.ValueConverter {
case "int32": case "int32":
return driver.Int32 return driver.Int32
case "string": case "string":
return driver.NotNull{Converter: driver.String} return driver.NotNull{Converter: fakeDriverString{}}
case "nullstring": case "nullstring":
return driver.Null{Converter: driver.String} return driver.Null{Converter: fakeDriverString{}}
case "int64": case "int64":
// TODO(coopernurse): add type-specific converter // TODO(coopernurse): add type-specific converter
return driver.NotNull{Converter: driver.DefaultParameterConverter} return driver.NotNull{Converter: driver.DefaultParameterConverter}
......
...@@ -326,13 +326,10 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) { ...@@ -326,13 +326,10 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) {
// Exec executes a query without returning any rows. // Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) { func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
var res Result var res Result
var err error
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
res, err = db.exec(query, sargs) res, err = db.exec(query, args)
if err != driver.ErrBadConn { if err != driver.ErrBadConn {
break break
} }
...@@ -340,7 +337,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { ...@@ -340,7 +337,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return res, err return res, err
} }
func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) { func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
ci, err := db.conn() ci, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -348,7 +345,11 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) { ...@@ -348,7 +345,11 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
defer db.putConn(ci, err) defer db.putConn(ci, err)
if execer, ok := ci.(driver.Execer); ok { if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, sargs) dargs, err := driverArgs(nil, args)
if err != nil {
return nil, err
}
resi, err := execer.Exec(query, dargs)
if err != driver.ErrSkip { if err != driver.ErrSkip {
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -363,7 +364,12 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) { ...@@ -363,7 +364,12 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(sargs) dargs, err := driverArgs(sti, args)
if err != nil {
return nil, err
}
resi, err := sti.Exec(dargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -577,13 +583,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -577,13 +583,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
} }
defer tx.releaseConn() defer tx.releaseConn()
sargs, err := subsetTypeArgs(args) if execer, ok := ci.(driver.Execer); ok {
dargs, err := driverArgs(nil, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resi, err := execer.Exec(query, dargs)
if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, sargs)
if err == nil { if err == nil {
return result{resi}, nil return result{resi}, nil
} }
...@@ -598,7 +603,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -598,7 +603,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(sargs) dargs, err := driverArgs(sti, args)
if err != nil {
return nil, err
}
resi, err := sti.Exec(dargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -674,51 +684,12 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -674,51 +684,12 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
} }
sargs := make([]driver.Value, len(args)) dargs, err := driverArgs(si, args)
// Convert args to subset types.
if cc, ok := si.(driver.ColumnConverter); ok {
for n, arg := range args {
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
if svi, ok := arg.(driver.Valuer); ok {
sv, err := svi.Value()
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err) return nil, err
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
}
arg = sv
}
// Second, ask the column to sanity check itself. For
// example, drivers might use this to make sure that
// an int64 values being inserted into a 16-bit
// integer field is in range (before getting
// truncated), or that a nil can't go into a NOT NULL
// column before going across the network to get the
// same error.
sargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
}
if !driver.IsValue(sargs[n]) {
return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
arg, sargs[n])
}
}
} else {
for n, arg := range args {
sargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
}
}
} }
resi, err := si.Exec(sargs) resi, err := si.Exec(dargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -805,11 +776,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -805,11 +776,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
if want := si.NumInput(); want != -1 && len(args) != want { if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args))
} }
sargs, err := subsetTypeArgs(args)
dargs, err := driverArgs(si, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rowsi, err := si.Query(sargs)
rowsi, err := si.Query(dargs)
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
return nil, err return nil, err
......
...@@ -306,8 +306,8 @@ func TestExec(t *testing.T) { ...@@ -306,8 +306,8 @@ func TestExec(t *testing.T) {
{[]interface{}{7, 9}, ""}, {[]interface{}{7, 9}, ""},
// Invalid conversions: // Invalid conversions:
{[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting Exec argument #1's type: sql/driver: value 4294967295 overflows int32"}, {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"},
{[]interface{}{"Brad", "strconv fail"}, "sql: converting Exec argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"},
// Wrong number of args: // Wrong number of args:
{[]interface{}{}, "sql: expected 2 arguments, got 0"}, {[]interface{}{}, "sql: expected 2 arguments, got 0"},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment