分治法举例之矩阵乘法

分治法举例之矩阵乘法

前言

矩阵按定义直接实现是比较直接简单的。时间复杂度也可以直接得出来是(O(n^3))

直接实现

class matrix:
    '''
    为了简单起见,没有对数据做校验。
    假设矩阵就是nxn的。
    '''
    def __init__(self, data):
        if not data or not hasattr(data, '__getitem__'):
            raise ValueError("data not valid! %s" % data)
        self.data = data
        self.rows = len(data)
        self.cols = max(map(lambda row: len(row), data))

    def __mul__(self, another):
        if self.cols != another.rows:
            raise ValueError("not valid ddata ,only support mxn * nxp")
        ret = matrix([[0 for _ in range(another.cols)] for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(another.cols):
                num = 0
                for k in range(self.cols):
                    num += self._getitem(i, k) * another._getitem(k, j)
                ret._setitem(i, j, num)
        return ret

    def _getitem(self, i, j):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
                i, j, self.data))
        try:
            return self.data[i][j]
        except Exception:
            return 0

    def _setitem(self, i, j, value):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
                i, j, str(value), self.data))
        if j >= len(self.data[i]):
            fill = self.cols - len(self.data[i])
            self.data[i].extend([0 for _ in range(fill)])
        self.data[i][j] = value

    def __str__(self):
        return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)

直接应用分治法

很简单的想法是把矩阵分成n/2的4块。于是
A*B = C
变成
[
\begin{pmatrix}
a & b \
c & d
\end{pmatrix}
*
\begin{pmatrix}
e & f \
g & h

\end{pmatrix}

\begin{pmatrix}
r & s \
t & u
\end{pmatrix}
]

r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh

于是有(T(n) = 8T(n/2)+O(n^2))
(O(n^{log_ba}) = O(n^{log_28})=O(n^3) > f(n))
满足主定理的第一种情况,于是(T(n) = \Theta(n^3)),并没有比直接实现快。

斯特拉森算法

斯特拉森于1969年提出的算法,运用分治策略并加上一些处理技巧设计出的一种矩阵乘法。
他巧妙地8变成了7。于是达到了(T(n) = \Theta(n^{log_27})\approx\Theta(n^{2.81}))
这还不是目前理论上最好的,暂时最好的达到了(T(n) =\Theta(n^{2.376}))

看一下他是怎么玩的。

P1 = (a+d) * (e+h)
P2 = (c+d) * e
P3 = a * (f-h)
P4 = d * (g-e)
P5 = (a+b) * h
P6 = (c-a) * (e+f)
P7 = (b-d) * (g+h)

利用此7个式子即可得到原来的r,s,t,u
r = P1 + P4 – P5 + P7
s = P3 + P5
t = P2 + P4
u = P1 + P3 -P2 + P6
验证一下u看看

u = P1 + P3 -P2 + P6
= (a+d) * (e+h) + a * (f-h) -((c+d) * e) + (c-a) * (e+f)
=ae + ah + de + dh + af - ah -ce - de + ce + cf -ae -af
= dh + cf

正确

代码实现如下:

#!/usr/bin/env python
from enum import Enum, IntEnum, unique
import sys

class matrix:
    '''
    为了简单起见,没有对数据做校验。
    假设矩阵就是nxn的。
    '''
    def __init__(self, data):
        if not data or not hasattr(data, '__getitem__'):
            raise ValueError("data not valid! %s" % data)
        self.data = data
        self.rows = len(data)
        self.cols = max(map(lambda row: len(row), data))
        if self.rows != self.cols:
            raise ValueError("only support nxn matrix, and n can continue divide by 2 util 1")

    def __add__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(self.rows):
                ret._setitem(i, j, self._getitem(i, j) + another._getitem(i, j))

        return ret

    def __sub__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        for i in range(self.rows):
            for j in range(self.rows):
                ret._setitem(i, j, self._getitem(i, j) - another._getitem(i, j))

        return ret

    def __mul__(self, another):
        if self.rows != another.rows:
            raise ValueError("not valid ddata ,only support nxn * nxn")
        ret = matrix([[0]*self.rows for _ in range(self.rows)])
        if self.rows == 2:
            for i in range(self.rows):
                for j in range(another.cols):
                    num = 0
                    for k in range(self.cols):
                        num += self._getitem(i, k) * another._getitem(k, j)
                    ret._setitem(i, j, num)
        else:
            a = self._divide(matrix.DIRECTION.LEFT_TOP)
            b = self._divide(matrix.DIRECTION.RIGHT_TOP)
            c = self._divide(matrix.DIRECTION.LEFT_BOTTOM)
            d = self._divide(matrix.DIRECTION.RIGHT_BOTTOM)

            e = another._divide(matrix.DIRECTION.LEFT_TOP)
            f = another._divide(matrix.DIRECTION.RIGHT_TOP)
            g = another._divide(matrix.DIRECTION.LEFT_BOTTOM)
            h = another._divide(matrix.DIRECTION.RIGHT_BOTTOM)

            p1 = (a+d)*(e+h)
            p2 = (c+d)*e
            p3 = a * (f-h)
            p4 = d * (g-e)
            p5 = (a+b)*h
            p6 = (c-a)*(e+f)
            p7 = (b-d)*(g+h)

            r = p1 + p4 - p5 + p7
            s = p3 + p5
            t = p2 + p4
            u = p1 + p3 - p2 + p6

            ret._merge(matrix.DIRECTION.LEFT_TOP, r)
            ret._merge(matrix.DIRECTION.RIGHT_TOP, s)
            ret._merge(matrix.DIRECTION.LEFT_BOTTOM, t)
            ret._merge(matrix.DIRECTION.RIGHT_BOTTOM, u)

        return ret

    @unique
    class DIRECTION (IntEnum):
        LEFT_TOP = 1
        LEFT_BOTTOM = 2
        RIGHT_TOP = 3
        RIGHT_BOTTOM = 4

    def _divide(self, direction):
        ret = matrix([[0]*int(self.rows/2) for _ in range(int(self.rows/2))])
        row_start = col_start = 0
        if direction == matrix.DIRECTION.LEFT_TOP:
            row_start = 0
            col_start = 0
        elif direction == matrix.DIRECTION.LEFT_BOTTOM:
            row_start = int(self.rows/2)
            col_start = 0
        elif direction == matrix.DIRECTION.RIGHT_TOP:
            row_start = 0
            col_start = int(self.cols/2)
        else:
            row_start = int(self.rows/2)
            col_start = int(self.cols/2)

        for i in range(ret.rows):
            for j in range(ret.cols):
                item = self._getitem(i+row_start, j+col_start)
                ret._setitem(i, j, item)

        return ret

    def _merge(self, direction, another):
        row_start = col_start = 0
        if direction == matrix.DIRECTION.LEFT_TOP:
            row_start = 0
            col_start = 0
        elif direction == matrix.DIRECTION.LEFT_BOTTOM:
            row_start = int(self.rows/2)
            col_start = 0
        elif direction == matrix.DIRECTION.RIGHT_TOP:
            row_start = 0
            col_start = int(self.cols/2)
        else:
            row_start = int(self.rows/2)
            col_start = int(self.cols/2)

        for i in range(another.rows):
            for j in range(another.cols):
                item = another._getitem(i, j)
                self._setitem(i+row_start, j+col_start, item)



    def _getitem(self, i, j):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
                i, j, self.data))
        try:
            return self.data[i][j]
        except Exception:
            return 0

    def _setitem(self, i, j, value):
        if i >= self.rows or j >= self.cols:
            raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
                i, j, str(value), self.data))
        if j >= len(self.data[i]):
            fill = self.cols - len(self.data[i])
            self.data[i].extend([0 for _ in range(fill)])
        self.data[i][j] = value

    def __str__(self):
        return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)

测试结果

方法 规模 时间
直接计算 8×8 0.054
拉特斯森 8×8 0.095
直接计算 16×16 0.063
拉特斯森 16×16 0.117
直接计算 32×32 0.090
拉特斯森 32×32 0.454
直接计算 64×64 0.419
拉特斯森 64×64 2.953
直接计算 128×128 2.946
拉特斯森 128×128 20.547
直接计算 256×256 24.835
拉特斯森 256×256 2:15.15
直接计算 512×512 3:15.98

总结

估计是我实现的问题,比预期结果要差,看视频里说的是到32就差不多了。
不过从上也可以看出来,拉特斯森增长的速度没有直接计算的快,迟早性能会更好。

另外,直观上感觉矩阵乘法是没法优化。但事实上可以。
这说明真没有那么多做不到的事,可能只是现在的你做不来,说不定有人能做到,说不定做到的那个人就是未来的你。

发表评论

电子邮件地址不会被公开。 必填项已用*标注