Commit a3b3284d authored by Matthew Dempsky's avatar Matthew Dempsky

cmd/compile: prevent untyped types from reaching walk

We already require expressions to have already been typechecked before
reaching walk. Moreover, all untyped expressions should have been
converted to their default type by walk.

However, in practice, we've been somewhat sloppy and inconsistent
about ensuring this. In particular, a lot of AST rewrites ended up
leaving untyped bool expressions scattered around. These likely aren't
harmful in practice, but it seems worth cleaning up.

The two most common cases addressed by this CL are:

1) When generating OIF and OFOR nodes, we would often typecheck the
conditional expression, but not apply defaultlit to force it to the
expression's default type.

2) When rewriting string comparisons into more fundamental primitives,
we were simply overwriting r.Type with the desired type, which didn't
propagate the type to nested subexpressions. These are fixed by
utilizing finishcompare, which correctly handles this (and is already
used by other comparison lowering rewrites).

Lastly, walkexpr is extended to assert that it's not called on untyped
expressions.

Fixes #23834.

Change-Id: Icbd29648a293555e4015d3b06a95a24ccbd3f790
Reviewed-on: https://go-review.googlesource.com/98337Reviewed-by: default avatarRobert Griesemer <gri@golang.org>
parent ed8b7a77
...@@ -434,6 +434,7 @@ func walkrange(n *Node) *Node { ...@@ -434,6 +434,7 @@ func walkrange(n *Node) *Node {
typecheckslice(n.Left.Ninit.Slice(), Etop) typecheckslice(n.Left.Ninit.Slice(), Etop)
n.Left = typecheck(n.Left, Erv) n.Left = typecheck(n.Left, Erv)
n.Left = defaultlit(n.Left, nil)
n.Right = typecheck(n.Right, Etop) n.Right = typecheck(n.Right, Etop)
typecheckslice(body, Etop) typecheckslice(body, Etop)
n.Nbody.Prepend(body...) n.Nbody.Prepend(body...)
...@@ -529,6 +530,7 @@ func memclrrange(n, v1, v2, a *Node) bool { ...@@ -529,6 +530,7 @@ func memclrrange(n, v1, v2, a *Node) bool {
n.Nbody.Append(v1) n.Nbody.Append(v1)
n.Left = typecheck(n.Left, Erv) n.Left = typecheck(n.Left, Erv)
n.Left = defaultlit(n.Left, nil)
typecheckslice(n.Nbody.Slice(), Etop) typecheckslice(n.Nbody.Slice(), Etop)
n = walkstmt(n) n = walkstmt(n)
return true return true
......
...@@ -308,6 +308,7 @@ func walkselectcases(cases *Nodes) []*Node { ...@@ -308,6 +308,7 @@ func walkselectcases(cases *Nodes) []*Node {
cond := nod(OEQ, chosen, nodintconst(int64(i))) cond := nod(OEQ, chosen, nodintconst(int64(i)))
cond = typecheck(cond, Erv) cond = typecheck(cond, Erv)
cond = defaultlit(cond, nil)
r = nod(OIF, cond, nil) r = nod(OIF, cond, nil)
r.Nbody.AppendNodes(&cas.Nbody) r.Nbody.AppendNodes(&cas.Nbody)
......
...@@ -217,6 +217,7 @@ func walkswitch(sw *Node) { ...@@ -217,6 +217,7 @@ func walkswitch(sw *Node) {
if sw.Left == nil { if sw.Left == nil {
sw.Left = nodbool(true) sw.Left = nodbool(true)
sw.Left = typecheck(sw.Left, Erv) sw.Left = typecheck(sw.Left, Erv)
sw.Left = defaultlit(sw.Left, nil)
} }
if sw.Left.Op == OTYPESW { if sw.Left.Op == OTYPESW {
...@@ -314,21 +315,16 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node { ...@@ -314,21 +315,16 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
low := nod(OGE, s.exprname, rng[0]) low := nod(OGE, s.exprname, rng[0])
high := nod(OLE, s.exprname, rng[1]) high := nod(OLE, s.exprname, rng[1])
a.Left = nod(OANDAND, low, high) a.Left = nod(OANDAND, low, high)
a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
a.Left = walkexpr(a.Left, nil) // give walk the opportunity to optimize the range check
} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
a.Left = nod(OEQ, s.exprname, n.Left) // if name == val a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
} else if s.kind == switchKindTrue { } else if s.kind == switchKindTrue {
a.Left = n.Left // if val a.Left = n.Left // if val
} else { } else {
// s.kind == switchKindFalse // s.kind == switchKindFalse
a.Left = nod(ONOT, n.Left, nil) // if !val a.Left = nod(ONOT, n.Left, nil) // if !val
a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
} }
a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right) // goto l a.Nbody.Set1(n.Right) // goto l
cas = append(cas, a) cas = append(cas, a)
...@@ -750,6 +746,7 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -750,6 +746,7 @@ func (s *typeSwitch) walk(sw *Node) {
def = blk def = blk
} }
i.Left = typecheck(i.Left, Erv) i.Left = typecheck(i.Left, Erv)
i.Left = defaultlit(i.Left, nil)
cas = append(cas, i) cas = append(cas, i)
// Load hash from type or itab. // Load hash from type or itab.
...@@ -869,6 +866,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node { ...@@ -869,6 +866,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node {
a := nod(OIF, nil, nil) a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash))) a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, Erv) a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right) a.Nbody.Set1(n.Right)
cas = append(cas, a) cas = append(cas, a)
} }
...@@ -880,6 +878,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node { ...@@ -880,6 +878,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node {
a := nod(OIF, nil, nil) a := nod(OIF, nil, nil)
a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash))) a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
a.Left = typecheck(a.Left, Erv) a.Left = typecheck(a.Left, Erv)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half])) a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:])) a.Rlist.Set1(s.walkCases(cc[half:]))
return a return a
......
...@@ -476,6 +476,10 @@ func walkexpr(n *Node, init *Nodes) *Node { ...@@ -476,6 +476,10 @@ func walkexpr(n *Node, init *Nodes) *Node {
Fatalf("missed typecheck: %+v", n) Fatalf("missed typecheck: %+v", n)
} }
if n.Type.IsUntyped() {
Fatalf("expression has untyped type: %+v", n)
}
if n.Op == ONAME && n.Class() == PAUTOHEAP { if n.Op == ONAME && n.Class() == PAUTOHEAP {
nn := nod(OIND, n.Name.Param.Heapaddr, nil) nn := nod(OIND, n.Name.Param.Heapaddr, nil)
nn = typecheck(nn, Erv) nn = typecheck(nn, Erv)
...@@ -1234,10 +1238,7 @@ opswitch: ...@@ -1234,10 +1238,7 @@ opswitch:
if (Op(n.Etype) == OEQ || Op(n.Etype) == ONE) && Isconst(n.Right, CTSTR) && n.Left.Op == OADDSTR && n.Left.List.Len() == 2 && Isconst(n.Left.List.Second(), CTSTR) && strlit(n.Right) == strlit(n.Left.List.Second()) { if (Op(n.Etype) == OEQ || Op(n.Etype) == ONE) && Isconst(n.Right, CTSTR) && n.Left.Op == OADDSTR && n.Left.List.Len() == 2 && Isconst(n.Left.List.Second(), CTSTR) && strlit(n.Right) == strlit(n.Left.List.Second()) {
// TODO(marvin): Fix Node.EType type union. // TODO(marvin): Fix Node.EType type union.
r := nod(Op(n.Etype), nod(OLEN, n.Left.List.First(), nil), nodintconst(0)) r := nod(Op(n.Etype), nod(OLEN, n.Left.List.First(), nil), nodintconst(0))
r = typecheck(r, Erv) n = finishcompare(n, r, init)
r = walkexpr(r, init)
r.Type = n.Type
n = r
break break
} }
...@@ -1337,10 +1338,7 @@ opswitch: ...@@ -1337,10 +1338,7 @@ opswitch:
remains -= step remains -= step
i += step i += step
} }
r = typecheck(r, Erv) n = finishcompare(n, r, init)
r = walkexpr(r, init)
r.Type = n.Type
n = r
break break
} }
} }
...@@ -1374,9 +1372,6 @@ opswitch: ...@@ -1374,9 +1372,6 @@ opswitch:
r = nod(ONOT, r, nil) r = nod(ONOT, r, nil)
r = nod(OOROR, nod(ONE, llen, rlen), r) r = nod(OOROR, nod(ONE, llen, rlen), r)
} }
r = typecheck(r, Erv)
r = walkexpr(r, nil)
} else { } else {
// sys_cmpstring(s1, s2) :: 0 // sys_cmpstring(s1, s2) :: 0
r = mkcall("cmpstring", types.Types[TINT], init, conv(n.Left, types.Types[TSTRING]), conv(n.Right, types.Types[TSTRING])) r = mkcall("cmpstring", types.Types[TINT], init, conv(n.Left, types.Types[TSTRING]), conv(n.Right, types.Types[TSTRING]))
...@@ -1384,12 +1379,7 @@ opswitch: ...@@ -1384,12 +1379,7 @@ opswitch:
r = nod(Op(n.Etype), r, nodintconst(0)) r = nod(Op(n.Etype), r, nodintconst(0))
} }
r = typecheck(r, Erv) n = finishcompare(n, r, init)
if !n.Type.IsBoolean() {
Fatalf("cmp %v", n.Type)
}
r.Type = n.Type
n = r
case OADDSTR: case OADDSTR:
n = addstr(n, init) n = addstr(n, init)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment