Commit 3c1e1c30 authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: use alias for unsafe rather than separate functions

When we need to generate a call to _cgoCheckPointer, we need to type
assert the result back to the desired type. That is harder when the type
is unsafe.Pointer, as the package can have values of unsafe.Pointer
types without actually importing unsafe, by mixing C void* and :=. We
used to handle this by generating a special function for each needed
type, and defining that function in a separate file where we did import
unsafe.

Simplify the code by not generating those functions, but instead just
import unsafe under the alias _cgo_unsafe. This is a simplification step
toward a fix for issue #16591.

Change-Id: I0edb3e04b6400ca068751709fe063397cf960a54
Reviewed-on: https://go-review.googlesource.com/30973
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent dc46b882
...@@ -167,7 +167,23 @@ func (p *Package) Translate(f *File) { ...@@ -167,7 +167,23 @@ func (p *Package) Translate(f *File) {
if len(needType) > 0 { if len(needType) > 0 {
p.loadDWARF(f, needType) p.loadDWARF(f, needType)
} }
p.rewriteCalls(f) if p.rewriteCalls(f) {
// Add `import _cgo_unsafe "unsafe"` as the first decl
// after the package statement.
imp := &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Name: ast.NewIdent("_cgo_unsafe"),
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"unsafe"`,
},
},
},
}
f.AST.Decls = append([]ast.Decl{imp}, f.AST.Decls...)
}
p.rewriteRef(f) p.rewriteRef(f)
} }
...@@ -578,7 +594,9 @@ func (p *Package) mangleName(n *Name) { ...@@ -578,7 +594,9 @@ func (p *Package) mangleName(n *Name) {
// rewriteCalls rewrites all calls that pass pointers to check that // rewriteCalls rewrites all calls that pass pointers to check that
// they follow the rules for passing pointers between Go and C. // they follow the rules for passing pointers between Go and C.
func (p *Package) rewriteCalls(f *File) { // This returns whether the package needs to import unsafe as _cgo_unsafe.
func (p *Package) rewriteCalls(f *File) bool {
needsUnsafe := false
for _, call := range f.Calls { for _, call := range f.Calls {
// This is a call to C.xxx; set goname to "xxx". // This is a call to C.xxx; set goname to "xxx".
goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
...@@ -590,18 +608,22 @@ func (p *Package) rewriteCalls(f *File) { ...@@ -590,18 +608,22 @@ func (p *Package) rewriteCalls(f *File) {
// Probably a type conversion. // Probably a type conversion.
continue continue
} }
p.rewriteCall(f, call, name) if p.rewriteCall(f, call, name) {
needsUnsafe = true
} }
}
return needsUnsafe
} }
// rewriteCall rewrites one call to add pointer checks. We replace // rewriteCall rewrites one call to add pointer checks. We replace
// each pointer argument x with _cgoCheckPointer(x).(T). // each pointer argument x with _cgoCheckPointer(x).(T).
func (p *Package) rewriteCall(f *File, call *Call, name *Name) { // This returns whether the package needs to import unsafe as _cgo_unsafe.
func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// Avoid a crash if the number of arguments is // Avoid a crash if the number of arguments is
// less than the number of parameters. // less than the number of parameters.
// This will be caught when the generated file is compiled. // This will be caught when the generated file is compiled.
if len(call.Call.Args) < len(name.FuncType.Params) { if len(call.Call.Args) < len(name.FuncType.Params) {
return return false
} }
any := false any := false
...@@ -612,7 +634,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) { ...@@ -612,7 +634,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
} }
} }
if !any { if !any {
return return false
} }
// We need to rewrite this call. // We need to rewrite this call.
...@@ -622,6 +644,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) { ...@@ -622,6 +644,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
// point of the defer statement, not when the function is called, so // point of the defer statement, not when the function is called, so
// rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p) // rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
needsUnsafe := false
var dargs []ast.Expr var dargs []ast.Expr
if call.Deferred { if call.Deferred {
dargs = make([]ast.Expr, len(name.FuncType.Params)) dargs = make([]ast.Expr, len(name.FuncType.Params))
...@@ -651,57 +674,42 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) { ...@@ -651,57 +674,42 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
// expression. // expression.
c.Args = p.checkAddrArgs(f, c.Args, origArg) c.Args = p.checkAddrArgs(f, c.Args, origArg)
// _cgoCheckPointer returns interface{}. // The Go version of the C type might use unsafe.Pointer,
// We need to type assert that to the type we want. // but the file might not import unsafe.
// If the Go version of this C type uses // Rewrite the Go type if necessary to use _cgo_unsafe.
// unsafe.Pointer, we can't use a type assertion, ptype := p.rewriteUnsafe(param.Go)
// because the Go file might not import unsafe. if ptype != param.Go {
// Instead we use a local variant of _cgoCheckPointer. needsUnsafe = true
}
var arg ast.Expr
if n := p.unsafeCheckPointerName(param.Go, call.Deferred); n != "" { // In order for the type assertion to succeed, we need
c.Fun = ast.NewIdent(n) // it to match the actual type of the argument. The
arg = c // only type we have is the type of the function
} else { // parameter. We know that the argument type must be
// In order for the type assertion to succeed, // assignable to the function parameter type, or the
// we need it to match the actual type of the // code would not compile, but there is nothing
// argument. The only type we have is the // requiring that the types be exactly the same. Add a
// type of the function parameter. We know // type conversion to the argument so that the type
// that the argument type must be assignable // assertion will succeed.
// to the function parameter type, or the code
// would not compile, but there is nothing
// requiring that the types be exactly the
// same. Add a type conversion to the
// argument so that the type assertion will
// succeed.
c.Args[0] = &ast.CallExpr{ c.Args[0] = &ast.CallExpr{
Fun: param.Go, Fun: ptype,
Args: []ast.Expr{ Args: []ast.Expr{
c.Args[0], c.Args[0],
}, },
} }
arg = &ast.TypeAssertExpr{ call.Call.Args[i] = &ast.TypeAssertExpr{
X: c, X: c,
Type: param.Go, Type: ptype,
}
} }
call.Call.Args[i] = arg
} }
if call.Deferred { if call.Deferred {
params := make([]*ast.Field, len(name.FuncType.Params)) params := make([]*ast.Field, len(name.FuncType.Params))
for i, param := range name.FuncType.Params { for i, param := range name.FuncType.Params {
ptype := param.Go ptype := p.rewriteUnsafe(param.Go)
if p.hasUnsafePointer(ptype) { if ptype != param.Go {
// Avoid generating unsafe.Pointer by using needsUnsafe = true
// interface{}. This works because we are
// going to call a _cgoCheckPointer function
// anyhow.
ptype = &ast.InterfaceType{
Methods: &ast.FieldList{},
}
} }
params[i] = &ast.Field{ params[i] = &ast.Field{
Names: []*ast.Ident{ Names: []*ast.Ident{
...@@ -740,6 +748,8 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) { ...@@ -740,6 +748,8 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
} }
} }
} }
return needsUnsafe
} }
// needsPointerCheck returns whether the type t needs a pointer check. // needsPointerCheck returns whether the type t needs a pointer check.
...@@ -935,69 +945,52 @@ func (p *Package) isType(t ast.Expr) bool { ...@@ -935,69 +945,52 @@ func (p *Package) isType(t ast.Expr) bool {
return false return false
} }
// unsafeCheckPointerName is given the Go version of a C type. If the // rewriteUnsafe returns a version of t with references to unsafe.Pointer
// type uses unsafe.Pointer, we arrange to build a version of // rewritten to use _cgo_unsafe.Pointer instead.
// _cgoCheckPointer that returns that type. This avoids using a type func (p *Package) rewriteUnsafe(t ast.Expr) ast.Expr {
// assertion to unsafe.Pointer in our copy of user code. We return
// the name of the _cgoCheckPointer function we are going to build, or
// the empty string if the type does not use unsafe.Pointer.
//
// The deferred parameter is true if this check is for the argument of
// a deferred function. In that case we need to use an empty interface
// as the argument type, because the deferred function we introduce in
// rewriteCall will use an empty interface type, and we can't add a
// type assertion. This is handled by keeping a separate list, and
// writing out the lists separately in writeDefs.
func (p *Package) unsafeCheckPointerName(t ast.Expr, deferred bool) string {
if !p.hasUnsafePointer(t) {
return ""
}
var buf bytes.Buffer
conf.Fprint(&buf, fset, t)
s := buf.String()
checks := &p.CgoChecks
if deferred {
checks = &p.DeferredCgoChecks
}
for i, t := range *checks {
if s == t {
return p.unsafeCheckPointerNameIndex(i, deferred)
}
}
*checks = append(*checks, s)
return p.unsafeCheckPointerNameIndex(len(*checks)-1, deferred)
}
// hasUnsafePointer returns whether the Go type t uses unsafe.Pointer.
// t is the Go version of a C type, so we don't need to handle every case.
// We only care about direct references, not references via typedefs.
func (p *Package) hasUnsafePointer(t ast.Expr) bool {
switch t := t.(type) { switch t := t.(type) {
case *ast.Ident: case *ast.Ident:
// We don't see a SelectorExpr for unsafe.Pointer; // We don't see a SelectorExpr for unsafe.Pointer;
// this is created by code in this file. // this is created by code in this file.
return t.Name == "unsafe.Pointer" if t.Name == "unsafe.Pointer" {
return ast.NewIdent("_cgo_unsafe.Pointer")
}
case *ast.ArrayType: case *ast.ArrayType:
return p.hasUnsafePointer(t.Elt) t1 := p.rewriteUnsafe(t.Elt)
if t1 != t.Elt {
r := *t
r.Elt = t1
return &r
}
case *ast.StructType: case *ast.StructType:
changed := false
fields := *t.Fields
fields.List = nil
for _, f := range t.Fields.List { for _, f := range t.Fields.List {
if p.hasUnsafePointer(f.Type) { ft := p.rewriteUnsafe(f.Type)
return true if ft == f.Type {
fields.List = append(fields.List, f)
} else {
fn := *f
fn.Type = ft
fields.List = append(fields.List, &fn)
changed = true
}
} }
if changed {
r := *t
r.Fields = &fields
return &r
} }
case *ast.StarExpr: // Pointer type. case *ast.StarExpr: // Pointer type.
return p.hasUnsafePointer(t.X) x1 := p.rewriteUnsafe(t.X)
if x1 != t.X {
r := *t
r.X = x1
return &r
} }
return false
}
// unsafeCheckPointerNameIndex returns the name to use for a
// _cgoCheckPointer variant based on the index in the CgoChecks slice.
func (p *Package) unsafeCheckPointerNameIndex(i int, deferred bool) string {
if deferred {
return fmt.Sprintf("_cgoCheckPointerInDefer%d", i)
} }
return fmt.Sprintf("_cgoCheckPointer%d", i) return t
} }
// rewriteRef rewrites all the C.xxx references in f.AST to refer to the // rewriteRef rewrites all the C.xxx references in f.AST to refer to the
......
...@@ -42,10 +42,6 @@ type Package struct { ...@@ -42,10 +42,6 @@ type Package struct {
GoFiles []string // list of Go files GoFiles []string // list of Go files
GccFiles []string // list of gcc output files GccFiles []string // list of gcc output files
Preamble string // collected preamble for _cgo_export.h Preamble string // collected preamble for _cgo_export.h
// See unsafeCheckPointerName.
CgoChecks []string
DeferredCgoChecks []string
} }
// A File collects information about a single Go input file. // A File collects information about a single Go input file.
......
...@@ -126,19 +126,6 @@ func (p *Package) writeDefs() { ...@@ -126,19 +126,6 @@ func (p *Package) writeDefs() {
fmt.Fprint(fgo2, goProlog) fmt.Fprint(fgo2, goProlog)
} }
for i, t := range p.CgoChecks {
n := p.unsafeCheckPointerNameIndex(i, false)
fmt.Fprintf(fgo2, "\nfunc %s(p %s, args ...interface{}) %s {\n", n, t, t)
fmt.Fprintf(fgo2, "\treturn _cgoCheckPointer(p, args...).(%s)\n", t)
fmt.Fprintf(fgo2, "}\n")
}
for i, t := range p.DeferredCgoChecks {
n := p.unsafeCheckPointerNameIndex(i, true)
fmt.Fprintf(fgo2, "\nfunc %s(p interface{}, args ...interface{}) %s {\n", n, t)
fmt.Fprintf(fgo2, "\treturn _cgoCheckPointer(p, args...).(%s)\n", t)
fmt.Fprintf(fgo2, "}\n")
}
gccgoSymbolPrefix := p.gccgoSymbolPrefix() gccgoSymbolPrefix := p.gccgoSymbolPrefix()
cVars := make(map[string]bool) cVars := make(map[string]bool)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment