你好,Stack Overflow社区,
我正在使用NumPy进行矩阵运算,我有一个关于NumPy如何处理矩阵乘法的问题,特别是在处理矩阵的非连续切片时。
考虑一个场景,我们有一个大矩阵,比如大小为[1000,1000],我们需要在这个矩阵的切片版本上执行矩阵乘法,步骤如[::10,::10]。我知道NumPy可能使用优化的BLAS例程,如GEMM
来进行矩阵乘法。然而,BLAS例程通常需要连续的内存布局才能有效运行。
我的问题是:NumPy如何在内部处理这种情况,其中乘法的输入矩阵由于切片而不连续?具体来说,我有兴趣了解NumPy是否:
1.自动将这些片重新分配到新的连续内存块,然后执行GEMM
。
1.有一种优化的方法来处理不连续的切片,而无需重新分配内存。
1.使用BLAS例程的任何特定变体或NumPy自己的实现来处理此类情况。
这些信息将帮助我更好地理解在NumPy中使用具有矩阵乘法步骤的切片的性能影响。
提前感谢您的见解!
1条答案
按热度按时间cclgggtu1#
np.matmul
做了大量的工作,试图弄清楚什么时候它可以把工作传递给BLAS。实现它的主要源文件是numpy/_core/src/umath/matmul.c.src
,具体来说,看看@TYPE@_matmul()
和is_blasable2d()
。具体来说,关于
is_blasable2d
的注解检查了:1.步幅不得混叠或重叠
1.较快(第二)轴必须是连续的
1.较慢的(第一个)轴步幅(单位步长)必须大于较快的轴尺寸
因此,由于第二个约束,即第二个轴不连续,您的示例将使用较慢的
_noblas
变体。作为一个健全性检查,我们看看运行时是否与上面的观察一致:
字符串
这似乎是正确的,第一个变体有一个不连续的第二个轴,速度要慢得多,大概是因为它没有使用BLAS。其他变体可能更快,因为它们被传递到BLAS。制作一个连续的副本需要一些时间,但最终的运行时间更快,所以在必要时这样做是值得的。