我在看一个关于PyTorch的教程和编码,在函数torch.randint
上卡住了。根据文档:
torch.randint(低=0,高,大小,*,生成器=无,输出=无,数据类型=无,布局=torch.strided,设备=无,要求_grad=假)→Tensor
这里,size
是:
size(tuple)-定义输出Tensor形状的元组。
YouTuber写道
random_idx = torch.randint(0, len(train_data), size=[1]).item()
但是[1]
不是一个元组,它是一个列表。这怎么可能呢?我也用元组测试了它,它工作得很好,我在互联网上找到的每一个randint()
的用法都为size
提供了一个元组。例如,size = (1,2)
或size = (1,1)
。
我搜索了torch.randint
的源代码,但没有找到。我搜索了GitHub、PyTorch文档,甚至试图在本地PyTorch库中找到它。
2条答案
按热度按时间ukqbszuj1#
文档指出这应该是
tuple
,但实际上,randint()
的定义是:其中,
_size
定义为:因此,在实践中,要求
size
参数的类型为Size
、List
(int类型)或Tuple
(int类型),在本例中,它们的行为基本相同**EDIT:**如上所述,实际上,在Python中,类型只是一种指示,所以如果你使用任何类型的变量,只要函数本身不引发错误,就不会有任何问题。至于为什么函数会相应地运行并返回预期的结果,这是因为答案的第一部分:)
n3ipq98p2#
列表和元组之间唯一的区别是元组一旦创建就不能修改。如果你希望数组是不可变的,你应该把它改为元组yes,否则,列表和元组之间的所有读操作都是相同的,所以你的代码会很好地工作:)。