pytorch 为什么宽矩阵的乘法比方阵慢?

vsnjm48y  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(114)

在尝试提高代码性能时,我注意到以下几点:

>>> a, b = torch.randn(1000,1000), torch.randn(1000,1000)
>>> c, d = torch.randn(10000, 100), torch.randn(100, 1000)
>>> e, f = torch.randn(100000, 10), torch.randn(10, 1000)
  
>>> %timeit torch.mm(a, b)
17 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit torch.mm(c, d)
24.4 ms ± 575 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit torch.mm(e, f)
138 ms ± 590 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

字符串
从理论上讲,上面的每个矩阵运算需要10^9次乘法,但实际上有很大的差异!随着矩阵变得越来越矩形,性能开始下降。我认为缓存未命中是一个原因,但似乎这些乘法都是缓存友好的。为什么乘方阵更快?

kulphzqa

kulphzqa1#

我有点困惑,你说这两个操作都需要10^9次操作。O(n^3)只对方阵成立。
注意,我将忽略像斯特拉森这样的改进的矩阵算法,它大大降低了这个界限。我将用你在学校或大学里学到的简单实现来证明我的观点。
对于具有形状(MxN)和(NXM)的矩阵的这种朴素矩阵乘法,以下规则适用于运算次数:

  • 加法的次数为M^2 *(2N - 1)
  • 乘法的次数为M^2 * N

| M| N|添加[1 e9]|乘法[1 e9]|操作[1 e9]|
| --|--|--|--|--|
| 1000 | 1000 |0.1999| 0.1秒|0.2999|
| 10000 | 100 |零点一九|1.0版|2.99|
| 100000 | 10 | 19 |十点| 29 |
如前所述,此表忽略了更高效的算法以及缓存位置等计算障碍。但它显示了一般情况下方阵乘法需要更少的操作,这反映在您测量的时间上(然而,torch.mm已经非常优化,这是时间不直接与表中的操作数量成比例的原因)

相关问题