Commit 2c0c68d6 authored by Matthew Dempsky's avatar Matthew Dempsky

cmd/compile: fix miscompilation of "defer delete(m, k)"

Previously, for slow map key types (i.e., any type other than a 32-bit
or 64-bit plain memory type), we would rewrite

    defer delete(m, k)

into

    ktmp := k
    defer delete(m, &ktmp)

However, if the defer statement was inside a loop, we would end up
reusing the same ktmp value for all of the deferred deletes.

We already rewrite

    defer print(x, y, z)

into

    defer func(a1, a2, a3) {
        print(a1, a2, a3)
    }(x, y, z)

This CL generalizes this rewrite to also apply for slow map deletes.

This could be extended to apply even more generally to other builtins,
but as discussed on #24259, there are cases where we must *not* do
this (e.g., "defer recover()"). However, if we elect to do this more
generally, this CL should still make that easier.

Lastly, while here, fix a few isues in wrapCall (nee walkprintfunc):

1) lookupN appends the generation number to the symbol anyway, so "%d"
was being literally included in the generated function names.

2) walkstmt will be called when the function is compiled later anyway,
so no need to do it now.

Fixes #24259.

Change-Id: I70286867c64c69c18e9552f69e3f4154a0fc8b04
Reviewed-on: https://go-review.googlesource.com/99017
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 558769a6
......@@ -618,25 +618,7 @@ func (o *Order) stmt(n *Node) {
// Special: order arguments to inner call but not call itself.
case ODEFER, OPROC:
t := o.markTemp()
switch n.Left.Op {
// Delete will take the address of the key.
// Copy key into new temp and do not clean it
// (it persists beyond the statement).
case ODELETE:
o.exprList(n.Left.List)
if mapfast(n.Left.List.First().Type) == mapslow {
t1 := o.markTemp()
np := n.Left.List.Addr(1) // map key
*np = o.copyExpr(*np, (*np).Type, false)
o.popTemp(t1)
}
default:
o.call(n.Left)
}
o.call(n.Left)
o.out = append(o.out, n)
o.cleanTemp(t)
......
......@@ -242,9 +242,18 @@ func walkstmt(n *Node) *Node {
case ODEFER:
Curfn.Func.SetHasDefer(true)
fallthrough
case OPROC:
switch n.Left.Op {
case OPRINT, OPRINTN:
n.Left = walkprintfunc(n.Left, &n.Ninit)
n.Left = wrapCall(n.Left, &n.Ninit)
case ODELETE:
if mapfast(n.Left.List.First().Type) == mapslow {
n.Left = wrapCall(n.Left, &n.Ninit)
} else {
n.Left = walkexpr(n.Left, &n.Ninit)
}
case OCOPY:
n.Left = copyany(n.Left, &n.Ninit, true)
......@@ -273,21 +282,6 @@ func walkstmt(n *Node) *Node {
walkstmtlist(n.Nbody.Slice())
walkstmtlist(n.Rlist.Slice())
case OPROC:
switch n.Left.Op {
case OPRINT, OPRINTN:
n.Left = walkprintfunc(n.Left, &n.Ninit)
case OCOPY:
n.Left = copyany(n.Left, &n.Ninit, true)
default:
n.Left = walkexpr(n.Left, &n.Ninit)
}
// make room for size & fn arguments.
adjustargs(n, 2*Widthptr)
case ORETURN:
walkexprlist(n.List.Slice(), &n.Ninit)
if n.List.Len() == 0 {
......@@ -3847,45 +3841,43 @@ func candiscard(n *Node) bool {
return true
}
// rewrite
// print(x, y, z)
// Rewrite
// go builtin(x, y, z)
// into
// func(a1, a2, a3) {
// print(a1, a2, a3)
// go func(a1, a2, a3) {
// builtin(a1, a2, a3)
// }(x, y, z)
// and same for println.
// for print, println, and delete.
var walkprintfunc_prgen int
var wrapCall_prgen int
// The result of walkprintfunc MUST be assigned back to n, e.g.
// n.Left = walkprintfunc(n.Left, init)
func walkprintfunc(n *Node, init *Nodes) *Node {
// The result of wrapCall MUST be assigned back to n, e.g.
// n.Left = wrapCall(n.Left, init)
func wrapCall(n *Node, init *Nodes) *Node {
if n.Ninit.Len() != 0 {
walkstmtlist(n.Ninit.Slice())
init.AppendNodes(&n.Ninit)
}
t := nod(OTFUNC, nil, nil)
var printargs []*Node
for i, n1 := range n.List.Slice() {
var args []*Node
for i, arg := range n.List.Slice() {
buf := fmt.Sprintf("a%d", i)
a := namedfield(buf, n1.Type)
a := namedfield(buf, arg.Type)
t.List.Append(a)
printargs = append(printargs, a.Left)
args = append(args, a.Left)
}
oldfn := Curfn
Curfn = nil
walkprintfunc_prgen++
sym := lookupN("print·%d", walkprintfunc_prgen)
wrapCall_prgen++
sym := lookupN("wrap·", wrapCall_prgen)
fn := dclfunc(sym, t)
a := nod(n.Op, nil, nil)
a.List.Set(printargs)
a.List.Set(args)
a = typecheck(a, Etop)
a = walkstmt(a)
fn.Nbody.Set1(a)
funcbody()
......
......@@ -875,3 +875,24 @@ func BenchmarkMapDelete(b *testing.B) {
b.Run("Int64", runWith(benchmarkMapDeleteInt64, 100, 1000, 10000))
b.Run("Str", runWith(benchmarkMapDeleteStr, 100, 1000, 10000))
}
func TestDeferDeleteSlow(t *testing.T) {
ks := []complex128{0, 1, 2, 3}
m := make(map[interface{}]int)
for i, k := range ks {
m[k] = i
}
if len(m) != len(ks) {
t.Errorf("want %d elements, got %d", len(ks), len(m))
}
func() {
for _, k := range ks {
defer delete(m, k)
}
}()
if len(m) != 0 {
t.Errorf("want 0 elements, got %d", len(m))
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment