numpy 如何按特定键值对Tensor进行“批量”排序?

axkjgtzd  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(57)

我需要通过第一列的键值对一批二维矩阵的行进行排序:
原始批处理矩阵(3DTensor):

torch.tensor([[[2, 0], 
               [0, 1],
               [1, 2]],

              [[1, 2], 
               [0, 0], 
               [2, 1]]])

字符串
所需Tensor:

torch.tensor([[[0, 1],
               [1, 2],
               [2, 0]],

              [[0, 0],
               [1, 2],  
               [2, 1]]])


已经知道how to handle one of the batch,和another answer通过for循环解决问题,这不是并行的。那么如何处理整个批处理?

rvpgvaaj

rvpgvaaj1#

这可能有点令人困惑,但很有意义:

(my_tensor[:,torch.argsort(my_tensor[:,:,0], dim=1)])\
[torch.arange(len(my_tensor)),torch.arange(len(my_tensor))]

字符串
在第一行中,我们提取了torch.argsort的排序Tensor,并将其应用于my_tensor,得到了(2, 2, 3, 2)的形状Tensor。由于我们希望每个元素只根据其第一列进行排序,所以我们只对前两个维度的对角线感兴趣,可以通过切片来提取它(第二行代码)。

相关问题