大整數乘法中的分治思想(TOOM-COOK的一種使用方法)
演算法分析與設計學習中,接觸到一道大整數乘法問題,分享出來,原題目如下:
演算法分析在用分治法求兩個n位大整數u和v的乘積時,將u和v都分割為長度為n/3的3段。證明可以用5次n/3位整數的乘法求得uv的值。按此思想設計大整數乘積的分治方法,並分析演算法的計算複雜性。
先參考一道較為簡單的題目:設有兩個n位二進位制數X,Y,求它們的乘積XY。 分析:按照一般演算法,根據小學數學乘法規律,兩個數中每位數都需要相應做乘法,則需要的時間複雜度是O(n^2)。
此時,我們考慮將X,Y分為高位、低位,即 X = | A| B | , Y = | C| D| A,B,C,D均為n/2位,求XY的問題可以轉換為:
X * Y = (A * 2^(n/2) +B ) *(C * 2^(n/2) +D )
= AC * 2^n + (AD + BC) * 2^(n/2)+BD
則,可以看出利用分治法後,進行了4次n/2位的乘法運算。但是並沒有涉及演算法效能優化,時間複雜度依然是O(n^2)。
這個演算法進行效能提升可以使用如下方法: 先計算
U = (A + B)(C + D), V = AC, W = BD
則 Z = XY = V * 2^n +(U - V- W ) * 2^(n/2) + W
上面過程中,由於只使用了U、V、W涉及的3次乘法,比沒有優化的少了一種,時間複雜度就降低到了
Karatsuba演算法虛擬碼實現如下
procedure karatsuba(num1, num2) if (num1 < 10) or (num2 < 10) return num1*num2 /* calculates the size of the numbers */ m = max(size_base10(num1), size_base10(num2)) m2 = m/2 /* split the digit sequences about the middle */ high1, low1 = split_at(num1, m2) high2, low2 = split_at(num2, m2) /* 3 calls made to numbers approximately half the size */ z0 = karatsuba(low1,low2) z1 = karatsuba((low1+high1),(low2+high2)) z2 = karatsuba(high1,high2) return (z2*10^(2*m2))+((z1-z2-z0)*10^(m2))+(z0)
具體例子可以參照這個圖片
那參照上述步驟,在解決最開始提到的5次n/3位整數乘法中也可以設法將兩個大整數U、V分割為
U = |A| B| C V =|D| E| F
然後利用(A +B +C )(D +E +F)的方法,在分別細分合並其中幾項,來簡化複雜度,以求達到5次乘法的要求。
PS,這個方法是一開始我的想法,實現後發現只能簡化到6次乘法實現,所以最後尋找別的演算法。
最後發現有TOOM-COOK方法來做這個工作,而題目涉及的分為3等份的過程就是TOOM-COOK中當n = 3 的特例,也稱為TOOM3。
那先來看看TOOM-COOK的一般實現,即不規定分為多少份,設分為m份。
設有兩個大整數U,V,利用分治思維將U、V分為如下部分
U = |U-(m-1)|……|U2 |U1 |U0
V =|V-(m-1)|……|V2 |V1 |V0
設X = 10^(n/m)
則可以將u和v及其乘積w=uv表示為
將U,V和W都看作關於變數X的多項式,可以得到
此時我們可以取2m-1個不同的數x1,x2,…,x2m-1代入上多項式,可得W(xi) 與 W0,W1,W2……W(2m-2)的關係,轉換為矩陣為
那設B為紅框部分的矩陣可以推得
此時,取x1、 x2、 x3、 x4…. 為不同的數,再結合以下兩個式子, 可以得出W0 、W1、 W2、 W3、 W4與U0、 U1、U2 ,V0、 V1、 V2 的關係 最後移位相加得到最終結果,移位相加可以參考如下的圖,因為設的是10進位制的數,所以每次移位就是10^(n/m)
那其實說到這裡,相信有人已經明白了TOOM-COOK演算法的整體核心思想,整體實現流程可以簡單概括為
那回歸到本題,在具體的3等分n位大整數U、V問題中,可以詳細描述為以下細節 首先,分治劃分U、V
U = |U2 |U1 |U0 V =|V2 |V1 |V0
同樣的,設X = 10^(n/3),那U、V及其乘積W = UV可以表示為 分別取X為5個不同的數,即
X1 = 0,X2 = 1,X3 = -1, X4 = 2,X5 =-2
代入多項式中 第一個結論,是W(Xi)與U0,U1,U2,U3,V0,V1,V2的關係,暫且將5種X取值下的W(Xi)標記為a,b,c,d,e 第二個結論,是W(Xi)即a,b,c,d,e與W = UV中各項引數W0,W1,W2,W3,W4的,其中可以看到矩陣B(TOOM-COOK一般方法中提到) 至此,建立了W0,W1,W2,W3,W4與U1、U2 ,V0、 V1、 V2 的關係,經過運算後,可以求出兩個式子 ① ② 再帶入W=UV=W0+W1X+W2X2 +W3X3+W4X4 中,可以求出W 到了這一步,相信大家都已經對這道題有了更深的理解,下面我們來看一道具體的應用。
123456*987654
按照程式實現流程,劃分UV後,分別求得a,b,c,d,e以及W0,W1,W2,W3,W4如下
UV=W0+W1X+W2X2+W3X3+W4X4即W0, W1 , W2, W3, W4移位相加可得W結果 對比運算,可知結果無誤
最後再說一下時間複雜度,a,b,c,d,e所涉及的一共5次乘法,加減不計,所以得到的遞迴方程如下: 求解得複雜度為
TOOM-COOK實現思路和演算法可以參考大數乘法問題
/**
* Java8中的 Toom-Cook multiplication 3路乘法
*/
private static BigInteger multiplyToomCook3(BigInteger a, BigInteger b) {
int alen = a.mag.length;
int blen = b.mag.length;
int largest = Math.max(alen, blen);
// k is the size (in ints) of the lower-order slices.
int k = (largest+2)/3; // Equal to ceil(largest/3)
// r is the size (in ints) of the highest-order slice.
int r = largest - 2*k;
// Obtain slices of the numbers. a2 and b2 are the most significant
// bits of the numbers a and b, and a0 and b0 the least significant.
BigInteger a0, a1, a2, b0, b1, b2;
a2 = a.getToomSlice(k, r, 0, largest);
a1 = a.getToomSlice(k, r, 1, largest);
a0 = a.getToomSlice(k, r, 2, largest);
b2 = b.getToomSlice(k, r, 0, largest);
b1 = b.getToomSlice(k, r, 1, largest);
b0 = b.getToomSlice(k, r, 2, largest);
BigInteger v0, v1, v2, vm1, vinf, t1, t2, tm1, da1, db1;
v0 = a0.multiply(b0);
da1 = a2.add(a0);
db1 = b2.add(b0);
vm1 = da1.subtract(a1).multiply(db1.subtract(b1));
da1 = da1.add(a1);
db1 = db1.add(b1);
v1 = da1.multiply(db1);
v2 = da1.add(a2).shiftLeft(1).subtract(a0).multiply(
db1.add(b2).shiftLeft(1).subtract(b0));
vinf = a2.multiply(b2);
// The algorithm requires two divisions by 2 and one by 3.
// All divisions are known to be exact, that is, they do not produce
// remainders, and all results are positive. The divisions by 2 are
// implemented as right shifts which are relatively efficient, leaving
// only an exact division by 3, which is done by a specialized
// linear-time algorithm.
t2 = v2.subtract(vm1).exactDivideBy3();
tm1 = v1.subtract(vm1).shiftRight(1);
t1 = v1.subtract(v0);
t2 = t2.subtract(t1).shiftRight(1);
t1 = t1.subtract(tm1).subtract(vinf);
t2 = t2.subtract(vinf.shiftLeft(1));
tm1 = tm1.subtract(t2);
// Number of bits to shift left.
int ss = k*32;
BigInteger result = vinf.shiftLeft(ss).add(t2).shiftLeft(ss).add(t1).shiftLeft(ss).add(tm1).shiftLeft(ss).add(v0);
if (a.signum != b.signum) {
return result.negate();
} else {
return result;
}
}