测试numpy数组是否只包含零

bwntbbo3  于 2023-04-06  发布在  其他
关注(0)|答案(8)|浏览(122)

我们用零初始化一个numpy数组如下:

np.zeros((N,N+1))

但是我们如何检查给定的n*n numpy数组矩阵中的所有元素是否为零。
如果所有的值都为零,该方法只需要返回True。

3xiyfsfu

3xiyfsfu1#

这里发布的其他答案也可以,但要使用的最清晰、最有效的函数是numpy.any()

>>> all_zeros = not np.any(a)

>>> all_zeros = not a.any()
  • 这比numpy.all(a==0)更好,因为它使用更少的RAM(它不需要由a==0项创建的临时数组)。
  • 而且,它比numpy.count_nonzero(a)更快,因为它可以在找到第一个非零元素时立即返回。
    ***编辑:**正如@Rachel在评论中指出的那样,np.any()不再使用“短路”逻辑,因此您不会看到小型阵列的速度优势。
qlzsbp2j

qlzsbp2j2#

查看numpy.count_nonzero。

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5
fjnneemd

fjnneemd3#

我会在这里使用np.all,如果你有一个数组a:

>>> np.all(a==0)
fafcakar

fafcakar4#

正如另一个答案所说,如果你知道0是数组中唯一可能的false元素,那么你可以利用truthy/falsy评估。数组中的所有元素都是false,如果其中没有任何true元素。*

>>> a = np.zeros(10)
>>> not np.any(a)
True

然而,答案声称any比其他选项更快,部分原因是短路。截至2018年,Numpy的allany不会短路
如果你经常做这类事情,那么很容易使用numba制作自己的短路版本:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

即使在不短路的情况下,这些版本往往比Numpy的版本更快。count_nonzero是最慢的。
用于检查性能的一些输入:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

检查:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
  • 有用的allany等价关系:
np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))
fhg3lkii

fhg3lkii5#

这会有用的。

def check(arr):
    if np.all(arr == 0):
        return True
    return False
tpgth1q7

tpgth1q76#

如果数组中的所有元素都大于或等于0,我认为使用sum是最快的方法。

test = np.ones((128, 128, 128))
%%timeit
not np.any(test)
>>> 1.46 ms ± 9.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
np.sum(test) == 0
>>> 646 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cbeh67ev

cbeh67ev7#

如果你想把1 e-15分类为零:

def all_zero( numpy_array ):
    return np.allclose( numpy_array, np.zeros_like(numpy_array) )
nx7onnlm

nx7onnlm8#

如果你正在测试所有的零,以避免在另一个numpy函数上出现警告,那么在try,except块中 Package 这一行将保存在你感兴趣的操作之前进行零测试的时间。

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0

相关问题