Commit 30088ac9 authored by Keith Randall's avatar Keith Randall

cmd/compile: make CSE faster

To refine a set of possibly equivalent values, the old CSE algorithm
picked one value, compared it against all the others, and made two sets
out of the results (the values that match the picked value and the
values that didn't).  Unfortunately, this leads to O(n^2) behavior. The
picked value ends up being equal to no other values, we make size 1 and
size n-1 sets, and then recurse on the size n-1 set.

Instead, sort the set by the equivalence classes of its arguments.  Then
we just look for spots in the sorted list where the equivalence classes
of the arguments change.  This lets us do a multi-way split for O(n lg
n) time.

This change makes cmpDepth unnecessary.

The refinement portion used to call the type comparator.  That is
unnecessary as the type was already part of the initial partition.

Lowers time of 16361 from 8 sec to 3 sec.
Lowers time of 15112 from 282 sec to 20 sec. That's kind of unfair, as
CL 30257 changed it from 21 sec to 282 sec. But that CL fixed other bad
compile times (issue #17127) by large factors, so net still a big win.

Fixes #15112
Fixes #16361

Change-Id: I351ce111bae446608968c6d48710eeb6a3d8e527
Reviewed-on: https://go-review.googlesource.com/30354Reviewed-by: default avatarTodd Neal <todd@tneal.org>
parent bd06d482
......@@ -9,10 +9,6 @@ import (
"sort"
)
const (
cmpDepth = 1
)
// cse does common-subexpression elimination on the Function.
// Values are just relinked, nothing is deleted. A subsequent deadcode
// pass is required to actually remove duplicate expressions.
......@@ -60,7 +56,8 @@ func cse(f *Func) {
valueEqClass[v.ID] = -v.ID
}
}
for i, e := range partition {
var pNum ID = 1
for _, e := range partition {
if f.pass.debug > 1 && len(e) > 500 {
fmt.Printf("CSE.large partition (%d): ", len(e))
for j := 0; j < 3; j++ {
......@@ -70,20 +67,22 @@ func cse(f *Func) {
}
for _, v := range e {
valueEqClass[v.ID] = ID(i)
valueEqClass[v.ID] = pNum
}
if f.pass.debug > 2 && len(e) > 1 {
fmt.Printf("CSE.partition #%d:", i)
fmt.Printf("CSE.partition #%d:", pNum)
for _, v := range e {
fmt.Printf(" %s", v.String())
}
fmt.Printf("\n")
}
pNum++
}
// Find an equivalence class where some members of the class have
// non-equivalent arguments. Split the equivalence class appropriately.
// Repeat until we can't find any more splits.
// Split equivalence classes at points where they have
// non-equivalent arguments. Repeat until we can't find any
// more splits.
var splitPoints []int
for {
changed := false
......@@ -91,39 +90,51 @@ func cse(f *Func) {
// we process new additions as they arrive, avoiding O(n^2) behavior.
for i := 0; i < len(partition); i++ {
e := partition[i]
v := e[0]
// all values in this equiv class that are not equivalent to v get moved
// into another equiv class.
// To avoid allocating while building that equivalence class,
// move the values equivalent to v to the beginning of e
// and other values to the end of e.
allvals := e
eqloop:
for j := 1; j < len(e); {
w := e[j]
equivalent := true
for i := 0; i < len(v.Args); i++ {
if valueEqClass[v.Args[i].ID] != valueEqClass[w.Args[i].ID] {
equivalent = false
// Sort by eq class of arguments.
sort.Sort(partitionByArgClass{e, valueEqClass})
// Find split points.
splitPoints = append(splitPoints[:0], 0)
for j := 1; j < len(e); j++ {
v, w := e[j-1], e[j]
eqArgs := true
for k, a := range v.Args {
b := w.Args[k]
if valueEqClass[a.ID] != valueEqClass[b.ID] {
eqArgs = false
break
}
}
if !equivalent || v.Type.Compare(w.Type) != CMPeq {
// w is not equivalent to v.
// move it to the end and shrink e.
e[j], e[len(e)-1] = e[len(e)-1], e[j]
e = e[:len(e)-1]
valueEqClass[w.ID] = ID(len(partition))
changed = true
continue eqloop
if !eqArgs {
splitPoints = append(splitPoints, j)
}
// v and w are equivalent. Keep w in e.
j++
}
partition[i] = e
if len(e) < len(allvals) {
partition = append(partition, allvals[len(e):])
if len(splitPoints) == 1 {
continue // no splits, leave equivalence class alone.
}
// Move another equivalence class down in place of e.
partition[i] = partition[len(partition)-1]
partition = partition[:len(partition)-1]
i--
// Add new equivalence classes for the parts of e we found.
splitPoints = append(splitPoints, len(e))
for j := 0; j < len(splitPoints)-1; j++ {
f := e[splitPoints[j]:splitPoints[j+1]]
if len(f) == 1 {
// Don't add singletons.
valueEqClass[f[0].ID] = -f[0].ID
continue
}
for _, v := range f {
valueEqClass[v.ID] = pNum
}
pNum++
partition = append(partition, f)
}
changed = true
}
if !changed {
......@@ -253,7 +264,7 @@ func partitionValues(a []*Value, auxIDs auxmap) []eqclass {
j := 1
for ; j < len(a); j++ {
w := a[j]
if cmpVal(v, w, auxIDs, cmpDepth) != CMPeq {
if cmpVal(v, w, auxIDs) != CMPeq {
break
}
}
......@@ -274,7 +285,7 @@ func lt2Cmp(isLt bool) Cmp {
type auxmap map[interface{}]int32
func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
func cmpVal(v, w *Value, auxIDs auxmap) Cmp {
// Try to order these comparison by cost (cheaper first)
if v.Op != w.Op {
return lt2Cmp(v.Op < w.Op)
......@@ -308,18 +319,6 @@ func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux])
}
if depth > 0 {
for i := range v.Args {
if v.Args[i] == w.Args[i] {
// skip comparing equal args
continue
}
if ac := cmpVal(v.Args[i], w.Args[i], auxIDs, depth-1); ac != CMPeq {
return ac
}
}
}
return CMPeq
}
......@@ -334,7 +333,7 @@ func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
func (sv sortvalues) Less(i, j int) bool {
v := sv.a[i]
w := sv.a[j]
if cmp := cmpVal(v, w, sv.auxIDs, cmpDepth); cmp != CMPeq {
if cmp := cmpVal(v, w, sv.auxIDs); cmp != CMPeq {
return cmp == CMPlt
}
......@@ -354,3 +353,25 @@ func (sv partitionByDom) Less(i, j int) bool {
w := sv.a[j]
return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block)
}
type partitionByArgClass struct {
a []*Value // array of values
eqClass []ID // equivalence class IDs of values
}
func (sv partitionByArgClass) Len() int { return len(sv.a) }
func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
func (sv partitionByArgClass) Less(i, j int) bool {
v := sv.a[i]
w := sv.a[j]
for i, a := range v.Args {
b := w.Args[i]
if sv.eqClass[a.ID] < sv.eqClass[b.ID] {
return true
}
if sv.eqClass[a.ID] > sv.eqClass[b.ID] {
return false
}
}
return false
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment