我试图理解这段代码抛出错误的原因。
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
(x.add_(y)).size()
RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]
字符串
参考网址:https://pytorch.org/docs/stable/notes/broadcasting.html#in-place-semantics
根据广播规则,因为每个维度都有一个1,所以上面的Tensor可以加到第二个维度上,从而给予一个(3,3,7)Tensor。
Numpy似乎正确地执行了生成(3,3,7)Tensor的操作。
import numpy as np
x=np.empty((1,3,1))
y=np.empty((3,1,7))
(x + y).shape
Output: (3, 3, 7)
型
有人能解释一下,为什么Pytorch会产生错误,即使遵循了规则?
1条答案
按热度按时间clj7thdc1#
我意识到了上面代码的问题。所以,函数add_()执行x = x+y,这导致了问题,因为x和y形状之间不匹配。因为,Tensor被分配为固定的内存大小,执行上面的操作应该出错,除非x具有与结果Tensor相同的形状。