Commit 2a85578b authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: support returning query database types

Creates a ColumnType structure that can be extended in to future.
Allow drivers to implement what makes sense for the database.

Fixes #16652

Change-Id: Ieb1fd64eac1460107b1d3474eba5201fa300a4ec
Reviewed-on: https://go-review.googlesource.com/29961
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 2ecaaf18
...@@ -11,6 +11,7 @@ package driver ...@@ -11,6 +11,7 @@ package driver
import ( import (
"context" "context"
"errors" "errors"
"reflect"
) )
// Value is a value that drivers must be able to handle. // Value is a value that drivers must be able to handle.
...@@ -239,6 +240,60 @@ type RowsNextResultSet interface { ...@@ -239,6 +240,60 @@ type RowsNextResultSet interface {
NextResultSet() error NextResultSet() error
} }
// RowsColumnTypeScanType may be implemented by Rows. It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
type RowsColumnTypeScanType interface {
Rows
ColumnTypeScanType(index int) reflect.Type
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
// database system type name without the length. Type names should be uppercase.
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
type RowsColumnTypeDatabaseTypeName interface {
Rows
ColumnTypeDatabaseTypeName(index int) string
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
type RowsColumnTypeLength interface {
Rows
ColumnTypeLength(index int) (length int64, ok bool)
}
// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
// be true if it is known the column may be null, or false if the column is known
// to be not nullable.
// If the column nullability is unknown, ok should be false.
type RowsColumnTypeNullable interface {
Rows
ColumnTypeNullable(index int) (nullable, ok bool)
}
// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
type RowsColumnTypePrecisionScale interface {
Rows
ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
}
// Tx is a transaction. // Tx is a transaction.
type Tx interface { type Tx interface {
Commit() error Commit() error
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"reflect"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
...@@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err ...@@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
stmt.colName = strings.Split(parts[1], ",") stmt.colName = strings.Split(parts[1], ",")
for n, colspec := range strings.Split(parts[2], ",") { for n, colspec := range strings.Split(parts[2], ",") {
if colspec == "" { if colspec == "" {
...@@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
setMRows := make([][]*row, 0, 1) setMRows := make([][]*row, 0, 1)
setColumns := make([][]string, 0, 1) setColumns := make([][]string, 0, 1)
setColType := make([][]string, 0, 1)
for { for {
db.mu.Lock() db.mu.Lock()
...@@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
mrows = append(mrows, mrow) mrows = append(mrows, mrow)
} }
var colType []string
for _, column := range s.colName {
colType = append(colType, t.coltype[t.columnIndex(column)])
}
t.mu.Unlock() t.mu.Unlock()
setMRows = append(setMRows, mrows) setMRows = append(setMRows, mrows)
setColumns = append(setColumns, s.colName) setColumns = append(setColumns, s.colName)
setColType = append(setColType, colType)
if s.next == nil { if s.next == nil {
break break
...@@ -809,6 +818,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -809,6 +818,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
posRow: -1, posRow: -1,
rows: setMRows, rows: setMRows,
cols: setColumns, cols: setColumns,
colType: setColType,
errPos: -1, errPos: -1,
} }
return cursor, nil return cursor, nil
...@@ -845,6 +855,7 @@ func (tx *fakeTx) Rollback() error { ...@@ -845,6 +855,7 @@ func (tx *fakeTx) Rollback() error {
type rowsCursor struct { type rowsCursor struct {
cols [][]string cols [][]string
colType [][]string
posSet int posSet int
posRow int posRow int
rows [][]*row rows [][]*row
...@@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string { ...@@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string {
return rc.cols[rc.posSet] return rc.cols[rc.posSet]
} }
func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
return colTypeToReflectType(rc.colType[rc.posSet][index])
}
var rowsCursorNextHook func(dest []driver.Value) error var rowsCursorNextHook func(dest []driver.Value) error
func (rc *rowsCursor) Next(dest []driver.Value) error { func (rc *rowsCursor) Next(dest []driver.Value) error {
...@@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter { ...@@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter {
} }
panic("invalid fakedb column type of " + typ) panic("invalid fakedb column type of " + typ)
} }
func colTypeToReflectType(typ string) reflect.Type {
switch typ {
case "bool":
return reflect.TypeOf(false)
case "nullbool":
return reflect.TypeOf(NullBool{})
case "int32":
return reflect.TypeOf(int32(0))
case "string":
return reflect.TypeOf("")
case "nullstring":
return reflect.TypeOf(NullString{})
case "int64":
return reflect.TypeOf(int64(0))
case "nullint64":
return reflect.TypeOf(NullInt64{})
case "float64":
return reflect.TypeOf(float64(0))
case "nullfloat64":
return reflect.TypeOf(NullFloat64{})
case "datetime":
return reflect.TypeOf(time.Time{})
}
panic("invalid fakedb column type of " + typ)
}
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"reflect"
"runtime" "runtime"
"sort" "sort"
"sync" "sync"
...@@ -996,8 +997,8 @@ const maxBadConnRetries = 2 ...@@ -996,8 +997,8 @@ const maxBadConnRetries = 2
// The caller must call the statement's Close method // The caller must call the statement's Close method
// when the statement is no longer needed. // when the statement is no longer needed.
// //
// The provided context is for the preparation of the statment, not for the execution of // The provided context is used for the preparation of the statement, not for the
// the statement. // execution of the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt var stmt *Stmt
var err error var err error
...@@ -2033,6 +2034,107 @@ func (rs *Rows) Columns() ([]string, error) { ...@@ -2033,6 +2034,107 @@ func (rs *Rows) Columns() ([]string, error) {
return rs.rowsi.Columns(), nil return rs.rowsi.Columns(), nil
} }
// ColumnTypes returns column information such as column type, length,
// and nullable. Some information may not be available from some drivers.
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
if rs.isClosed() {
return nil, errors.New("sql: Rows are closed")
}
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
return rowsColumnInfoSetup(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
type ColumnType struct {
name string
hasNullable bool
hasLength bool
hasPrecisionScale bool
nullable bool
length int64
databaseType string
precision int64
scale int64
scanType reflect.Type
}
// Name returns the name or alias of the column.
func (ci *ColumnType) Name() string {
return ci.name
}
// Length returns the column type length for variable length column types such
// as text and binary field types. If the type length is unbounded the value will
// be math.MaxInt64 (any database limits will still apply).
// If the column type is not variable length, such as an int, or if not supported
// by the driver ok is false.
func (ci *ColumnType) Length() (length int64, ok bool) {
return ci.length, ci.hasLength
}
// DecimalSize returns the scale and precision of a decimal type.
// If not applicable or if not supported ok is false.
func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return ci.precision, ci.scale, ci.hasPrecisionScale
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
// If a driver does not support this property ScanType will return
// the type of an empty interface.
func (ci *ColumnType) ScanType() reflect.Type {
return ci.scanType
}
// Nullable returns whether the column may be null.
// If a driver does not support this property ok will be false.
func (ci *ColumnType) Nullable() (nullable, ok bool) {
return ci.nullable, ci.hasNullable
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT".
func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))
for i := range list {
ci := &ColumnType{
name: names[i],
}
list[i] = ci
if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
ci.scanType = prop.ColumnTypeScanType(i)
} else {
ci.scanType = reflect.TypeOf(new(interface{})).Elem()
}
if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
ci.length, ci.hasLength = prop.ColumnTypeLength(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
}
}
return list
}
// Scan copies the columns in the current row into the values pointed // Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the // at by dest. The number of values in dest must be the same as the
// number of columns in Rows. // number of columns in Rows.
......
...@@ -499,6 +499,56 @@ func TestRowsColumns(t *testing.T) { ...@@ -499,6 +499,56 @@ func TestRowsColumns(t *testing.T) {
} }
} }
func TestRowsColumnTypes(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
rows, err := db.Query("SELECT|people|age,name|")
if err != nil {
t.Fatalf("Query: %v", err)
}
tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}
types := make([]reflect.Type, len(tt))
for i, tp := range tt {
st := tp.ScanType()
if st == nil {
t.Errorf("scantype is null for column %q", tp.Name())
continue
}
types[i] = st
}
values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
ct := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
ct++
if ct == 0 {
if values[0].(string) != "Bob" {
t.Errorf("Expected Bob, got %v", values[0])
}
if values[1].(int) != 2 {
t.Errorf("Expected 2, got %v", values[1])
}
}
}
if ct != 3 {
t.Errorf("expected 3 rows, got %d", ct)
}
if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
}
func TestQueryRow(t *testing.T) { func TestQueryRow(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment