1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
''' 矩阵乘法的 Strassen 算法 '''
def matrix_add(matrix_a:list, matrix_b:list) -> list: ''' 矩阵加法 :param matrix_a: 加数 :param matrix_b: 加数 :return: 相加后的结果矩阵 matrix_c ''' rows = len(matrix_a) columns = len(matrix_a[0]) matrix_c = [list() for i in range(rows)] for i in range(rows): for j in range(columns): tmp = matrix_a[i][j] + matrix_b[i][j] matrix_c[i].append(tmp) return matrix_c
def matrix_minus(matrix_a:list, matrix_b:list) -> list: ''' 矩阵减法 :param matrix_a: 被减数 :param matrix_b: 减数 :return: 相减后的结果矩阵 matrix_c ''' rows = len(matrix_a) columns = len(matrix_a[0]) matrix_c = [list() for i in range(rows)] for i in range(rows): for j in range(columns): tmp = matrix_a[i][j] - matrix_b[i][j] matrix_c[i].append(tmp) return matrix_c
def matrix_divide(matrix:list, row:int, column:int) -> list: ''' 分离一个子矩阵(四分之一)出来,注意,切割出来的子矩阵的边长是父矩阵的边长的一半 :param matrix: 父矩阵 :param row: 行的开始索引,row 的可能取值为 {1, 2} :param column: 列的开始索引,column 的可能取值为 {1, 2} :return: 切割好的矩阵 ''' rows = len(matrix) matrix_sub = [list() for i in range(rows // 2)] k = 0 for i in range((row - 1) * rows // 2, row * rows // 2): for j in range((column - 1) * rows // 2, column * rows // 2): tmp = matrix[i][j] matrix_sub[k].append(tmp) k += 1 return matrix_sub
def matrix_merge(matrix_11:list, matrix_12:list, matrix_21:list, matrix_22:list) -> list: ''' 合并四个子矩阵 :param matrix_11: 左上角的子矩阵 :param matrix_12: 右上角的子矩阵 :param matrix_21: 左下角的子矩阵 :param matrix_22: 右下角的子矩阵 :return: 合并之后的矩阵 ''' rows = len(matrix_11) matrix_all = [list() for i in range(rows * 2)] for i in range(rows): matrix_all[i] = matrix_11[i] + matrix_12[i] for j in range(rows): matrix_all[rows + j] = matrix_21[j] + matrix_22[j] return matrix_all
def strassen(matrix_a:list, matrix_b:list) -> list: ''' Strassen 算法计算矩阵的乘法 :param matrix_a: 待乘矩阵 :param matrix_b: 待乘矩阵 :return: 结果矩阵 ''' rows = len(matrix_a) if rows == 1: matrix_all = [list() for i in range(rows)] matrix_all[0].append(matrix_a[0][0] * matrix_b[0][0]) else: s1 = matrix_minus(matrix_divide(matrix_b, 1, 2), matrix_divide(matrix_b, 2, 2)) s2 = matrix_add(matrix_divide(matrix_a, 1, 1), matrix_divide(matrix_a, 1, 2)) s3 = matrix_add(matrix_divide(matrix_a, 2, 1), matrix_divide(matrix_a, 2, 2)) s4 = matrix_minus(matrix_divide(matrix_b, 2, 1), matrix_divide(matrix_b, 1, 1)) s5 = matrix_add(matrix_divide(matrix_a, 1, 1), matrix_divide(matrix_a, 2, 2)) s6 = matrix_add(matrix_divide(matrix_b, 1, 1), matrix_divide(matrix_b, 2, 2)) s7 = matrix_minus(matrix_divide(matrix_a, 1, 2), matrix_divide(matrix_a, 2, 2)) s8 = matrix_add(matrix_divide(matrix_b, 2, 1), matrix_divide(matrix_b, 2, 2)) s9 = matrix_minus(matrix_divide(matrix_a, 1, 1), matrix_divide(matrix_a, 2, 1)) s10 = matrix_add(matrix_divide(matrix_b, 1, 1), matrix_divide(matrix_b, 1, 2)) p1 = strassen(matrix_divide(matrix_a, 1, 1), s1) p2 = strassen(s2, matrix_divide(matrix_b, 2, 2)) p3 = strassen(s3, matrix_divide(matrix_b, 1, 1)) p4 = strassen(matrix_divide(matrix_a, 2, 2), s4) p5 = strassen(s5, s6) p6 = strassen(s7, s8) p7 = strassen(s9, s10) c11 = matrix_add(matrix_add(p5, p4), matrix_minus(p6, p2)) c12 = matrix_add(p1, p2) c21 = matrix_add(p3, p4) c22 = matrix_add(matrix_minus(p5, p3), matrix_minus(p1, p7)) matrix_all = matrix_merge(c11, c12, c21, c22) return matrix_all
import numpy if __name__ == '__main__': a = [[1 for i in range(16)] for j in range(16)] b = [[1 for i in range(16)] for j in range(16)] c = strassen(a, b) print('打印结果矩阵') print(c) for i in c: print(i) n_a = numpy.array(a) n_b = numpy.array(b) print(n_a) print(n_b) print(n_a.dot(b))
|