Pytorch错误:RuntimeError:输出的形状[1,3,1]与广播形状[3,3,7]不匹配

zpjtge22  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(93)

我试图理解这段代码抛出错误的原因。

x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
(x.add_(y)).size()
RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]

字符串
参考网址:https://pytorch.org/docs/stable/notes/broadcasting.html#in-place-semantics
根据广播规则,因为每个维度都有一个1,所以上面的Tensor可以加到第二个维度上,从而给予一个(3,3,7)Tensor。
Numpy似乎正确地执行了生成(3,3,7)Tensor的操作。

import numpy as np
x=np.empty((1,3,1))
y=np.empty((3,1,7))
(x + y).shape
Output: (3, 3, 7)


有人能解释一下,为什么Pytorch会产生错误,即使遵循了规则?

clj7thdc

clj7thdc1#

我意识到了上面代码的问题。所以,函数add_()执行x = x+y,这导致了问题,因为x和y形状之间不匹配。因为,Tensor被分配为固定的内存大小,执行上面的操作应该出错,除非x具有与结果Tensor相同的形状。

相关问题