Commit 060501dc authored by Alexandru Moșoi's avatar Alexandru Moșoi Committed by Alexandru Moșoi

cmd/compile: constant fold modulo

Fixes #15079

Change-Id: Ib4dd9eab322da39234008e040100e75cb58761b3
Reviewed-on: https://go-review.googlesource.com/21501Reviewed-by: default avatarDavid Chase <drchase@google.com>
Run-TryBot: Alexandru Moșoi <alexandru@mosoi.ro>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 68325b56
...@@ -29,6 +29,11 @@ func b(i uint, j uint) uint { ...@@ -29,6 +29,11 @@ func b(i uint, j uint) uint {
return i / j return i / j
} }
//go:noinline
func c(i int) int {
return 7 / (i - i)
}
func main() { func main() {
if got := checkDivByZero(func() { b(7, 0) }); !got { if got := checkDivByZero(func() { b(7, 0) }); !got {
fmt.Printf("expected div by zero for b(7, 0), got no error\n") fmt.Printf("expected div by zero for b(7, 0), got no error\n")
...@@ -42,6 +47,10 @@ func main() { ...@@ -42,6 +47,10 @@ func main() {
fmt.Printf("expected div by zero for a(4, nil), got no error\n") fmt.Printf("expected div by zero for a(4, nil), got no error\n")
failed = true failed = true
} }
if got := checkDivByZero(func() { c(5) }); !got {
fmt.Printf("expected div by zero for c(5), got no error\n")
failed = true
}
if failed { if failed {
panic("tests failed") panic("tests failed")
......
...@@ -47,7 +47,7 @@ var szs []szD = []szD{ ...@@ -47,7 +47,7 @@ var szs []szD = []szD{
} }
var ops []op = []op{op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"}, var ops []op = []op{op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
op{"lsh", "<<"}, op{"rsh", ">>"}} op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"}}
// compute the result of i op j, cast as type t. // compute the result of i op j, cast as type t.
func ansU(i, j uint64, t, op string) string { func ansU(i, j uint64, t, op string) string {
...@@ -63,6 +63,10 @@ func ansU(i, j uint64, t, op string) string { ...@@ -63,6 +63,10 @@ func ansU(i, j uint64, t, op string) string {
if j != 0 { if j != 0 {
ans = i / j ans = i / j
} }
case "%":
if j != 0 {
ans = i % j
}
case "<<": case "<<":
ans = i << j ans = i << j
case ">>": case ">>":
...@@ -93,6 +97,10 @@ func ansS(i, j int64, t, op string) string { ...@@ -93,6 +97,10 @@ func ansS(i, j int64, t, op string) string {
if j != 0 { if j != 0 {
ans = i / j ans = i / j
} }
case "%":
if j != 0 {
ans = i % j
}
case "<<": case "<<":
ans = i << uint64(j) ans = i << uint64(j)
case ">>": case ">>":
...@@ -151,7 +159,7 @@ func main() { ...@@ -151,7 +159,7 @@ func main() {
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1) fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero // avoid division by zero
if o.name != "div" || i != 0 { if o.name != "mod" && o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd) fncCnst1.Execute(w, fd)
} }
...@@ -170,7 +178,7 @@ func main() { ...@@ -170,7 +178,7 @@ func main() {
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1) fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
// avoid division by zero // avoid division by zero
if o.name != "div" || i != 0 { if o.name != "mod" && o.name != "div" || i != 0 {
fncCnst1.Execute(w, fd) fncCnst1.Execute(w, fd)
} }
fncCnst2.Execute(w, fd) fncCnst2.Execute(w, fd)
...@@ -184,14 +192,14 @@ func main() { ...@@ -184,14 +192,14 @@ func main() {
vrf1, _ := template.New("vrf1").Parse(` vrf1, _ := template.New("vrf1").Parse(`
if got := {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa({{.Input}}); got != {{.Ans}} { if got := {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}{{.Symbol}}{{.Input}} = %d, wanted {{.Ans}}\n",got) fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}%s{{.Input}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
failed = true failed = true
} }
`) `)
vrf2, _ := template.New("vrf2").Parse(` vrf2, _ := template.New("vrf2").Parse(`
if got := {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa({{.Input}}); got != {{.Ans}} { if got := {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa({{.Input}}); got != {{.Ans}} {
fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}{{.Symbol}}{{.Number}} = %d, wanted {{.Ans}}\n",got) fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}%s{{.Number}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
failed = true failed = true
} }
`) `)
...@@ -211,7 +219,7 @@ func main() { ...@@ -211,7 +219,7 @@ func main() {
// unsigned // unsigned
for _, j := range s.u { for _, j := range s.u {
if o.name != "div" || j != 0 { if o.name != "mod" && o.name != "div" || j != 0 {
fd.Ans = ansU(i, j, s.name, o.symbol) fd.Ans = ansU(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j) fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd) err = vrf1.Execute(w, fd)
...@@ -220,7 +228,7 @@ func main() { ...@@ -220,7 +228,7 @@ func main() {
} }
} }
if o.name != "div" || i != 0 { if o.name != "mod" && o.name != "div" || i != 0 {
fd.Ans = ansU(j, i, s.name, o.symbol) fd.Ans = ansU(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j) fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd) err = vrf2.Execute(w, fd)
...@@ -247,7 +255,7 @@ func main() { ...@@ -247,7 +255,7 @@ func main() {
fd.Number = fmt.Sprintf("%d", i) fd.Number = fmt.Sprintf("%d", i)
fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1) fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
for _, j := range s.i { for _, j := range s.i {
if o.name != "div" || j != 0 { if o.name != "mod" && o.name != "div" || j != 0 {
fd.Ans = ansS(i, j, s.name, o.symbol) fd.Ans = ansS(i, j, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j) fd.Input = fmt.Sprintf("%d", j)
err = vrf1.Execute(w, fd) err = vrf1.Execute(w, fd)
...@@ -256,7 +264,7 @@ func main() { ...@@ -256,7 +264,7 @@ func main() {
} }
} }
if o.name != "div" || i != 0 { if o.name != "mod" && o.name != "div" || i != 0 {
fd.Ans = ansS(j, i, s.name, o.symbol) fd.Ans = ansS(j, i, s.name, o.symbol)
fd.Input = fmt.Sprintf("%d", j) fd.Input = fmt.Sprintf("%d", j)
err = vrf2.Execute(w, fd) err = vrf2.Execute(w, fd)
......
...@@ -66,6 +66,16 @@ ...@@ -66,6 +66,16 @@
(Const32F [f2i(float64(i2f32(c) * i2f32(d)))]) (Const32F [f2i(float64(i2f32(c) * i2f32(d)))])
(Mul64F (Const64F [c]) (Const64F [d])) -> (Const64F [f2i(i2f(c) * i2f(d))]) (Mul64F (Const64F [c]) (Const64F [d])) -> (Const64F [f2i(i2f(c) * i2f(d))])
(Mod8 (Const8 [c]) (Const8 [d])) && d != 0-> (Const8 [int64(int8(c % d))])
(Mod16 (Const16 [c]) (Const16 [d])) && d != 0-> (Const16 [int64(int16(c % d))])
(Mod32 (Const32 [c]) (Const32 [d])) && d != 0-> (Const32 [int64(int32(c % d))])
(Mod64 (Const64 [c]) (Const64 [d])) && d != 0-> (Const64 [c % d])
(Mod8u (Const8 [c]) (Const8 [d])) && d != 0-> (Const8 [int64(uint8(c) % uint8(d))])
(Mod16u (Const16 [c]) (Const16 [d])) && d != 0-> (Const16 [int64(uint16(c) % uint16(d))])
(Mod32u (Const32 [c]) (Const32 [d])) && d != 0-> (Const32 [int64(uint32(c) % uint32(d))])
(Mod64u (Const64 [c]) (Const64 [d])) && d != 0-> (Const64 [int64(uint64(c) % uint64(d))])
(Lsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c << uint64(d)]) (Lsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c << uint64(d)])
(Rsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c >> uint64(d)]) (Rsh64x64 (Const64 [c]) (Const64 [d])) -> (Const64 [c >> uint64(d)])
(Rsh64Ux64 (Const64 [c]) (Const64 [d])) -> (Const64 [int64(uint64(c) >> uint64(d))]) (Rsh64Ux64 (Const64 [c]) (Const64 [d])) -> (Const64 [int64(uint64(c) >> uint64(d))])
...@@ -728,5 +738,5 @@ ...@@ -728,5 +738,5 @@
// A%B = A-(A/B*B). // A%B = A-(A/B*B).
// This implements % with two * and a bunch of ancillary ops. // This implements % with two * and a bunch of ancillary ops.
// One of the * is free if the user's code also computes A/B. // One of the * is free if the user's code also computes A/B.
(Mod64 <t> x (Const64 [c])) && smagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c]))) (Mod64 <t> x (Const64 [c])) && x.Op != OpConst64 && smagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c])))
(Mod64u <t> x (Const64 [c])) && umagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c]))) (Mod64u <t> x (Const64 [c])) && x.Op != OpConst64 && umagic64ok(c) -> (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c])))
...@@ -174,10 +174,22 @@ func rewriteValuegeneric(v *Value, config *Config) bool { ...@@ -174,10 +174,22 @@ func rewriteValuegeneric(v *Value, config *Config) bool {
return rewriteValuegeneric_OpLsh8x64(v, config) return rewriteValuegeneric_OpLsh8x64(v, config)
case OpLsh8x8: case OpLsh8x8:
return rewriteValuegeneric_OpLsh8x8(v, config) return rewriteValuegeneric_OpLsh8x8(v, config)
case OpMod16:
return rewriteValuegeneric_OpMod16(v, config)
case OpMod16u:
return rewriteValuegeneric_OpMod16u(v, config)
case OpMod32:
return rewriteValuegeneric_OpMod32(v, config)
case OpMod32u:
return rewriteValuegeneric_OpMod32u(v, config)
case OpMod64: case OpMod64:
return rewriteValuegeneric_OpMod64(v, config) return rewriteValuegeneric_OpMod64(v, config)
case OpMod64u: case OpMod64u:
return rewriteValuegeneric_OpMod64u(v, config) return rewriteValuegeneric_OpMod64u(v, config)
case OpMod8:
return rewriteValuegeneric_OpMod8(v, config)
case OpMod8u:
return rewriteValuegeneric_OpMod8u(v, config)
case OpMul16: case OpMul16:
return rewriteValuegeneric_OpMul16(v, config) return rewriteValuegeneric_OpMul16(v, config)
case OpMul32: case OpMul32:
...@@ -4409,11 +4421,136 @@ func rewriteValuegeneric_OpLsh8x8(v *Value, config *Config) bool { ...@@ -4409,11 +4421,136 @@ func rewriteValuegeneric_OpLsh8x8(v *Value, config *Config) bool {
} }
return false return false
} }
func rewriteValuegeneric_OpMod16(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod16 (Const16 [c]) (Const16 [d]))
// cond: d != 0
// result: (Const16 [int64(int16(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst16 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst16 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst16)
v.AuxInt = int64(int16(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod16u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod16u (Const16 [c]) (Const16 [d]))
// cond: d != 0
// result: (Const16 [int64(uint16(c) % uint16(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst16 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst16 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst16)
v.AuxInt = int64(uint16(c) % uint16(d))
return true
}
return false
}
func rewriteValuegeneric_OpMod32(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod32 (Const32 [c]) (Const32 [d]))
// cond: d != 0
// result: (Const32 [int64(int32(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst32 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst32 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst32)
v.AuxInt = int64(int32(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod32u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod32u (Const32 [c]) (Const32 [d]))
// cond: d != 0
// result: (Const32 [int64(uint32(c) % uint32(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst32 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst32 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst32)
v.AuxInt = int64(uint32(c) % uint32(d))
return true
}
return false
}
func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool { func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
b := v.Block b := v.Block
_ = b _ = b
// match: (Mod64 (Const64 [c]) (Const64 [d]))
// cond: d != 0
// result: (Const64 [c % d])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst64 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst64 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst64)
v.AuxInt = c % d
return true
}
// match: (Mod64 <t> x (Const64 [c])) // match: (Mod64 <t> x (Const64 [c]))
// cond: smagic64ok(c) // cond: x.Op != OpConst64 && smagic64ok(c)
// result: (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c]))) // result: (Sub64 x (Mul64 <t> (Div64 <t> x (Const64 <t> [c])) (Const64 <t> [c])))
for { for {
t := v.Type t := v.Type
...@@ -4423,7 +4560,7 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool { ...@@ -4423,7 +4560,7 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
break break
} }
c := v_1.AuxInt c := v_1.AuxInt
if !(smagic64ok(c)) { if !(x.Op != OpConst64 && smagic64ok(c)) {
break break
} }
v.reset(OpSub64) v.reset(OpSub64)
...@@ -4446,6 +4583,27 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool { ...@@ -4446,6 +4583,27 @@ func rewriteValuegeneric_OpMod64(v *Value, config *Config) bool {
func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool { func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
b := v.Block b := v.Block
_ = b _ = b
// match: (Mod64u (Const64 [c]) (Const64 [d]))
// cond: d != 0
// result: (Const64 [int64(uint64(c) % uint64(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst64 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst64 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst64)
v.AuxInt = int64(uint64(c) % uint64(d))
return true
}
// match: (Mod64u <t> n (Const64 [c])) // match: (Mod64u <t> n (Const64 [c]))
// cond: isPowerOfTwo(c) // cond: isPowerOfTwo(c)
// result: (And64 n (Const64 <t> [c-1])) // result: (And64 n (Const64 <t> [c-1]))
...@@ -4468,7 +4626,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool { ...@@ -4468,7 +4626,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
return true return true
} }
// match: (Mod64u <t> x (Const64 [c])) // match: (Mod64u <t> x (Const64 [c]))
// cond: umagic64ok(c) // cond: x.Op != OpConst64 && umagic64ok(c)
// result: (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c]))) // result: (Sub64 x (Mul64 <t> (Div64u <t> x (Const64 <t> [c])) (Const64 <t> [c])))
for { for {
t := v.Type t := v.Type
...@@ -4478,7 +4636,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool { ...@@ -4478,7 +4636,7 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
break break
} }
c := v_1.AuxInt c := v_1.AuxInt
if !(umagic64ok(c)) { if !(x.Op != OpConst64 && umagic64ok(c)) {
break break
} }
v.reset(OpSub64) v.reset(OpSub64)
...@@ -4498,6 +4656,58 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool { ...@@ -4498,6 +4656,58 @@ func rewriteValuegeneric_OpMod64u(v *Value, config *Config) bool {
} }
return false return false
} }
func rewriteValuegeneric_OpMod8(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod8 (Const8 [c]) (Const8 [d]))
// cond: d != 0
// result: (Const8 [int64(int8(c % d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst8 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst8 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst8)
v.AuxInt = int64(int8(c % d))
return true
}
return false
}
func rewriteValuegeneric_OpMod8u(v *Value, config *Config) bool {
b := v.Block
_ = b
// match: (Mod8u (Const8 [c]) (Const8 [d]))
// cond: d != 0
// result: (Const8 [int64(uint8(c) % uint8(d))])
for {
v_0 := v.Args[0]
if v_0.Op != OpConst8 {
break
}
c := v_0.AuxInt
v_1 := v.Args[1]
if v_1.Op != OpConst8 {
break
}
d := v_1.AuxInt
if !(d != 0) {
break
}
v.reset(OpConst8)
v.AuxInt = int64(uint8(c) % uint8(d))
return true
}
return false
}
func rewriteValuegeneric_OpMul16(v *Value, config *Config) bool { func rewriteValuegeneric_OpMul16(v *Value, config *Config) bool {
b := v.Block b := v.Block
_ = b _ = b
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment