pytorch Assert错误:如果capturable=False,则state_steps不应为CUDATensor

qv7cva1a  于 2022-12-18  发布在  其他
关注(0)|答案(3)|浏览(2029)

我在Google colab上加载上一个时期的模型权重时遇到此错误。我使用的是PyTorch版本1.12.0。我无法降级到更低版本,因为我使用的外部库需要Pytorch 1.12.0
谢谢!

hs1ihplo

hs1ihplo1#

它似乎与AdamAdamW优化器新引入的参数(可捕获)有关。
1.加载检查点optim.param_groups[0]['capturable'] = True后强制capturable = True。这似乎会使模型训练速度降低约10%(YMMV取决于设置)。
1.将PyTorch恢复到以前的版本(可能是1.11.0)。
图片来源:https://github.com/pytorch/pytorch/issues/80809#issuecomment-1173481031

qyzbxkaa

qyzbxkaa2#

如果您使用的是PyTorch 1.12.0和Cuda二进制文件11.6/11.7,请在shell或命令提示符下粘贴以下内容:

pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

在更新的torch版本中删除了Adam Optimizer回归

编辑有一个新的Torch版本使用此安装

pip install torch==1.13.0+cu117 torchvision torchaudio torchtext --extra-index-url https://download.pytorch.org/whl/cu117
hk8txs48

hk8txs483#

你能告诉我你使用的是哪个优化器吗?我在使用AdamW优化器时遇到过这种情况。你可以通过使用load_state_dict加载优化器,然后使用.cpu()函数显式地将其Map到cpu来避免这种情况。

相关问题