pytorch保存模型像keras

bfrts1fy  于 8个月前  发布在  其他
关注(0)|答案(1)|浏览(89)

我想保存我在pytorch中训练的模型。我想保存训练后的模型和重量。
在喀拉斯,

from tensorflow import keras

model_json = model.to_json()
with open("bpnn_model.json", "w") as json_file:
    json_file.write(model_json)

# saving weight
model.save_weights("bpnn_model_weights.h5")

字符串
以节省重量和模型形状。和加载模型我使用

from tensorflow import keras

with open('bpnn_model.json', 'r') as file:
    model_json = file.read()

trained_model = keras.models.model_from_json(model_json)
trained_model.load_weights('bpnn_model_weights.h5')


这样我就可以轻松地重新构建模型,而无需在新代码中使用saving_model_code
在pytorch中,我尝试torch.save()/model.state_dict()来保存模型。但我得到的是Can't get attribute 'net' on <module '__main__' from D:\\...,我问chat-gpt,他告诉我需要在加载代码之前构建模型。或者需要从构建模型代码(如from your_model_file import Net)导入网络
我想要的是当我打开一个新的代码时,我只需要加载模型和权重,也许设置激活和优化器。我可以重新训练或拟合模型,不需要关心模型的形状。
注:我的英语不是很好,如果我写错了什么或很难理解,我道歉aboot。

weylhg0b

weylhg0b1#

你可以试试torch.jit

import torch
from YourModel import yourModel

model = yourModel()
...
#training
...
#save model
model.eval()
jit_model = torch.jit.trace(model, torch.randn(N, C, H, W))
torch.jit.save(jit_model, 'your_path.pth')

字符串
保存好JIT模型后,可以像这样加载:

import torch
...
#load model
model = torch.jit.load("./your_path.pth")

相关问题