基于存储在另一个数组或列表中的索引拆分numpy多维数组

bksxznpy  于 2021-09-08  发布在  Java
关注(0)|答案(1)|浏览(212)

我有一个numpy多维数组,形状=(12,2,3,3)

import numpy as np
arr = np.arange(12*2*3*3).reshape((12,2,3,3))

我需要根据第二维度选择这些元素,其中Dindice存储在另一个列表中

indices = [0,1,0,0,1,1,0,1,1,0,1,1]

在一个数组中,其余的在另一个数组中。任何一种情况下的输出都应该是另一个形状数组(12,3,3)

arr2 = np.empty((arr.shape[0],*arr.shape[-2:]))

我可以用for循环来做

for i, ii in enumerate(indices):
    arr2[i] = arr[i, indices[ii],...]

然而,我正在寻找一个单一的班轮。
当我尝试使用列表作为索引编制索引时

test = arr[:,indices,...]

我明白了 test 形状为(12,12,3,3)而不是(12,3,3)。你能帮我吗?

lhcgjxsq

lhcgjxsq1#

你可以用 np.arange 对于索引第一个维度:

test = arr[np.arange(arr.shape[0]),indices,...]

或者只是python range 功能:

test = arr[range(arr.shape[0]),indices,...]

相关问题