Commit 8c85e230 authored by Josh Bleecher Snyder's avatar Josh Bleecher Snyder

cmd/compile: recognize integer ranges in switch statements

Consider a switch statement like:

switch x {
case 1:
  // ...
case 2, 3, 4, 5, 6:
  // ...
case 5:
  // ...
}

Prior to this CL, the generated code treated
2, 3, 4, 5, and 6 independently in a binary search.
With this CL, the generated code checks whether
2 <= x && x <= 6.
walkinrange then optimizes that range check
into a single unsigned comparison.

Experiments suggest that the best min range size
is 2, using binary size as a proxy for optimization.

Binary sizes before/after this CL:

cmd/compile: 14209728 / 14165360
cmd/go:       9543100 /  9539004

Change-Id: If2f7fb97ca80468fa70351ef540866200c4c996c
Reviewed-on: https://go-review.googlesource.com/26770
Run-TryBot: Josh Bleecher Snyder <josharian@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarMatthew Dempsky <mdempsky@google.com>
parent 5c3edc46
...@@ -880,17 +880,27 @@ func (p *printer) stmtfmt(n *Node) *printer { ...@@ -880,17 +880,27 @@ func (p *printer) stmtfmt(n *Node) *printer {
case OXCASE: case OXCASE:
if n.List.Len() != 0 { if n.List.Len() != 0 {
p.f("case %v: %v", hconv(n.List, FmtComma), n.Nbody) p.f("case %v", hconv(n.List, FmtComma))
} else { } else {
p.f("default: %v", n.Nbody) p.s("default")
} }
p.f(": %v", n.Nbody)
case OCASE: case OCASE:
if n.Left != nil { switch {
p.f("case %v: %v", n.Left, n.Nbody) case n.Left != nil:
} else { // single element
p.f("default: %v", n.Nbody) p.f("case %v", n.Left)
case n.List.Len() > 0:
// range
if n.List.Len() != 2 {
Fatalf("bad OCASE list length", n.List)
}
p.f("case %v..%v", n.List.First(), n.List.Second())
default:
p.s("default")
} }
p.f(": %v", n.Nbody)
case OBREAK, case OBREAK,
OCONTINUE, OCONTINUE,
......
...@@ -16,7 +16,10 @@ const ( ...@@ -16,7 +16,10 @@ const (
switchKindType // switch a.(type) {...} switchKindType // switch a.(type) {...}
) )
const binarySearchMin = 4 // minimum number of cases for binary search const (
binarySearchMin = 4 // minimum number of cases for binary search
integerRangeMin = 2 // minimum size of integer ranges
)
// An exprSwitch walks an expression switch. // An exprSwitch walks an expression switch.
type exprSwitch struct { type exprSwitch struct {
...@@ -291,7 +294,18 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node { ...@@ -291,7 +294,18 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
lno := setlineno(n) lno := setlineno(n)
a := Nod(OIF, nil, nil) a := Nod(OIF, nil, nil)
if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { if rng := n.List.Slice(); rng != nil {
// Integer range.
// exprname is a temp or a constant,
// so it is safe to evaluate twice.
// In most cases, this conjunction will be
// rewritten by walkinrange into a single comparison.
low := Nod(OGE, s.exprname, rng[0])
high := Nod(OLE, s.exprname, rng[1])
a.Left = Nod(OANDAND, low, high)
a.Left = typecheck(a.Left, Erv)
a.Left = walkexpr(a.Left, nil) // give walk the opportunity to optimize the range check
} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
a.Left = Nod(OEQ, s.exprname, n.Left) // if name == val a.Left = Nod(OEQ, s.exprname, n.Left) // if name == val
a.Left = typecheck(a.Left, Erv) a.Left = typecheck(a.Left, Erv)
} else if s.kind == switchKindTrue { } else if s.kind == switchKindTrue {
...@@ -312,7 +326,13 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node { ...@@ -312,7 +326,13 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
// find the middle and recur // find the middle and recur
half := len(cc) / 2 half := len(cc) / 2
a := Nod(OIF, nil, nil) a := Nod(OIF, nil, nil)
mid := cc[half-1].node.Left n := cc[half-1].node
var mid *Node
if rng := n.List.Slice(); rng != nil {
mid = rng[1] // high end of range
} else {
mid = n.Left
}
le := Nod(OLE, s.exprname, mid) le := Nod(OLE, s.exprname, mid)
if Isconst(mid, CTSTR) { if Isconst(mid, CTSTR) {
// Search by length and then by value; see caseClauseByConstVal. // Search by length and then by value; see caseClauseByConstVal.
...@@ -368,9 +388,51 @@ func casebody(sw *Node, typeswvar *Node) { ...@@ -368,9 +388,51 @@ func casebody(sw *Node, typeswvar *Node) {
n.List.Set(nil) n.List.Set(nil)
cas = append(cas, n) cas = append(cas, n)
default: default:
// expand multi-valued cases // Expand multi-valued cases and detect ranges of integer cases.
for _, n1 := range n.List.Slice() { if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
cas = append(cas, Nod(OCASE, n1, jmp)) // Can't use integer ranges. Expand each case into a separate node.
for _, n1 := range n.List.Slice() {
cas = append(cas, Nod(OCASE, n1, jmp))
}
break
}
// Find integer ranges within runs of constants.
s := n.List.Slice()
j := 0
for j < len(s) {
// Find a run of constants.
var run int
for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
}
if run-j >= integerRangeMin {
// Search for integer ranges in s[j:run].
// Typechecking is done, so all values are already in an appropriate range.
search := s[j:run]
sort.Sort(constIntNodesByVal(search))
for beg, end := 0, 1; end <= len(search); end++ {
if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
continue
}
if end-beg >= integerRangeMin {
// Record range in List.
c := Nod(OCASE, nil, jmp)
c.List.Set2(search[beg], search[end-1])
cas = append(cas, c)
} else {
// Not large enough for range; record separately.
for _, n := range search[beg:end] {
cas = append(cas, Nod(OCASE, n, jmp))
}
}
beg = end
}
j = run
}
// Advance to next constant, adding individual non-constant
// or as-yet-unhandled constant cases as we go.
for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
cas = append(cas, Nod(OCASE, s[j], jmp))
}
} }
} }
...@@ -418,7 +480,7 @@ func casebody(sw *Node, typeswvar *Node) { ...@@ -418,7 +480,7 @@ func casebody(sw *Node, typeswvar *Node) {
func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
var cc caseClauses var cc caseClauses
for _, n := range clauses { for _, n := range clauses {
if n.Left == nil { if n.Left == nil && n.List.Len() == 0 {
// default case // default case
if cc.defjmp != nil { if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking") Fatalf("duplicate default case not detected during typechecking")
...@@ -427,6 +489,9 @@ func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { ...@@ -427,6 +489,9 @@ func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
continue continue
} }
c := caseClause{node: n, ordinal: len(cc.list)} c := caseClause{node: n, ordinal: len(cc.list)}
if n.List.Len() > 0 {
c.isconst = true
}
switch consttype(n.Left) { switch consttype(n.Left) {
case CTFLT, CTINT, CTRUNE, CTSTR: case CTFLT, CTINT, CTRUNE, CTSTR:
c.isconst = true c.isconst = true
...@@ -533,14 +598,34 @@ func (s *exprSwitch) checkDupCases(cc []caseClause) { ...@@ -533,14 +598,34 @@ func (s *exprSwitch) checkDupCases(cc []caseClause) {
if ct := consttype(c.node.Left); ct < 0 || ct == CTBOOL { if ct := consttype(c.node.Left); ct < 0 || ct == CTBOOL {
continue continue
} }
val := c.node.Left.Val().Interface() if c.node.Left != nil {
prev, dup := seen[val] // Single constant.
if !dup { val := c.node.Left.Val().Interface()
seen[val] = c.node prev, dup := seen[val]
if !dup {
seen[val] = c.node
continue
}
setlineno(c.node)
Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
continue
}
if c.node.List.Len() == 2 {
// Range of integers.
low := c.node.List.Index(0).Int64()
high := c.node.List.Index(1).Int64()
for i := low; i <= high; i++ {
prev, dup := seen[i]
if !dup {
seen[i] = c.node
continue
}
setlineno(c.node)
Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
}
continue continue
} }
setlineno(c.node) Fatalf("bad caseClause node in checkDupCases: %v", c.node)
Yyerror("duplicate case %v in switch\n\tprevious case at %v", prev.Left, prev.Line())
} }
return return
} }
...@@ -784,8 +869,24 @@ type caseClauseByConstVal []caseClause ...@@ -784,8 +869,24 @@ type caseClauseByConstVal []caseClause
func (x caseClauseByConstVal) Len() int { return len(x) } func (x caseClauseByConstVal) Len() int { return len(x) }
func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x caseClauseByConstVal) Less(i, j int) bool { func (x caseClauseByConstVal) Less(i, j int) bool {
v1 := x[i].node.Left.Val().U // n1 and n2 might be individual constants or integer ranges.
v2 := x[j].node.Left.Val().U // We have checked for duplicates already,
// so ranges can be safely represented by any value in the range.
n1 := x[i].node
var v1 interface{}
if s := n1.List.Slice(); s != nil {
v1 = s[0].Val().U
} else {
v1 = n1.Left.Val().U
}
n2 := x[j].node
var v2 interface{}
if s := n2.List.Slice(); s != nil {
v2 = s[0].Val().U
} else {
v2 = n2.Left.Val().U
}
switch v1 := v1.(type) { switch v1 := v1.(type) {
case *Mpflt: case *Mpflt:
...@@ -821,3 +922,11 @@ func (x caseClauseByType) Less(i, j int) bool { ...@@ -821,3 +922,11 @@ func (x caseClauseByType) Less(i, j int) bool {
} }
return c1.ordinal < c2.ordinal return c1.ordinal < c2.ordinal
} }
type constIntNodesByVal []*Node
func (x constIntNodesByVal) Len() int { return len(x) }
func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x constIntNodesByVal) Less(i, j int) bool {
return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
}
...@@ -436,7 +436,7 @@ const ( ...@@ -436,7 +436,7 @@ const (
// statements // statements
OBLOCK // { List } (block of code) OBLOCK // { List } (block of code)
OBREAK // break OBREAK // break
OCASE // case Left: Nbody (select case after processing; Left==nil means default) OCASE // case Left or List[0]..List[1]: Nbody (select case after processing; Left==nil and List==nil means default)
OXCASE // case List: Nbody (select case before processing; List==nil means default) OXCASE // case List: Nbody (select case before processing; List==nil means default)
OCONTINUE // continue OCONTINUE // continue
ODEFER // defer Left (Left must be call) ODEFER // defer Left (Left must be call)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment