如何获得一个包含其他Tensor大小(或形状)的pytorchTensor而不转换为Python int?

amrnrhlw  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(54)

在将pytorch代码导出到ONNX的上下文中,我得到了以下警告:
TracerWarning:torch.tensor结果在跟踪中被注册为常量。如果您使用此函数从常量变量创建Tensor,并且每次调用此函数时都是相同的,则可以安全地忽略此警告。在任何其他情况下,这可能会导致跟踪不正确。
以下是令人反感的一行:

text_lengths = torch.tensor([text_inputs.shape[1]]).long().to(text_inputs.device)

字符串
text_inputs是形状为torch.Size([1, 81])torch.Tensor
这个警告是正确的,不能忽略,因为text_inputs的形状应该是动态的。
我需要text_lengths是一个torch.Tensor,它包含来自text_inputsshape的数字81。上面的“冒犯行”成功地做到了这一点,但我们实际上从pytorch到Python int再回到pytorch的往返,因为torch.Size对象中的元素是Python int s。这(1)有点奇怪,(2)可能在GPU -> CPU -> GPU方面效率低下,如上所述,这是ONNX导出上下文中的实际问题。
有没有其他方法可以让我在不“离开”torch世界的情况下,在torch计算中使用Tensor的形状?

vm0i2vca

vm0i2vca1#

使用torch.tensor将导致导出的图中有一个常量值。在跟踪时,torch.tensor.shapetorch.tensor.size的输出应该返回Tensor(而不是python整数),因此上面的代码应该按预期导出,只需使用

text_lengths = text_inputs.shape[1]

字符串
返回正确值的LongTensor
由于构造函数中的括号,代码似乎没有按照预期的方式运行。[text_inputs.shape[1]]中的形状在python列表中被解释为整数,因此在跟踪过程中保存为常量。
在创建Tensor时删除额外的括号将正确导出为在运行时由text_inputs形状定义的标量。

text_lengths = torch.LongTensor(text_inputs.size(1))

相关问题