pytorch scatter_函数学习笔记

x33g5p2x  于2021-11-10 转载在 其他  
字(2.0k)|赞(0)|评价(0)|浏览(236)

在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(src)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:

target.scatter(dim, index, src)

其中各变量及参数的说明如下:

  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐行进行列填充
  • index: 按照轴方向,在target张量中需要填充的位置

dim 0:

把a按顺序(一行一行遍历)给b的索引(index)赋值,index是行编号

实际就是把a的行按照新的顺序赋值给b,顺序就是index行编号。

列子1:

import torch

a = (torch.arange(10) + 1).reshape(5, 2).float()
print(a)
print("-------------------------------------")
b = torch.zeros(5, 3)

b_ = b.scatter(dim=0, index=torch.LongTensor([[4, 2], [3, 0], [2, 0], [1, 0], [0, 0]]), src=a)

print(b_)

结果:

tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.]])

tensor([[ 9., 10.,  0.],
        [ 7.,  0.,  0.],
        [ 5.,  2.,  0.],
        [ 3.,  0.,  0.],
        [ 1.,  0.,  0.]])

例子2:

import torch
a = (torch.arange(10)+1).reshape(2,5).float()
print(a)
print("-------------------------------------")
b = torch.zeros(3, 5)
b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2]]),src=a)


# b 0行  a 1列 1行
# b 2行  a 2列 1行
print(b_)
print("-------------------------------------")

b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)

# 7是因为两个 第2行 第4列 值发生覆盖了。
# 第1行 第1列
# 第3行 第2列
# 第2行 第3列
# 第2行 第4列
# 第3行 第5列

# 第3行 第1列
# 第1行 第2列
# 第3行 第3列
# 第2行 第4列
# 第1行 第5列

print(b_)

结果:

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])

tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.]])

tensor([[ 1.,  7.,  0.,  0., 10.],
        [ 0.,  0.,  3.,  9.,  0.],
        [ 6.,  2.,  8.,  0.,  5.]])

dim1:

把a按顺序(一行一行遍历)给b的索引(index)赋值,index是列编号

实际就是把a的行重新设置,赋值到b的行上,新的位置,就是index索引(列编号位置)。

import torch
a = (torch.arange(10)+1).reshape(2,5).float()
print(a)
print("-------------------------------------")
b = torch.zeros(3, 5)
b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2]]),src=a)

#b 0列 第1行, a 0列 第1行
#b 2列 第1行 ,a 1列 第1行
print(b_)
print("-------------------------------------")

b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)

#把a的第1行按顺序 放在b的第1行上,顺序是index

#4的来源:
0, 2, 1, [1], 2

#把a的第2行按顺序 放在b的第2行上,顺序是index

print(b_)

结果:

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])

tensor([[1., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

tensor([[ 1.,  5.,  2.,  0.,  0.],
        [10.,  9.,  8.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

相关文章