Commit 1dc20e91 authored by Alexander Döring's avatar Alexander Döring Committed by Robert Griesemer

math/big: specialize Karatsuba implementation for squaring

Currently we use three different algorithms for squaring:
1. basic multiplication for small numbers
2. basic squaring for medium numbers
3. Karatsuba multiplication for large numbers

Change 3. to a version of Karatsuba multiplication specialized
for x == y.

Increasing the performance of 3. lets us lower the threshold
between 2. and 3.

Adapt TestCalibrate to the change that 3. isn't independent
of the threshold between 1. and 2. any more.

Fixes #23221.

benchstat old.txt new.txt
name           old time/op  new time/op  delta
NatSqr/1-4     29.6ns ± 7%  29.5ns ± 5%     ~     (p=0.103 n=50+50)
NatSqr/2-4     51.9ns ± 1%  51.9ns ± 1%     ~     (p=0.693 n=42+49)
NatSqr/3-4     64.3ns ± 1%  64.1ns ± 0%   -0.26%  (p=0.000 n=46+43)
NatSqr/5-4     93.5ns ± 2%  93.1ns ± 1%   -0.39%  (p=0.000 n=48+49)
NatSqr/8-4      131ns ± 1%   131ns ± 1%     ~     (p=0.870 n=46+49)
NatSqr/10-4     175ns ± 1%   175ns ± 1%   +0.38%  (p=0.000 n=49+47)
NatSqr/20-4     426ns ± 1%   429ns ± 1%   +0.84%  (p=0.000 n=46+48)
NatSqr/30-4     702ns ± 2%   699ns ± 1%   -0.38%  (p=0.011 n=46+44)
NatSqr/50-4    1.44µs ± 2%  1.43µs ± 1%   -0.54%  (p=0.010 n=48+48)
NatSqr/80-4    2.85µs ± 1%  2.87µs ± 1%   +0.68%  (p=0.000 n=47+47)
NatSqr/100-4   4.06µs ± 1%  4.07µs ± 1%   +0.29%  (p=0.000 n=46+45)
NatSqr/200-4   13.4µs ± 1%  13.5µs ± 1%   +0.73%  (p=0.000 n=48+48)
NatSqr/300-4   28.5µs ± 1%  28.2µs ± 1%   -1.22%  (p=0.000 n=46+48)
NatSqr/500-4   81.9µs ± 1%  67.0µs ± 1%  -18.25%  (p=0.000 n=48+48)
NatSqr/800-4    161µs ± 1%   140µs ± 1%  -13.29%  (p=0.000 n=47+48)
NatSqr/1000-4   245µs ± 1%   207µs ± 1%  -15.17%  (p=0.000 n=49+49)

go test -v -calibrate --run TestCalibrate
...
Calibrating threshold between basicSqr(x) and karatsubaSqr(x)
Looking for a timing difference for x between 200 - 500 words by 10 step
words = 200 deltaT =     -980ns (  -7%) is karatsubaSqr(x) better: false
words = 210 deltaT =     -773ns (  -5%) is karatsubaSqr(x) better: false
words = 220 deltaT =     -695ns (  -4%) is karatsubaSqr(x) better: false
words = 230 deltaT =     -570ns (  -3%) is karatsubaSqr(x) better: false
words = 240 deltaT =     -458ns (  -2%) is karatsubaSqr(x) better: false
words = 250 deltaT =      -63ns (   0%) is karatsubaSqr(x) better: false
words = 260 deltaT =      118ns (   0%) is karatsubaSqr(x) better: true  threshold  found
words = 270 deltaT =      377ns (   1%) is karatsubaSqr(x) better: true
words = 280 deltaT =      765ns (   3%) is karatsubaSqr(x) better: true
words = 290 deltaT =      673ns (   2%) is karatsubaSqr(x) better: true
words = 300 deltaT =      502ns (   1%) is karatsubaSqr(x) better: true
words = 310 deltaT =      629ns (   2%) is karatsubaSqr(x) better: true
words = 320 deltaT =    1.011µs (   3%) is karatsubaSqr(x) better: true
words = 330 deltaT =     1.36µs (   4%) is karatsubaSqr(x) better: true
words = 340 deltaT =    3.001µs (   8%) is karatsubaSqr(x) better: true
words = 350 deltaT =    3.178µs (   8%) is karatsubaSqr(x) better: true
...

Change-Id: I6f13c23d94d042539ac28e77fd2618cdc37a429e
Reviewed-on: https://go-review.googlesource.com/105075
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarRobert Griesemer <gri@golang.org>
parent 723f4286
......@@ -28,24 +28,32 @@ import (
var calibrate = flag.Bool("calibrate", false, "run calibration test")
const (
sqrModeMul = "mul(x, x)"
sqrModeBasic = "basicSqr(x)"
sqrModeKaratsuba = "karatsubaSqr(x)"
)
func TestCalibrate(t *testing.T) {
if *calibrate {
computeKaratsubaThresholds()
// compute basicSqrThreshold where overhead becomes negligible
minSqr := computeSqrThreshold(10, 30, 1, 3)
// compute karatsubaSqrThreshold where karatsuba is faster
maxSqr := computeSqrThreshold(300, 500, 10, 3)
if minSqr != 0 {
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
} else {
fmt.Println("no basicSqrThreshold found")
}
if maxSqr != 0 {
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
} else {
fmt.Println("no karatsubaSqrThreshold found")
}
if !*calibrate {
return
}
computeKaratsubaThresholds()
// compute basicSqrThreshold where overhead becomes negligible
minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
// compute karatsubaSqrThreshold where karatsuba is faster
maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
if minSqr != 0 {
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
} else {
fmt.Println("no basicSqrThreshold found")
}
if maxSqr != 0 {
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
} else {
fmt.Println("no karatsubaSqrThreshold found")
}
}
......@@ -109,16 +117,17 @@ func computeKaratsubaThresholds() {
}
}
func measureBasicSqr(words, nruns int, enable bool) time.Duration {
func measureSqr(words, nruns int, mode string) time.Duration {
// more runs for better statistics
initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
if enable {
// set thresholds to use basicSqr at this number of words
switch mode {
case sqrModeMul:
basicSqrThreshold = words + 1
case sqrModeBasic:
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
} else {
// set thresholds to disable basicSqr for any number of words
basicSqrThreshold, karatsubaSqrThreshold = -1, -1
case sqrModeKaratsuba:
karatsubaSqrThreshold = words - 1
}
var testval int64
......@@ -133,18 +142,18 @@ func measureBasicSqr(words, nruns int, enable bool) time.Duration {
return time.Duration(testval)
}
func computeSqrThreshold(from, to, step, nruns int) int {
fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
var initPos bool
var threshold int
for i := from; i <= to; i += step {
baseline := measureBasicSqr(i, nruns, false)
testval := measureBasicSqr(i, nruns, true)
baseline := measureSqr(i, nruns, lower)
testval := measureSqr(i, nruns, upper)
pos := baseline > testval
delta := baseline - testval
percent := delta * 100 / baseline
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
if i == from {
initPos = pos
}
......
......@@ -388,12 +388,12 @@ func max(x, y int) int {
}
// karatsubaLen computes an approximation to the maximum k <= n such that
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before
// becoming about the value of karatsubaThreshold.
func karatsubaLen(n int) int {
// becoming about the value of threshold.
func karatsubaLen(n, threshold int) int {
i := uint(0)
for n > karatsubaThreshold {
for n > threshold {
n >>= 1
i++
}
......@@ -433,7 +433,7 @@ func (z nat) mul(x, y nat) nat {
// y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
k := karatsubaLen(n, karatsubaThreshold)
// k <= n
// multiply x0 and y0 via Karatsuba
......@@ -486,8 +486,8 @@ func (z nat) mul(x, y nat) nat {
// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) >= 2*len(x)
// The (non-normalized) result is placed in z[0 : 2 * len(x)].
// Requirements: len(x) > 0, len(z) == 2*len(x)
// The (non-normalized) result is placed in z.
func basicSqr(z, x nat) {
n := len(x)
t := make(nat, 2*n) // temporary variable to hold the products
......@@ -503,11 +503,48 @@ func basicSqr(z, x nat) {
addVV(z, z, t) // combine the result
}
// karatsubaSqr squares x and leaves the result in z.
// len(x) must be a power of 2 and len(z) >= 6*len(x).
// The (non-normalized) result is placed in z[0 : 2*len(x)].
//
// The algorithm and the layout of z are the same as for karatsuba.
func karatsubaSqr(z, x nat) {
n := len(x)
if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
z = z[:2*n]
basicSqr(z, x)
return
}
n2 := n >> 1
x1, x0 := x[n2:], x[0:n2]
karatsubaSqr(z, x0)
karatsubaSqr(z[n:], x1)
// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
xd := z[2*n : 2*n+n2]
if subVV(xd, x1, x0) != 0 {
subVV(xd, x0, x1)
}
p := z[n*3:]
karatsubaSqr(p, xd)
r := z[n*4:]
copy(r, z[:n*2])
karatsubaAdd(z[n2:], r, n)
karatsubaAdd(z[n2:], r[n:], n)
karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
}
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// the Karatsuba algorithm is used.
// we use the Karatsuba algorithm optimized for x == y.
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
// z = x*x
func (z nat) sqr(x nat) nat {
......@@ -536,7 +573,31 @@ func (z nat) sqr(x nat) nat {
return z.norm()
}
return z.mul(x, x)
// Use Karatsuba multiplication optimized for x == y.
// The algorithm and layout of z are the same as for mul.
// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
k := karatsubaLen(n, karatsubaSqrThreshold)
x0 := x[0:k]
z = z.make(max(6*k, 2*n))
karatsubaSqr(z, x0) // z = x0^2
z = z[0 : 2*n]
z[2*k:].clear()
if k < n {
var t nat
x0 := x0.norm()
x1 := x[k:]
t = t.mul(x0, x1)
addAt(z, t, k)
addAt(z, t, k) // z = 2*x1*x0*b + x0^2
t = t.sqr(x1)
addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
}
return z.norm()
}
// mulRange computes the product of all the unsigned integers in the
......
......@@ -648,26 +648,26 @@ func TestSticky(t *testing.T) {
}
}
func testBasicSqr(t *testing.T, x nat) {
func testSqr(t *testing.T, x nat) {
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
basicSqr(got, x)
basicMul(want, x, x)
got = got.sqr(x)
want = want.mul(x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
}
func TestBasicSqr(t *testing.T) {
func TestSqr(t *testing.T) {
for _, a := range prodNN {
if a.x != nil {
testBasicSqr(t, a.x)
testSqr(t, a.x)
}
if a.y != nil {
testBasicSqr(t, a.y)
testSqr(t, a.y)
}
if a.z != nil {
testBasicSqr(t, a.z)
testSqr(t, a.z)
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment