如何在pytorch中沿着轴执行乘法?

qc6wkl3g  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(141)

我有两个TensorX和Y -- X的形状为(20,4,300),Y的形状为(20,300).如何执行乘法,使我得到形状为(20,4)的结果.在keras中相应的技术是

doc_product = Dot(axes=(2,1))([X,Y])

我想知道如何在pytorch也能做到这一点?

6ojccjat

6ojccjat1#

最通用的矩阵乘法函数是torch.einsum:它允许您指定要沿着其相乘的维度以及输出Tensor的维度顺序。
在您的情况下,它看起来像:

dot_product = torch.einsum('bij,bj->bi')

相关问题