# 直接实现

class matrix:
'''
为了简单起见，没有对数据做校验。
假设矩阵就是nxn的。
'''
def __init__(self, data):
if not data or not hasattr(data, '__getitem__'):
raise ValueError("data not valid! %s" % data)
self.data = data
self.rows = len(data)
self.cols = max(map(lambda row: len(row), data))

def __mul__(self, another):
if self.cols != another.rows:
raise ValueError("not valid ddata ,only support mxn * nxp")
ret = matrix([[0 for _ in range(another.cols)] for _ in range(self.rows)])
for i in range(self.rows):
for j in range(another.cols):
num = 0
for k in range(self.cols):
num += self._getitem(i, k) * another._getitem(k, j)
ret._setitem(i, j, num)
return ret

def _getitem(self, i, j):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
i, j, self.data))
try:
return self.data[i][j]
except Exception:
return 0

def _setitem(self, i, j, value):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
i, j, str(value), self.data))
if j >= len(self.data[i]):
fill = self.cols - len(self.data[i])
self.data[i].extend([0 for _ in range(fill)])
self.data[i][j] = value

def __str__(self):
return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)


A*B = C

[
\begin{pmatrix}
a & b \
c & d
\end{pmatrix}
*
\begin{pmatrix}
e & f \
g & h

# \end{pmatrix}

\begin{pmatrix}
r & s \
t & u
\end{pmatrix}
]

r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh

(O(n^{log_ba}) = O(n^{log_28})=O(n^3) > f(n))

# 斯特拉森算法

P1 = (a+d) * (e+h)
P2 = (c+d) * e
P3 = a * (f-h)
P4 = d * (g-e)
P5 = (a+b) * h
P6 = (c-a) * (e+f)
P7 = (b-d) * (g+h)


r = P1 + P4 – P5 + P7
s = P3 + P5
t = P2 + P4
u = P1 + P3 -P2 + P6

u = P1 + P3 -P2 + P6
= (a+d) * (e+h) + a * (f-h) -((c+d) * e) + (c-a) * (e+f)
=ae + ah + de + dh + af - ah -ce - de + ce + cf -ae -af
= dh + cf



#!/usr/bin/env python
from enum import Enum, IntEnum, unique
import sys

class matrix:
'''
为了简单起见，没有对数据做校验。
假设矩阵就是nxn的。
'''
def __init__(self, data):
if not data or not hasattr(data, '__getitem__'):
raise ValueError("data not valid! %s" % data)
self.data = data
self.rows = len(data)
self.cols = max(map(lambda row: len(row), data))
if self.rows != self.cols:
raise ValueError("only support nxn matrix, and n can continue divide by 2 util 1")

if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
for i in range(self.rows):
for j in range(self.rows):
ret._setitem(i, j, self._getitem(i, j) + another._getitem(i, j))

return ret

def __sub__(self, another):
if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
for i in range(self.rows):
for j in range(self.rows):
ret._setitem(i, j, self._getitem(i, j) - another._getitem(i, j))

return ret

def __mul__(self, another):
if self.rows != another.rows:
raise ValueError("not valid ddata ,only support nxn * nxn")
ret = matrix([[0]*self.rows for _ in range(self.rows)])
if self.rows == 2:
for i in range(self.rows):
for j in range(another.cols):
num = 0
for k in range(self.cols):
num += self._getitem(i, k) * another._getitem(k, j)
ret._setitem(i, j, num)
else:
a = self._divide(matrix.DIRECTION.LEFT_TOP)
b = self._divide(matrix.DIRECTION.RIGHT_TOP)
c = self._divide(matrix.DIRECTION.LEFT_BOTTOM)
d = self._divide(matrix.DIRECTION.RIGHT_BOTTOM)

e = another._divide(matrix.DIRECTION.LEFT_TOP)
f = another._divide(matrix.DIRECTION.RIGHT_TOP)
g = another._divide(matrix.DIRECTION.LEFT_BOTTOM)
h = another._divide(matrix.DIRECTION.RIGHT_BOTTOM)

p1 = (a+d)*(e+h)
p2 = (c+d)*e
p3 = a * (f-h)
p4 = d * (g-e)
p5 = (a+b)*h
p6 = (c-a)*(e+f)
p7 = (b-d)*(g+h)

r = p1 + p4 - p5 + p7
s = p3 + p5
t = p2 + p4
u = p1 + p3 - p2 + p6

ret._merge(matrix.DIRECTION.LEFT_TOP, r)
ret._merge(matrix.DIRECTION.RIGHT_TOP, s)
ret._merge(matrix.DIRECTION.LEFT_BOTTOM, t)
ret._merge(matrix.DIRECTION.RIGHT_BOTTOM, u)

return ret

@unique
class DIRECTION (IntEnum):
LEFT_TOP = 1
LEFT_BOTTOM = 2
RIGHT_TOP = 3
RIGHT_BOTTOM = 4

def _divide(self, direction):
ret = matrix([[0]*int(self.rows/2) for _ in range(int(self.rows/2))])
row_start = col_start = 0
if direction == matrix.DIRECTION.LEFT_TOP:
row_start = 0
col_start = 0
elif direction == matrix.DIRECTION.LEFT_BOTTOM:
row_start = int(self.rows/2)
col_start = 0
elif direction == matrix.DIRECTION.RIGHT_TOP:
row_start = 0
col_start = int(self.cols/2)
else:
row_start = int(self.rows/2)
col_start = int(self.cols/2)

for i in range(ret.rows):
for j in range(ret.cols):
item = self._getitem(i+row_start, j+col_start)
ret._setitem(i, j, item)

return ret

def _merge(self, direction, another):
row_start = col_start = 0
if direction == matrix.DIRECTION.LEFT_TOP:
row_start = 0
col_start = 0
elif direction == matrix.DIRECTION.LEFT_BOTTOM:
row_start = int(self.rows/2)
col_start = 0
elif direction == matrix.DIRECTION.RIGHT_TOP:
row_start = 0
col_start = int(self.cols/2)
else:
row_start = int(self.rows/2)
col_start = int(self.cols/2)

for i in range(another.rows):
for j in range(another.cols):
item = another._getitem(i, j)
self._setitem(i+row_start, j+col_start, item)

def _getitem(self, i, j):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d, %s" % (
i, j, self.data))
try:
return self.data[i][j]
except Exception:
return 0

def _setitem(self, i, j, value):
if i >= self.rows or j >= self.cols:
raise IndexError("index out of boundary,i=%d,j=%d,value=%s, %s" % (
i, j, str(value), self.data))
if j >= len(self.data[i]):
fill = self.cols - len(self.data[i])
self.data[i].extend([0 for _ in range(fill)])
self.data[i][j] = value

def __str__(self):
return "(rows:%d, cols:%d)->%s" % (self.rows, self.cols, self.data)


Tags: