服务器之家

服务器之家 > 正文

Java实现矩阵乘法以及优化的方法实例

时间:2021-08-04 09:49     来源/作者:GGG_Yu

传统的矩阵乘法实现

  首先,两个矩阵能够相乘,必须满足一个前提:前一个矩阵的行数等于后一个矩阵的列数。

  第一个矩阵的第m行和第二个矩阵的第n列的乘积和即为乘积矩阵第m行第n列的值,可用如下图像表示这个过程。

Java实现矩阵乘法以及优化的方法实例

矩阵乘法过程展示

c[1][1] = a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1] + a[1][3] * b[3][1] + a[1][4] * b[4][1]

  而用java实现该过程的传统方法就是按照该规则实现一个三重循环,把各项乘积累加:

?
1
2
3
4
5
6
7
8
9
10
11
12
public int[][] multiply(int[][] mat1, int[][] mat2){
    int m = mat1.length, n = mat2[0].length;
    int[][] mat = new int[m][n];
    for(int i = 0; i < m; i++){
        for(int j = 0; j < n; j++){
            for(int k = 0; k < mat1[0].length; k++){
                mat[i][j] += mat1[i][k] * mat2[k][j];
            }
        }
    }
    return mat;
}

  可以看出该方法的时间复杂度为o(n3),当矩阵维数比较大的时候程序就很容易超时。

优化方法(strassen算法)

  strassen算法是由volker strassen在1966年提出的第一个时间复杂度低于o(n³)的矩阵乘法算法,其主要思想是通过分治来实现矩阵乘法的快速运算,计算过程如图所示:

Java实现矩阵乘法以及优化的方法实例

将一次矩阵乘法拆分成多个乘法与加法的结合

  为什么这个方法会更快呢,我们知道,按照传统的矩阵乘法:

c11 = a11 * b11 + a12 * b21
c12 = a11 * b12 + a12 * b22
c21 = a21 * b11 + a22 * b21
c22 = a21 * b12 + a22 * b22

  我们需要8次矩阵乘法和4次矩阵加法,正是这8次乘法最耗时;而strassen方法只需要7次矩阵乘法,尽管代价是矩阵加法次数变为18次,但是基于数量级考虑,18次加法仍然快于1次乘法。

  当然,strassen算法的代码实现也比传统算法复杂许多,这里附上另一个大神写的java实现(原文链接:):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
public class matrix {
    private final matrix[] _matrixarray;
    private final int n;
    private int element;
    public matrix(int n) {
        this.n = n;
        if (n != 1) {
            this._matrixarray = new matrix[4];
            for (int i = 0; i < 4; i++) {
                this._matrixarray[i] = new matrix(n / 2);
            }
        } else {
            this._matrixarray = null;
        }
    }
    private matrix(int n, boolean needinit) {
        this.n = n;
        if (n != 1) {
            this._matrixarray = new matrix[4];
        } else {
            this._matrixarray = null;
        }
    }
    public void set(int i, int j, int a) {
        if (n == 1) {
            element = a;
        } else {
            int size = n / 2;
            this._matrixarray[(i / size) * 2 + (j / size)].set(i % size, j % size, a);
        }
    }
    public matrix multi(matrix m) {
        matrix result = null;
        if (n == 1) {
            result = new matrix(1);
            result.set(0, 0, (element * m.element));
        } else {
            result = new matrix(n, false);
            result._matrixarray[0] = p5(m).add(p4(m)).minus(p2(m)).add(p6(m));
            result._matrixarray[1] = p1(m).add(p2(m));
            result._matrixarray[2] = p3(m).add(p4(m));
            result._matrixarray[3] = p5(m).add(p1(m)).minus(p3(m)).minus(p7(m));
        }
        return result;
    }
    public matrix add(matrix m) {
        matrix result = null;
        if (n == 1) {
            result = new matrix(1);
            result.set(0, 0, (element + m.element));
        } else {
            result = new matrix(n, false);
            result._matrixarray[0] = this._matrixarray[0].add(m._matrixarray[0]);
            result._matrixarray[1] = this._matrixarray[1].add(m._matrixarray[1]);
            result._matrixarray[2] = this._matrixarray[2].add(m._matrixarray[2]);
            result._matrixarray[3] = this._matrixarray[3].add(m._matrixarray[3]);;
        }
        return result;
    }
    public matrix minus(matrix m) {
        matrix result = null;
        if (n == 1) {
            result = new matrix(1);
            result.set(0, 0, (element - m.element));
        } else {
            result = new matrix(n, false);
            result._matrixarray[0] = this._matrixarray[0].minus(m._matrixarray[0]);
            result._matrixarray[1] = this._matrixarray[1].minus(m._matrixarray[1]);
            result._matrixarray[2] = this._matrixarray[2].minus(m._matrixarray[2]);
            result._matrixarray[3] = this._matrixarray[3].minus(m._matrixarray[3]);;
        }
        return result;
    }
    protected matrix p1(matrix m) {
        return _matrixarray[0].multi(m._matrixarray[1]).minus(_matrixarray[0].multi(m._matrixarray[3]));
    }
    protected matrix p2(matrix m) {
        return _matrixarray[0].multi(m._matrixarray[3]).add(_matrixarray[1].multi(m._matrixarray[3]));
    }
    protected matrix p3(matrix m) {
        return _matrixarray[2].multi(m._matrixarray[0]).add(_matrixarray[3].multi(m._matrixarray[0]));
    }
    protected matrix p4(matrix m) {
        return _matrixarray[3].multi(m._matrixarray[2]).minus(_matrixarray[3].multi(m._matrixarray[0]));
    }
    protected matrix p5(matrix m) {
        return (_matrixarray[0].add(_matrixarray[3])).multi(m._matrixarray[0].add(m._matrixarray[3]));
    }
    protected matrix p6(matrix m) {
        return (_matrixarray[1].minus(_matrixarray[3])).multi(m._matrixarray[2].add(m._matrixarray[3]));
    }
    protected matrix p7(matrix m) {
        return (_matrixarray[0].minus(_matrixarray[2])).multi(m._matrixarray[0].add(m._matrixarray[1]));
    }
    public int get(int i, int j) {
        if (n == 1) {
            return element;
        } else {
            int size = n / 2;
            return this._matrixarray[(i / size) * 2 + (j / size)].get(i % size, j % size);
        }
    }
    public void display() {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                system.out.print(get(i, j));
                system.out.print(" ");
            }
            system.out.println();
        }
    }
    
    public static void main(string[] args) {
        matrix m = new matrix(2);
        matrix n = new matrix(2);
        m.set(0, 0, 1);
        m.set(0, 1, 3);
        m.set(1, 0, 5);
        m.set(1, 1, 7);
        n.set(0, 0, 8);
        n.set(0, 1, 4);
        n.set(1, 0, 6);
        n.set(1, 1, 2);
        matrix res = m.multi(n);
        res.display();
    }
}

总结

到此这篇关于java实现矩阵乘法以及优化的文章就介绍到这了,更多相关java矩阵乘法及优化内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/GGG_Yu/article/details/109693318

标签:

相关文章

热门资讯

2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
2021德云社封箱演出完整版 2021年德云社封箱演出在线看
2021德云社封箱演出完整版 2021年德云社封箱演出在线看 2021-03-15
返回顶部