Commit a16954b8 authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: always use a function literal for pointer checking

The pointer checking code needs to know the exact type of the parameter
expected by the C function, so that it can use a type assertion to
convert the empty interface returned by cgoCheckPointer to the correct
type. Previously this was done by using a type conversion, but that
meant that the code accepted arguments that were convertible to the
parameter type, rather than arguments that were assignable as in a
normal function call. In other words, some code that should not have
passed type checking was accepted.

This CL changes cgo to always use a function literal for pointer
checking. Now the argument is passed to the function literal, which has
the correct argument type, so type checking is performed just as for a
function call as it should be.

Since we now always use a function literal, simplify the checking code
to run as a statement by itself. It now no longer needs to return a
value, and we no longer need a type assertion.

This does have the cost of introducing another function call into any
call to a C function that requires pointer checking, but the cost of the
additional call should be minimal compared to the cost of pointer
checking.

Fixes #16591.

Change-Id: I220165564cf69db9fd5f746532d7f977a5b2c989
Reviewed-on: https://go-review.googlesource.com/31233
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarMatthew Dempsky <mdempsky@google.com>
parent e32ac797
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Issue 16591: Test that we detect an invalid call that was being
// hidden by a type conversion inserted by cgo checking.
package p
// void f(int** p) { }
import "C"
type x *C.int
func F(p *x) {
C.f(p) // ERROR HERE
}
...@@ -46,6 +46,7 @@ check issue13423.go ...@@ -46,6 +46,7 @@ check issue13423.go
expect issue13635.go C.uchar C.schar C.ushort C.uint C.ulong C.longlong C.ulonglong C.complexfloat C.complexdouble expect issue13635.go C.uchar C.schar C.ushort C.uint C.ulong C.longlong C.ulonglong C.complexfloat C.complexdouble
check issue13830.go check issue13830.go
check issue16116.go check issue16116.go
check issue16591.go
if ! go build issue14669.go; then if ! go build issue14669.go; then
exit 1 exit 1
......
...@@ -186,6 +186,7 @@ func testCallbackCallers(t *testing.T) { ...@@ -186,6 +186,7 @@ func testCallbackCallers(t *testing.T) {
"runtime.asmcgocall", "runtime.asmcgocall",
"runtime.cgocall", "runtime.cgocall",
"test._Cfunc_callback", "test._Cfunc_callback",
"test.nestedCall.func1",
"test.nestedCall", "test.nestedCall",
"test.testCallbackCallers", "test.testCallbackCallers",
"test.TestCallbackCallers", "test.TestCallbackCallers",
......
...@@ -639,34 +639,57 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -639,34 +639,57 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// We need to rewrite this call. // We need to rewrite this call.
// //
// We are going to rewrite C.f(p) to C.f(_cgoCheckPointer(p)). // We are going to rewrite C.f(p) to
// If the call to C.f is deferred, that will check p at the // func (_cgo0 ptype) {
// point of the defer statement, not when the function is called, so // _cgoCheckPointer(_cgo0)
// rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p) // C.f(_cgo0)
// }(p)
// Using a function literal like this lets us do correct
// argument type checking, and works correctly if the call is
// deferred.
needsUnsafe := false needsUnsafe := false
var dargs []ast.Expr params := make([]*ast.Field, len(name.FuncType.Params))
if call.Deferred { args := make([]ast.Expr, len(name.FuncType.Params))
dargs = make([]ast.Expr, len(name.FuncType.Params)) var stmts []ast.Stmt
}
for i, param := range name.FuncType.Params { for i, param := range name.FuncType.Params {
// params is going to become the parameters of the
// function literal.
// args is going to become the list of arguments to the
// function literal.
// nparam is the parameter of the function literal that
// corresponds to param.
origArg := call.Call.Args[i] origArg := call.Call.Args[i]
darg := origArg args[i] = origArg
nparam := ast.NewIdent(fmt.Sprintf("_cgo%d", i))
// The Go version of the C type might use unsafe.Pointer,
// but the file might not import unsafe.
// Rewrite the Go type if necessary to use _cgo_unsafe.
ptype := p.rewriteUnsafe(param.Go)
if ptype != param.Go {
needsUnsafe = true
}
if call.Deferred { params[i] = &ast.Field{
dargs[i] = darg Names: []*ast.Ident{nparam},
darg = ast.NewIdent(fmt.Sprintf("_cgo%d", i)) Type: ptype,
call.Call.Args[i] = darg
} }
call.Call.Args[i] = nparam
if !p.needsPointerCheck(f, param.Go, origArg) { if !p.needsPointerCheck(f, param.Go, origArg) {
continue continue
} }
// Run the cgo pointer checks on nparam.
// Change the function literal to call the real function
// with the parameter passed through _cgoCheckPointer.
c := &ast.CallExpr{ c := &ast.CallExpr{
Fun: ast.NewIdent("_cgoCheckPointer"), Fun: ast.NewIdent("_cgoCheckPointer"),
Args: []ast.Expr{ Args: []ast.Expr{
darg, nparam,
}, },
} }
...@@ -674,77 +697,64 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -674,77 +697,64 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// expression. // expression.
c.Args = p.checkAddrArgs(f, c.Args, origArg) c.Args = p.checkAddrArgs(f, c.Args, origArg)
// The Go version of the C type might use unsafe.Pointer, stmt := &ast.ExprStmt{
// but the file might not import unsafe. X: c,
// Rewrite the Go type if necessary to use _cgo_unsafe.
ptype := p.rewriteUnsafe(param.Go)
if ptype != param.Go {
needsUnsafe = true
}
// In order for the type assertion to succeed, we need
// it to match the actual type of the argument. The
// only type we have is the type of the function
// parameter. We know that the argument type must be
// assignable to the function parameter type, or the
// code would not compile, but there is nothing
// requiring that the types be exactly the same. Add a
// type conversion to the argument so that the type
// assertion will succeed.
c.Args[0] = &ast.CallExpr{
Fun: ptype,
Args: []ast.Expr{
c.Args[0],
},
}
call.Call.Args[i] = &ast.TypeAssertExpr{
X: c,
Type: ptype,
} }
stmts = append(stmts, stmt)
} }
if call.Deferred { fcall := &ast.CallExpr{
params := make([]*ast.Field, len(name.FuncType.Params)) Fun: call.Call.Fun,
for i, param := range name.FuncType.Params { Args: call.Call.Args,
ptype := p.rewriteUnsafe(param.Go) }
if ptype != param.Go { ftype := &ast.FuncType{
needsUnsafe = true Params: &ast.FieldList{
} List: params,
params[i] = &ast.Field{ },
Names: []*ast.Ident{ }
ast.NewIdent(fmt.Sprintf("_cgo%d", i)), var fbody ast.Stmt
}, if name.FuncType.Result == nil {
Type: ptype, fbody = &ast.ExprStmt{
} X: fcall,
} }
} else {
dbody := &ast.CallExpr{ fbody = &ast.ReturnStmt{
Fun: call.Call.Fun, Results: []ast.Expr{fcall},
Args: call.Call.Args,
} }
call.Call.Fun = &ast.FuncLit{ rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
Type: &ast.FuncType{ if rtype != name.FuncType.Result.Go {
Params: &ast.FieldList{ needsUnsafe = true
List: params, }
}, ftype.Results = &ast.FieldList{
}, List: []*ast.Field{
Body: &ast.BlockStmt{ &ast.Field{
List: []ast.Stmt{ Type: rtype,
&ast.ExprStmt{
X: dbody,
},
}, },
}, },
} }
call.Call.Args = dargs }
call.Call.Lparen = token.NoPos call.Call.Fun = &ast.FuncLit{
call.Call.Rparen = token.NoPos Type: ftype,
Body: &ast.BlockStmt{
List: append(stmts, fbody),
},
}
call.Call.Args = args
call.Call.Lparen = token.NoPos
call.Call.Rparen = token.NoPos
// There is a Ref pointing to the old call.Call.Fun. // There is a Ref pointing to the old call.Call.Fun.
for _, ref := range f.Ref { for _, ref := range f.Ref {
if ref.Expr == &call.Call.Fun { if ref.Expr == &call.Call.Fun {
ref.Expr = &dbody.Fun ref.Expr = &fcall.Fun
// If this call expects two results, we have to
// adjust the results of the function we generated.
if ref.Context == "call2" {
ftype.Results.List = append(ftype.Results.List,
&ast.Field{
Type: ast.NewIdent("error"),
})
} }
} }
} }
......
...@@ -1379,14 +1379,14 @@ func _cgo_runtime_cgocall(unsafe.Pointer, uintptr) int32 ...@@ -1379,14 +1379,14 @@ func _cgo_runtime_cgocall(unsafe.Pointer, uintptr) int32
func _cgo_runtime_cgocallback(unsafe.Pointer, unsafe.Pointer, uintptr, uintptr) func _cgo_runtime_cgocallback(unsafe.Pointer, unsafe.Pointer, uintptr, uintptr)
//go:linkname _cgoCheckPointer runtime.cgoCheckPointer //go:linkname _cgoCheckPointer runtime.cgoCheckPointer
func _cgoCheckPointer(interface{}, ...interface{}) interface{} func _cgoCheckPointer(interface{}, ...interface{})
//go:linkname _cgoCheckResult runtime.cgoCheckResult //go:linkname _cgoCheckResult runtime.cgoCheckResult
func _cgoCheckResult(interface{}) func _cgoCheckResult(interface{})
` `
const gccgoGoProlog = ` const gccgoGoProlog = `
func _cgoCheckPointer(interface{}, ...interface{}) interface{} func _cgoCheckPointer(interface{}, ...interface{})
func _cgoCheckResult(interface{}) func _cgoCheckResult(interface{})
` `
...@@ -1566,18 +1566,17 @@ typedef struct __go_empty_interface { ...@@ -1566,18 +1566,17 @@ typedef struct __go_empty_interface {
void *__object; void *__object;
} Eface; } Eface;
extern Eface runtimeCgoCheckPointer(Eface, Slice) extern void runtimeCgoCheckPointer(Eface, Slice)
__asm__("runtime.cgoCheckPointer") __asm__("runtime.cgoCheckPointer")
__attribute__((weak)); __attribute__((weak));
extern Eface localCgoCheckPointer(Eface, Slice) extern void localCgoCheckPointer(Eface, Slice)
__asm__("GCCGOSYMBOLPREF._cgoCheckPointer"); __asm__("GCCGOSYMBOLPREF._cgoCheckPointer");
Eface localCgoCheckPointer(Eface ptr, Slice args) { void localCgoCheckPointer(Eface ptr, Slice args) {
if(runtimeCgoCheckPointer) { if(runtimeCgoCheckPointer) {
return runtimeCgoCheckPointer(ptr, args); runtimeCgoCheckPointer(ptr, args);
} }
return ptr;
} }
extern void runtimeCgoCheckResult(Eface) extern void runtimeCgoCheckResult(Eface)
......
...@@ -370,10 +370,10 @@ var racecgosync uint64 // represents possible synchronization in C code ...@@ -370,10 +370,10 @@ var racecgosync uint64 // represents possible synchronization in C code
// pointers.) // pointers.)
// cgoCheckPointer checks if the argument contains a Go pointer that // cgoCheckPointer checks if the argument contains a Go pointer that
// points to a Go pointer, and panics if it does. It returns the pointer. // points to a Go pointer, and panics if it does.
func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { func cgoCheckPointer(ptr interface{}, args ...interface{}) {
if debug.cgocheck == 0 { if debug.cgocheck == 0 {
return ptr return
} }
ep := (*eface)(unsafe.Pointer(&ptr)) ep := (*eface)(unsafe.Pointer(&ptr))
...@@ -386,7 +386,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { ...@@ -386,7 +386,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
p = *(*unsafe.Pointer)(p) p = *(*unsafe.Pointer)(p)
} }
if !cgoIsGoPointer(p) { if !cgoIsGoPointer(p) {
return ptr return
} }
aep := (*eface)(unsafe.Pointer(&args[0])) aep := (*eface)(unsafe.Pointer(&args[0]))
switch aep._type.kind & kindMask { switch aep._type.kind & kindMask {
...@@ -397,7 +397,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { ...@@ -397,7 +397,7 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
} }
pt := (*ptrtype)(unsafe.Pointer(t)) pt := (*ptrtype)(unsafe.Pointer(t))
cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail) cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail)
return ptr return
case kindSlice: case kindSlice:
// Check the slice rather than the pointer. // Check the slice rather than the pointer.
ep = aep ep = aep
...@@ -415,7 +415,6 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} { ...@@ -415,7 +415,6 @@ func cgoCheckPointer(ptr interface{}, args ...interface{}) interface{} {
} }
cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top, cgoCheckPointerFail) cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top, cgoCheckPointerFail)
return ptr
} }
const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer" const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment