Commit 1d18f66d authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: write a string rather than building an AST

This generates the same code as before, but does so directly rather
than building an AST and printing that. This is in preparation for
later changes.

Change-Id: Ifec141120bcc74847f0bff8d3d47306bfe69b454
Reviewed-on: https://go-review.googlesource.com/c/142883
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent af951994
...@@ -744,16 +744,19 @@ func (p *Package) rewriteCalls(f *File) bool { ...@@ -744,16 +744,19 @@ func (p *Package) rewriteCalls(f *File) bool {
// argument and then calls the original function. // argument and then calls the original function.
// This returns whether the package needs to import unsafe as _cgo_unsafe. // This returns whether the package needs to import unsafe as _cgo_unsafe.
func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
params := name.FuncType.Params
args := call.Call.Args
// Avoid a crash if the number of arguments is // Avoid a crash if the number of arguments is
// less than the number of parameters. // less than the number of parameters.
// This will be caught when the generated file is compiled. // This will be caught when the generated file is compiled.
if len(call.Call.Args) < len(name.FuncType.Params) { if len(args) < len(params) {
return false return false
} }
any := false any := false
for i, param := range name.FuncType.Params { for i, param := range params {
if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) { if p.needsPointerCheck(f, param.Go, args[i]) {
any = true any = true
break break
} }
...@@ -772,127 +775,108 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool { ...@@ -772,127 +775,108 @@ func (p *Package) rewriteCall(f *File, call *Call, name *Name) bool {
// Using a function literal like this lets us do correct // Using a function literal like this lets us do correct
// argument type checking, and works correctly if the call is // argument type checking, and works correctly if the call is
// deferred. // deferred.
var sb bytes.Buffer
sb.WriteString("func(")
needsUnsafe := false needsUnsafe := false
params := make([]*ast.Field, len(name.FuncType.Params))
nargs := make([]ast.Expr, len(name.FuncType.Params)) for i, param := range params {
var stmts []ast.Stmt if i > 0 {
for i, param := range name.FuncType.Params { sb.WriteString(", ")
// params is going to become the parameters of the }
// function literal.
// nargs is going to become the list of arguments made fmt.Fprintf(&sb, "_cgo%d ", i)
// by the call within the function literal.
// nparam is the parameter of the function literal that
// corresponds to param.
origArg := call.Call.Args[i]
nparam := ast.NewIdent(fmt.Sprintf("_cgo%d", i))
nargs[i] = nparam
// The Go version of the C type might use unsafe.Pointer,
// but the file might not import unsafe.
// Rewrite the Go type if necessary to use _cgo_unsafe.
ptype := p.rewriteUnsafe(param.Go) ptype := p.rewriteUnsafe(param.Go)
if ptype != param.Go { if ptype != param.Go {
needsUnsafe = true needsUnsafe = true
} }
sb.WriteString(gofmtLine(ptype))
}
params[i] = &ast.Field{ sb.WriteString(")")
Names: []*ast.Ident{nparam},
Type: ptype,
}
if !p.needsPointerCheck(f, param.Go, origArg) {
continue
}
// Run the cgo pointer checks on nparam. result := false
twoResults := false
// Change the function literal to call the real function // Check whether this call expects two results.
// with the parameter passed through _cgoCheckPointer. for _, ref := range f.Ref {
c := &ast.CallExpr{ if ref.Expr != &call.Call.Fun {
Fun: ast.NewIdent("_cgoCheckPointer"), continue
Args: []ast.Expr{
nparam,
},
} }
if ref.Context == ctxCall2 {
// Add optional additional arguments for an address sb.WriteString(" (")
// expression. result = true
c.Args = p.checkAddrArgs(f, c.Args, origArg) twoResults = true
stmt := &ast.ExprStmt{
X: c,
} }
stmts = append(stmts, stmt) break
} }
const cgoMarker = "__cgo__###__marker__" // Add the result type, if any.
fcall := &ast.CallExpr{
Fun: ast.NewIdent(cgoMarker),
Args: nargs,
}
ftype := &ast.FuncType{
Params: &ast.FieldList{
List: params,
},
}
if name.FuncType.Result != nil { if name.FuncType.Result != nil {
rtype := p.rewriteUnsafe(name.FuncType.Result.Go) rtype := p.rewriteUnsafe(name.FuncType.Result.Go)
if rtype != name.FuncType.Result.Go { if rtype != name.FuncType.Result.Go {
needsUnsafe = true needsUnsafe = true
} }
ftype.Results = &ast.FieldList{ if !twoResults {
List: []*ast.Field{ sb.WriteString(" ")
&ast.Field{
Type: rtype,
},
},
} }
sb.WriteString(gofmtLine(rtype))
result = true
} }
// If this call expects two results, we have to // Add the second result type, if any.
// adjust the results of the function we generated. if twoResults {
for _, ref := range f.Ref { if name.FuncType.Result == nil {
if ref.Expr == &call.Call.Fun && ref.Context == ctxCall2 { // An explicit void result looks odd but it
if ftype.Results == nil { // seems to be how cgo has worked historically.
// An explicit void argument sb.WriteString("_Ctype_void")
// looks odd but it seems to
// be how cgo has worked historically.
ftype.Results = &ast.FieldList{
List: []*ast.Field{
&ast.Field{
Type: ast.NewIdent("_Ctype_void"),
},
},
}
}
ftype.Results.List = append(ftype.Results.List,
&ast.Field{
Type: ast.NewIdent("error"),
})
} }
sb.WriteString(", error)")
} }
var fbody ast.Stmt sb.WriteString(" { ")
if ftype.Results == nil {
fbody = &ast.ExprStmt{ for i, param := range params {
X: fcall, arg := args[i]
if !p.needsPointerCheck(f, param.Go, arg) {
continue
} }
} else {
fbody = &ast.ReturnStmt{ // Check for &a[i].
Results: []ast.Expr{fcall}, if p.checkIndex(&sb, f, arg, i) {
continue
}
// Check for &x.
if p.checkAddr(&sb, arg, i) {
continue
} }
fmt.Fprintf(&sb, "_cgoCheckPointer(_cgo%d); ", i)
} }
lit := &ast.FuncLit{
Type: ftype, if result {
Body: &ast.BlockStmt{ sb.WriteString("return ")
List: append(stmts, fbody),
},
} }
text := strings.Replace(gofmt(lit), "\n", ";", -1)
repl := strings.Split(text, cgoMarker) // Now we are ready to call the C function.
f.Edit.Insert(f.offset(call.Call.Fun.Pos()), repl[0]) // To work smoothly with rewriteRef we leave the call in place
f.Edit.Insert(f.offset(call.Call.Fun.End()), repl[1]) // and just insert our new arguments between the function
// and the old arguments.
f.Edit.Insert(f.offset(call.Call.Fun.Pos()), sb.String())
sb.Reset()
sb.WriteString("(")
for i := range params {
if i > 0 {
sb.WriteString(", ")
}
fmt.Fprintf(&sb, "_cgo%d", i)
}
sb.WriteString("); }")
f.Edit.Insert(f.offset(call.Call.Lparen), sb.String())
return needsUnsafe return needsUnsafe
} }
...@@ -1001,19 +985,13 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool { ...@@ -1001,19 +985,13 @@ func (p *Package) hasPointer(f *File, t ast.Expr, top bool) bool {
} }
} }
// checkAddrArgs tries to add arguments to the call of // checkIndex checks whether arg the form &a[i], possibly inside type
// _cgoCheckPointer when the argument is an address expression. We // conversions. If so, and if a has no side effects, it writes
// pass true to mean that the argument is an address operation of // _cgoCheckPointer(_cgoNN, a) to sb and returns true. This tells
// something other than a slice index, which means that it's only // _cgoCheckPointer to check the complete contents of the slice.
// necessary to check the specific element pointed to, not the entire func (p *Package) checkIndex(sb *bytes.Buffer, f *File, arg ast.Expr, i int) bool {
// object. This is for &s.f, where f is a field in a struct. We can
// pass a slice or array, meaning that we should check the entire
// slice or array but need not check any other part of the object.
// This is for &s.a[i], where we need to check all of a. However, we
// only pass the slice or array if we can refer to it without side
// effects.
func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr {
// Strip type conversions. // Strip type conversions.
x := arg
for { for {
c, ok := x.(*ast.CallExpr) c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) { if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
...@@ -1023,22 +1001,46 @@ func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr ...@@ -1023,22 +1001,46 @@ func (p *Package) checkAddrArgs(f *File, args []ast.Expr, x ast.Expr) []ast.Expr
} }
u, ok := x.(*ast.UnaryExpr) u, ok := x.(*ast.UnaryExpr)
if !ok || u.Op != token.AND { if !ok || u.Op != token.AND {
return args return false
} }
index, ok := u.X.(*ast.IndexExpr) index, ok := u.X.(*ast.IndexExpr)
if !ok { if !ok {
// This is the address of something that is not an return false
// index expression. We only need to examine the }
// single value to which it points. if p.hasSideEffects(f, index.X) {
// TODO: what if true is shadowed? return false
return append(args, ast.NewIdent("true")) }
}
if !p.hasSideEffects(f, index.X) { fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, %s); ", i, gofmtLine(index.X))
// Examine the entire slice.
return append(args, index.X) return true
} }
// Treat the pointer as unknown.
return args // checkAddr checks whether arg has the form &x, possibly inside type
// conversions. If so it writes _cgoCheckPointer(_cgoNN, true) to sb
// and returns true. This tells _cgoCheckPointer to check just the
// contents of the pointer being passed, not any other part of the
// memory allocation. This is run after checkIndex, which looks for
// the special case of &a[i], which requires different checks.
func (p *Package) checkAddr(sb *bytes.Buffer, arg ast.Expr, i int) bool {
// Strip type conversions.
px := &arg
for {
c, ok := (*px).(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
break
}
px = &c.Args[0]
}
if u, ok := (*px).(*ast.UnaryExpr); !ok || u.Op != token.AND {
return false
}
// Use "0 == 0" to do the right thing in the unlikely event
// that "true" is shadowed.
fmt.Fprintf(sb, "_cgoCheckPointer(_cgo%d, 0 == 0); ", i)
return true
} }
// hasSideEffects returns whether the expression x has any side // hasSideEffects returns whether the expression x has any side
......
...@@ -126,3 +126,9 @@ func gofmt(n interface{}) string { ...@@ -126,3 +126,9 @@ func gofmt(n interface{}) string {
} }
return gofmtBuf.String() return gofmtBuf.String()
} }
// gofmtLine returns the gofmt-formatted string for an AST node,
// ensuring that it is on a single line.
func gofmtLine(n interface{}) string {
return strings.Replace(gofmt(n), "\n", ";", -1)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment