numpy 如何让嵌套的for循环在python中执行得更快?

flseospp  于 5个月前  发布在  Python
关注(0)|答案(2)|浏览(73)

以下是我的脚本:

for a in range(-100, 101):
    for b in range(-100, 101):
        for c in range(-100, 101):
            for d in range(-100, 101):
                if abs(2**a*3**b*5**c*7**d-0.3048) <= 10**(-6):
                    print('a=',a, ', b=', b, ', c=', c,', d=', d,', the number=', 2**a*3**b*5**c*7**d, ', error=', abs(2**a*3**b*5**c*7**d-.3048))

字符串
在python中执行上面的脚本花了27分15秒。我知道它经历了201^4次表达式计算,但我需要更快地运行这些计算(因为我想尝试range(-200,201)等等)。
我想知道是否有可能使上面的代码执行得更快。我认为使用numpy数组会有所帮助,但不确定如何应用它,以及它是否真的有效。

1szpjjfi

1szpjjfi1#

对于这些类型的计算,您可以尝试numba JIT:

from numba import njit

@njit
def fn():
    for a in range(-100, 101):
        for b in range(-100, 101):
            for c in range(-100, 101):
                for d in range(-100, 101):
                    n = (2.0**a) * (3.0**b) * (5.0**c) * (7.0**d)
                    v = n - 0.3048
                    if abs(v) <= 1e-06:
                        print(
                            "a=",
                            a,
                            ", b=",
                            b,
                            ", c=",
                            c,
                            ", d=",
                            d,
                            ", the number=",
                            n,
                            ", error=",
                            abs(n - 3.048),
                        )

fn()

字符串
在我的机器(AMD 5700 X)上运行这段代码需要大约57秒(包括编译步骤)。相比之下,如果没有@njit(只是普通的Python),这只需要4分钟。

a= -78 , b= -89 , c= -14 , d= 89 , the number= 0.3047994427888104 , error= 2.7432005572111895
a= -78 , b= -57 , c= 50 , d= 18 , the number= 0.30479915330101043 , error= 2.7432008466989894
a= -69 , b= -85 , c= 87 , d= 0 , the number= 0.3047993420932106 , error= 2.7432006579067894
a= -63 , b= 42 , c= -99 , d= 80 , the number= 0.3048005478488736 , error= 2.7431994521511265
a= -63 , b= 74 , c= -35 , d= 9 , the number= 0.3048002583600241 , error= 2.743199741639976
a= -54 , b= 14 , c= -62 , d= 62 , the number= 0.3048007366419375 , error= 2.7431992633580626
a= -54 , b= 46 , c= 2 , d= -9 , the number= 0.30480044715290866 , error= 2.7431995528470914
a= -54 , b= 78 , c= 66 , d= -80 , the number= 0.3048001576641548 , error= 2.7431998423358452
a= -45 , b= -14 , c= -25 , d= 44 , the number= 0.30480092543511833 , error= 2.7431990745648815
a= -45 , b= 18 , c= 39 , d= -27 , the number= 0.3048006359459102 , error= 2.7431993640540897
a= -36 , b= -10 , c= 76 , d= -45 , the number= 0.30480082473902875 , error= 2.7431991752609712
a= 5 , b= -44 , c= -72 , d= 82 , the number= 0.30479914163960603 , error= 2.743200858360394
a= 14 , b= -72 , c= -35 , d= 64 , the number= 0.304799330431799 , error= 2.743200669568201
a= 14 , b= -40 , c= 29 , d= -7 , the number= 0.3047990409441057 , error= 2.743200959055894
a= 23 , b= -100 , c= 2 , d= 46 , the number= 0.30479951922410875 , error= 2.7432004807758914
a= 23 , b= -68 , c= 66 , d= -25 , the number= 0.30479922973623635 , error= 2.7432007702637637
a= 29 , b= 91 , c= -56 , d= -16 , the number= 0.30480014600271205 , error= 2.743199853997288
a= 38 , b= 31 , c= -83 , d= 37 , the number= 0.30480062428444915 , error= 2.743199375715551
a= 38 , b= 63 , c= -19 , d= -34 , the number= 0.30480033479552704 , error= 2.743199665204473
a= 47 , b= 3 , c= -46 , d= 19 , the number= 0.30480081307756046 , error= 2.7431991869224395
a= 47 , b= 35 , c= 18 , d= -52 , the number= 0.30480052358845894 , error= 2.743199476411541
a= 56 , b= 7 , c= 55 , d= -70 , the number= 0.3048007123815079 , error= 2.7431992876184923
a= 65 , b= -21 , c= 92 , d= -88 , the number= 0.3048009011746738 , error= 2.7431990988253263
a= 97 , b= -27 , c= -93 , d= 57 , the number= 0.3047990292827057 , error= 2.7432009707172944

real    0m57,939s
user    0m0,009s
sys     0m0,009s


看看你的代码,你可以使用parallel rangeprange)来进一步加快速度:

from numba import njit, prange

@njit(parallel=True)
def fn():
    for a in prange(-100, 101):
        i_a = 2.0**a
        for b in prange(-100, 101):
            i_b = i_a * 3.0**b
            for c in prange(-100, 101):
                i_c = i_b * 5.0**c
                for d in prange(-100, 101):
                    n = i_c * (7.0**d)
                    v = n - 0.3048
                    if abs(v) <= 1e-06:
                        print(
                            "a=",
                            a,
                            ", b=",
                            b,
                            ", c=",
                            c,
                            ", d=",
                            d,
                            ", the number=",
                            n,
                            ", error=",
                            abs(n - 3.048),
                        )

fn()


在我的8 C/16 T机器上只需~2.7秒。
@EDIT:添加了存储中间结果。谢谢@yotheguitou

puruo6ea

puruo6ea2#

几分钟后就开始了。
主要的速度改进只是预先计算所有的权力。我怀疑itertools实际上给了我任何东西。
你可能不是故意在一个位置使用.3048,在打印消息中使用3.048。我把两者都改成了.3048。也许你是指另一个。

import itertools
    aa = {i: 2 ** i for i in range(-100, 101)}
    bb = {i: 3 ** i for i in range(-100, 101)}
    cc = {i: 5 ** i for i in range(-100, 101)}
    dd = {i: 7 ** i for i in range(-100, 101)}

    for (a, avalue), (b, bvalue), (c, cvalue), (d, dvalue) in itertools.product(aa.items(), bb.items(), cc.items(), dd.items()):
        if abs(avalue * bvalue * cvalue * dvalue - .3048) <= 1e-6:
            value = avalue * bvalue * cvalue * dvalue
            print('a=',a, ', b=', b, ', c=', c,', d=', d,', the number=', value, ', error=', abs(value - .3048))

字符串

相关问题