NumPy如何在内部处理非连续切片的矩阵乘法?

tzdcorbm  于 4个月前  发布在  其他
关注(0)|答案(1)|浏览(33)

你好,Stack Overflow社区,
我正在使用NumPy进行矩阵运算,我有一个关于NumPy如何处理矩阵乘法的问题,特别是在处理矩阵的非连续切片时。
考虑一个场景,我们有一个大矩阵,比如大小为[1000,1000],我们需要在这个矩阵的切片版本上执行矩阵乘法,步骤如[::10,::10]。我知道NumPy可能使用优化的BLAS例程,如GEMM来进行矩阵乘法。然而,BLAS例程通常需要连续的内存布局才能有效运行。
我的问题是:NumPy如何在内部处理这种情况,其中乘法的输入矩阵由于切片而不连续?具体来说,我有兴趣了解NumPy是否:
1.自动将这些片重新分配到新的连续内存块,然后执行GEMM
1.有一种优化的方法来处理不连续的切片,而无需重新分配内存。
1.使用BLAS例程的任何特定变体或NumPy自己的实现来处理此类情况。
这些信息将帮助我更好地理解在NumPy中使用具有矩阵乘法步骤的切片的性能影响。
提前感谢您的见解!

cclgggtu

cclgggtu1#

np.matmul做了大量的工作,试图弄清楚什么时候它可以把工作传递给BLAS。实现它的主要源文件是numpy/_core/src/umath/matmul.c.src,具体来说,看看@TYPE@_matmul()is_blasable2d()
具体来说,关于is_blasable2d的注解检查了:
1.步幅不得混叠或重叠
1.较快(第二)轴必须是连续的
1.较慢的(第一个)轴步幅(单位步长)必须大于较快的轴尺寸
因此,由于第二个约束,即第二个轴不连续,您的示例将使用较慢的_noblas变体。
作为一个健全性检查,我们看看运行时是否与上面的观察一致:

import numpy as np

arr = np.zeros((1000, 1000))

%timeit arr[::2,::2] @ arr[::2,::2]     # takes ~300ms
%timeit arr[::2,:500] @ arr[::2,:500]   # takes ~  4ms
%timeit arr[:500,:500] @ arr[:500,:500] # takes ~  4ms

# as pointed out by hpaulj, the following takes ~  5ms
%timeit np.ascontiguousarray(arr[::2,::2]) @ np.ascontiguousarray(arr[::2,::2])

字符串
这似乎是正确的,第一个变体有一个不连续的第二个轴,速度要慢得多,大概是因为它没有使用BLAS。其他变体可能更快,因为它们被传递到BLAS。制作一个连续的副本需要一些时间,但最终的运行时间更快,所以在必要时这样做是值得的。

相关问题