Commit a70a2a8a authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: don't update each call in place

Updating each call in place broke when there were multiple cgo calls
used as arguments to another cgo call where some required rewriting.
Instead, rewrite calls to strings via the existing mangling mechanism,
and only substitute the top level call in place.

Fixes #28540

Change-Id: Ifd66f04c205adc4ad6dd5ee8e79e57dce17e86bb
Reviewed-on: https://go-review.googlesource.com/c/146860Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarDmitri Shuralyov <dmitshur@golang.org>
parent 2182bb09
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Crash from call with two arguments that need pointer checking.
// No runtime test; just make sure it compiles.
package cgotest
/*
static void twoargs1(void *p, int n) {}
static void *twoargs2() { return 0; }
static int twoargs3(void * p) { return 0; }
*/
import "C"
import "unsafe"
func twoargsF() {
v := []string{}
C.twoargs1(C.twoargs2(), C.twoargs3(unsafe.Pointer(&v)))
}
...@@ -722,20 +722,18 @@ func (p *Package) mangleName(n *Name) { ...@@ -722,20 +722,18 @@ func (p *Package) mangleName(n *Name) {
func (p *Package) rewriteCalls(f *File) bool { func (p *Package) rewriteCalls(f *File) bool {
needsUnsafe := false needsUnsafe := false
// Walk backward so that in C.f1(C.f2()) we rewrite C.f2 first. // Walk backward so that in C.f1(C.f2()) we rewrite C.f2 first.
for i := len(f.Calls) - 1; i >= 0; i-- { for _, call := range f.Calls {
call := f.Calls[i] if call.Done {
// This is a call to C.xxx; set goname to "xxx".
goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
if goname == "malloc" {
continue continue
} }
name := f.Name[goname] start := f.offset(call.Call.Pos())
if name.Kind != "func" { end := f.offset(call.Call.End())
// Probably a type conversion. str, nu := p.rewriteCall(f, call)
continue if str != "" {
} f.Edit.Replace(start, end, str)
if p.rewriteCall(f, call, name) { if nu {
needsUnsafe = true needsUnsafe = true
}
} }
} }
return needsUnsafe return needsUnsafe
...@@ -745,8 +743,29 @@ func (p *Package) rewriteCalls(f *File) bool { ...@@ -745,8 +743,29 @@ func (p *Package) rewriteCalls(f *File) bool {
// If any pointer checks are required, we rewrite the call into a // If any pointer checks are required, we rewrite the call into a
// function literal that calls _cgoCheckPointer for each pointer // function literal that calls _cgoCheckPointer for each pointer
// argument and then calls the original function. // argument and then calls the original function.
// This returns whether the package needs to import unsafe as _cgo_unsafe. // This returns the rewritten call and whether the package needs to
func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { // import unsafe as _cgo_unsafe.
// If it returns the empty string, the call did not need to be rewritten.
func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
// This is a call to C.xxx; set goname to "xxx".
// It may have already been mangled by rewriteName.
var goname string
switch fun := call.Call.Fun.(type) {
case *ast.SelectorExpr:
goname = fun.Sel.Name
case *ast.Ident:
goname = strings.TrimPrefix(fun.Name, "_C2func_")
goname = strings.TrimPrefix(goname, "_Cfunc_")
}
if goname == "" || goname == "malloc" {
return "", false
}
name := f.Name[goname]
if name == nil || name.Kind != "func" {
// Probably a type conversion.
return "", false
}
params := name.FuncType.Params params := name.FuncType.Params
args := call.Call.Args args := call.Call.Args
...@@ -754,7 +773,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -754,7 +773,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// less than the number of parameters. // less than the number of parameters.
// This will be caught when the generated file is compiled. // This will be caught when the generated file is compiled.
if len(args) < len(params) { if len(args) < len(params) {
return false return "", false
} }
any := false any := false
...@@ -765,7 +784,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -765,7 +784,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
} }
} }
if !any { if !any {
return false return "", false
} }
// We need to rewrite this call. // We need to rewrite this call.
...@@ -848,7 +867,10 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -848,7 +867,10 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// Write _cgoCheckPointer calls to sbCheck. // Write _cgoCheckPointer calls to sbCheck.
var sbCheck bytes.Buffer var sbCheck bytes.Buffer
for i, param := range params { for i, param := range params {
arg := p.mangle(f, &args[i]) arg, nu := p.mangle(f, &args[i])
if nu {
needsUnsafe = true
}
// Explicitly convert untyped constants to the // Explicitly convert untyped constants to the
// parameter type, to avoid a type mismatch. // parameter type, to avoid a type mismatch.
...@@ -893,12 +915,12 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -893,12 +915,12 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
sb.WriteString("return ") sb.WriteString("return ")
} }
// Now we are ready to call the C function. m, nu := p.mangle(f, &call.Call.Fun)
// To work smoothly with rewriteRef we leave the call in place if nu {
// and just replace the old arguments with our new ones. needsUnsafe = true
f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String()) }
sb.WriteString(gofmtLine(m))
sb.Reset()
sb.WriteString("(") sb.WriteString("(")
for i := range params { for i := range params {
if i > 0 { if i > 0 {
...@@ -916,9 +938,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -916,9 +938,7 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
} }
sb.WriteString("()") sb.WriteString("()")
f.Edit.Replace(f.offset(call.Call.Lparen), f.offset(call.Call.Rparen)+1, sb.String()) return sb.String(), needsUnsafe
return needsUnsafe
} }
// needsPointerCheck returns whether the type t needs a pointer check. // needsPointerCheck returns whether the type t needs a pointer check.
...@@ -1025,32 +1045,54 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool { ...@@ -1025,32 +1045,54 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
} }
} }
// mangle replaces references to C names in arg with the mangled names. // mangle replaces references to C names in arg with the mangled names,
// It removes the corresponding references in f.Ref, so that we don't // rewriting calls when it finds them.
// try to do the replacement again in rewriteRef. // It removes the corresponding references in f.Ref and f.Calls, so that we
func (p *Package) mangle(f *File, arg *ast.Expr) ast.Expr { // don't try to do the replacement again in rewriteRef or rewriteCall.
func (p *Package) mangle(f *File, arg *ast.Expr) (ast.Expr, bool) {
needsUnsafe := false
f.walk(arg, ctxExpr, func(f *File, arg interface{}, context astContext) { f.walk(arg, ctxExpr, func(f *File, arg interface{}, context astContext) {
px, ok := arg.(*ast.Expr) px, ok := arg.(*ast.Expr)
if !ok { if !ok {
return return
} }
sel, ok := (*px).(*ast.SelectorExpr) sel, ok := (*px).(*ast.SelectorExpr)
if !ok { if ok {
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
return
}
for _, r := range f.Ref {
if r.Expr == px {
*px = p.rewriteName(f, r)
r.Done = true
break
}
}
return return
} }
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
call, ok := (*px).(*ast.CallExpr)
if !ok {
return return
} }
for _, r := range f.Ref { for _, c := range f.Calls {
if r.Expr == px { if !c.Done && c.Call.Lparen == call.Lparen {
*px = p.rewriteName(f, r) cstr, nu := p.rewriteCall(f, c)
r.Done = true if cstr != "" {
break // Smuggle the rewritten call through an ident.
*px = ast.NewIdent(cstr)
if nu {
needsUnsafe = true
}
c.Done = true
}
} }
} }
}) })
return *arg return *arg, needsUnsafe
} }
// checkIndex checks whether arg the form &a[i], possibly inside type // checkIndex checks whether arg the form &a[i], possibly inside type
......
...@@ -81,6 +81,7 @@ func nameKeys(m map[string]*Name) []string { ...@@ -81,6 +81,7 @@ func nameKeys(m map[string]*Name) []string {
type Call struct { type Call struct {
Call *ast.CallExpr Call *ast.CallExpr
Deferred bool Deferred bool
Done bool
} }
// A Ref refers to an expression of the form C.xxx in the AST. // A Ref refers to an expression of the form C.xxx in the AST.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment