如何扩展3d numpy数组的值?

avkwfej4  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(124)

假设我有一个3d数组(3x3x1),如下所示:

[[[149]
  [121]
  [189]]

 [[ 32]
  [225]
  [ 44]]

 [[ 33]
  [133]
  [ 11]]]

如何扩展所有值,使它们在最深的值(3x3x3)中相同,如下所示:

[[[149 149 149]
  [121 121 121]
  [189 189 189]]

 [[ 32  32  32]
  [225 225 225]
  [ 44  44  44]]

 [[ 33  33  33]
  [133 133 133]
  [ 11  11  11]]]

我已经试过了:

for i in range(len(array)):
  for j in range(len(array[i])):
    array[i][j] = np.array(list(array[i][j]) * 3)
print(array)

但它给了我一个错误:

could not broadcast input array from shape (3,) into shape (1,)

出于一般化的目的,我如何使用m x n x p形状格式实现这一点?

x6yk4ghg

x6yk4ghg1#

有多种选择。例如:np.repeatnp.tilenp.broadcast_to

import numpy as np

arr = np.array([
    [[149], [121], [189]],
    [[ 32], [225], [ 44]],
    [[ 33], [133], [ 11]]
])

out = np.repeat(arr, 3, axis=2)
# or
out = np.broadcast_to(arr, (3, 3, 3))
# or
out = np.tile(arr, (1, 1, 3))

选择最适合你的。
另外请注意,形状为(3, 3, 1)的数组可能会自动用作形状为(3, 3, 3)的数组,而无需由于broadcasting而手动重复。

cwxwcias

cwxwcias2#

也有np.c_[]的方法:

import numpy as np

arr = np.array([
    [[149], [121], [189]],
    [[ 32], [225], [ 44]],
    [[ 33], [133], [ 11]]
])

print(np.c_[arr, arr, arr])

相关问题