在导入pytorch的过程中,是否有任何配置可以强制初始化NumPy模块?

3j86kqsm  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(44)

我有一些东西,

python3.11/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: module compiled against API version 0x10 but this version of numpy is 0xf (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)

字符串
但我肯定需要numpy,因此我需要API版本匹配,所以我不希望这只是一个警告,我希望它是一个错误。我不想手动检测NumPy版本,因为pytorch和numpy版本可能会有所不同。不同的numpy可以共享相同的API版本。它应该只在API不匹配时出错。
如果当前安装的numpy太旧,我最好也在pip级别失败或安装新的numpy。有没有办法做到这一点,而不是在pip命令行手动硬编码numpy==xxx?

update似乎我可以尽早使用warnings.filterwarnings("error", category=UserWarning, message=warning_pattern)(在任何pytorch导入之前)将警告转换为错误。

然而,看起来还没有办法在PIP期间将numpy更新到兼容版本。

9udxz4iz

9udxz4iz1#

可以像这样显式地检查(任何模块的)版本:

import numpy as np

v = np.__version__
if v != "1.26.1":
    print('wrong version')
else:
    print('correct version:', v)

字符串
并返回以下内容:

correct version: 1.26.1


或者你可以像这样显式地抛出一个错误

v = np.__version__
if v == "1.26.2":
    print('correct numpy version:', v)
else:
    raise Exception(f'wrong numpy version: {v}')


上述的替代方案是在独立的python包(虚拟环境)中的requirements.txt文件。
最后,您可以像这样调用特定版本的安装:

import subprocess
import sys

subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'numpy==1.26.1'])

相关问题