分治法举例之矩阵乘法
分治法举例之矩阵乘法
前言
矩阵按定义直接实现是比较直接简单的。时间复杂度也可以直接得出来是(O(n^3))
直接实现
class matrix:
'''
为了简单起见,没有对数据做校验。
假设矩阵就是nxn的。
'''
def __init__(self, data):
if not data or not hasattr(data, '__getitem__'):
raise ValueError("data not valid! %s" % data)
self.data = data
self.rows = len(data)
self.cols = max(map(lambda row: len(row), data))
def __mul__(self, another):
if self.cols != another.rows:
raise ValueError("not valid ddata ,only support mxn * nxp")
ret = matrix([[0 for _ in range(another.cols)] for _ in range(self.rows)])
for i in range(self.rows):
for j in range(another.cols):
num = 0
for k in range(self.cols):
num += self._getitem(i, k) * another._getitem(k, j)
ret._setitem(i, j, num)
return ret
def _getitem(self, i, j):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
i, j, self.data))
try:
return self.data[i][j]
except Exception:
return 0
def _setitem(self, i, j, value):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
i, j, str(value), self.data))
if j >= len(self.data[i]):
fill = self.cols - len(self.data[i])
self.data[i].extend([0 for _ in range(fill)])
self.data[i][j] = value
def __str__(self):
return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)
直接应用分治法
很简单的想法是把矩阵分成n/2的4块。于是
A*B = C
变成
[
\begin{pmatrix}
a & b \
c & d
\end{pmatrix}
*
\begin{pmatrix}
e & f \
g & h
\end{pmatrix}
\begin{pmatrix}
r & s \
t & u
\end{pmatrix}
]
r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh
于是有(T(n) = 8T(n/2)+O(n^2))
(O(n^{log_ba}) = O(n^{log_28})=O(n^3) > f(n))
满足主定理的第一种情况,于是(T(n) = \Theta(n^3)),并没有比直接实现快。
斯特拉森算法
斯特拉森于1969年提出的算法,运用分治策略并加上一些处理技巧设计出的一种矩阵乘法。
他巧妙地8变成了7。于是达到了(T(n) = \Theta(n^{log_27})\approx\Theta(n^{2.81}))
这还不是目前理论上最好的,暂时最好的达到了(T(n) =\Theta(n^{2.376}))
看一下他是怎么玩的。
P1 = (a+d) * (e+h)
P2 = (c+d) * e
P3 = a * (f-h)
P4 = d * (g-e)
P5 = (a+b) * h
P6 = (c-a) * (e+f)
P7 = (b-d) * (g+h)
利用此7个式子即可得到原来的r,s,t,u
r = P1 + P4 – P5 + P7
s = P3 + P5
t = P2 + P4
u = P1 + P3 -P2 + P6
验证一下u看看
u = P1 + P3 -P2 + P6
= (a+d) * (e+h) + a * (f-h) -((c+d) * e) + (c-a) * (e+f)
=ae + ah + de + dh + af - ah -ce - de + ce + cf -ae -af
= dh + cf
正确
代码实现如下:
#!/usr/bin/env python
from enum import Enum, IntEnum, unique
import sys
class matrix:
'''
为了简单起见,没有对数据做校验。
假设矩阵就是nxn的。
'''
def __init__(self, data):
if not data or not hasattr(data, '__getitem__'):
raise ValueError("data not valid! %s" % data)
self.data = data
self.rows = len(data)
self.cols = max(map(lambda row: len(row), data))
if self.rows != self.cols:
raise ValueError("only support nxn matrix, and n can continue divide by 2 util 1")
def __add__(self, another):
if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
for i in range(self.rows):
for j in range(self.rows):
ret._setitem(i, j, self._getitem(i, j) + another._getitem(i, j))
return ret
def __sub__(self, another):
if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
for i in range(self.rows):
for j in range(self.rows):
ret._setitem(i, j, self._getitem(i, j) - another._getitem(i, j))
return ret
def __mul__(self, another):
if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
if self.rows == 2:
for i in range(self.rows):
for j in range(another.cols):
num = 0
for k in range(self.cols):
num += self._getitem(i, k) * another._getitem(k, j)
ret._setitem(i, j, num)
else:
a = self._divide(matrix.DIRECTION.LEFT_TOP)
b = self._divide(matrix.DIRECTION.RIGHT_TOP)
c = self._divide(matrix.DIRECTION.LEFT_BOTTOM)
d = self._divide(matrix.DIRECTION.RIGHT_BOTTOM)
e = another._divide(matrix.DIRECTION.LEFT_TOP)
f = another._divide(matrix.DIRECTION.RIGHT_TOP)
g = another._divide(matrix.DIRECTION.LEFT_BOTTOM)
h = another._divide(matrix.DIRECTION.RIGHT_BOTTOM)
p1 = (a+d)*(e+h)
p2 = (c+d)*e
p3 = a * (f-h)
p4 = d * (g-e)
p5 = (a+b)*h
p6 = (c-a)*(e+f)
p7 = (b-d)*(g+h)
r = p1 + p4 - p5 + p7
s = p3 + p5
t = p2 + p4
u = p1 + p3 - p2 + p6
ret._merge(matrix.DIRECTION.LEFT_TOP, r)
ret._merge(matrix.DIRECTION.RIGHT_TOP, s)
ret._merge(matrix.DIRECTION.LEFT_BOTTOM, t)
ret._merge(matrix.DIRECTION.RIGHT_BOTTOM, u)
return ret
@unique
class DIRECTION (IntEnum):
LEFT_TOP = 1
LEFT_BOTTOM = 2
RIGHT_TOP = 3
RIGHT_BOTTOM = 4
def _divide(self, direction):
ret = matrix([[0]*int(self.rows/2) for _ in range(int(self.rows/2))])
row_start = col_start = 0
if direction == matrix.DIRECTION.LEFT_TOP:
row_start = 0
col_start = 0
elif direction == matrix.DIRECTION.LEFT_BOTTOM:
row_start = int(self.rows/2)
col_start = 0
elif direction == matrix.DIRECTION.RIGHT_TOP:
row_start = 0
col_start = int(self.cols/2)
else:
row_start = int(self.rows/2)
col_start = int(self.cols/2)
for i in range(ret.rows):
for j in range(ret.cols):
item = self._getitem(i+row_start, j+col_start)
ret._setitem(i, j, item)
return ret
def _merge(self, direction, another):
row_start = col_start = 0
if direction == matrix.DIRECTION.LEFT_TOP:
row_start = 0
col_start = 0
elif direction == matrix.DIRECTION.LEFT_BOTTOM:
row_start = int(self.rows/2)
col_start = 0
elif direction == matrix.DIRECTION.RIGHT_TOP:
row_start = 0
col_start = int(self.cols/2)
else:
row_start = int(self.rows/2)
col_start = int(self.cols/2)
for i in range(another.rows):
for j in range(another.cols):
item = another._getitem(i, j)
self._setitem(i+row_start, j+col_start, item)
def _getitem(self, i, j):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
i, j, self.data))
try:
return self.data[i][j]
except Exception:
return 0
def _setitem(self, i, j, value):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
i, j, str(value), self.data))
if j >= len(self.data[i]):
fill = self.cols - len(self.data[i])
self.data[i].extend([0 for _ in range(fill)])
self.data[i][j] = value
def __str__(self):
return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)
测试结果
方法 | 规模 | 时间 |
---|---|---|
直接计算 | 8×8 | 0.054 |
拉特斯森 | 8×8 | 0.095 |
直接计算 | 16×16 | 0.063 |
拉特斯森 | 16×16 | 0.117 |
直接计算 | 32×32 | 0.090 |
拉特斯森 | 32×32 | 0.454 |
直接计算 | 64×64 | 0.419 |
拉特斯森 | 64×64 | 2.953 |
直接计算 | 128×128 | 2.946 |
拉特斯森 | 128×128 | 20.547 |
直接计算 | 256×256 | 24.835 |
拉特斯森 | 256×256 | 2:15.15 |
直接计算 | 512×512 | 3:15.98 |
总结
估计是我实现的问题,比预期结果要差,看视频里说的是到32就差不多了。
不过从上也可以看出来,拉特斯森增长的速度没有直接计算的快,迟早性能会更好。
另外,直观上感觉矩阵乘法是没法优化。但事实上可以。
这说明真没有那么多做不到的事,可能只是现在的你做不来,说不定有人能做到,说不定做到的那个人就是未来的你。