将整数Tensor转换为仅在该索引处有1的二进制Tensor

enyaitl3  于 2021-08-20  发布在  Java
关注(0)|答案(1)|浏览(335)

此问题已在此处找到答案

将Tensor转换为索引的一个热编码Tensor(2个答案)
两天前关门了。
在pytorch中,将整数Tensor转换为仅在整数索引处有1的二元Tensor是一件容易的事吗?
例如

tensor([[1,3,2,6]])

将成为

tensor([[0,1,0,0,0,0,0],
        [0,0,0,1,0,0,0],
        [0,0,1,0,0,0,0],
        [0,0,0,0,0,0,1]])
cwxwcias

cwxwcias1#

t = tensor([[1,3,2,6]])

rows = t.shape[1]
cols = t.max() + 1

output = torch.zeros(rows, cols) # initializes zeros array of desired dimensions
output[list(range(rows)), t.tolist()] = 1 # sets cells to 1

为了澄清最后一个操作,您可以传入一个行号和列号列表,它会将所有这些元素设置为相等后的值。因此,在本例中,我们希望将以下位置设置为1:

(0,1), (1,3), (2,2), (3,6)

我们将其表示为:

output[[0,1,2,3], [1,3,2,6]] = 1

你会看到这些列表排列成a)一个递增的列表,直到行总数,和b)我们的原始Tensor

相关问题