保存Keras模型中的自定义字段

mum43rcc  于 7个月前  发布在  其他
关注(0)|答案(2)|浏览(68)

考虑一下我有一个经过训练的Keras序列模型的情况。我保存模型,

keras.saving.save_model(model, path, save_format="...")

字符串
然而,在保存之前,我在模型中设置了一个自定义的list[str]属性:

setattr(model, "custom_attr", ["one", "two", "three"])


最后,当我使用keras.saving.load_model重新加载模型对象(来自另一个项目)时,我希望通过model.custom_attr使用我的自定义属性。然而,这不起作用,因为重新加载模型后custom_attr不再存在。

有没有办法做到这一点

我查了一下,似乎可以在重新加载模型时指定custom_objects参数,但该方法似乎仅限于自定义模型类中定义的自定义层或自定义损失函数。我的设置完全不同,因为我有一个普通的Sequential模型。

34gzjxbg

34gzjxbg1#

我想一个解决方案是用pickle模块单独保存您的自定义属性:

import pickle

# Your custom attribute
custom_attr = ["one", "two", "three"]

# Save the custom attribute to a file
with open("custom_attr.pkl", "wb") as file:
    pickle.dump(custom_attr, file)

字符串
在重新加载模型之后,也加载这个pickle并设置属性setattr(model, custom_attr)

mpbci0fu

mpbci0fu2#

我通过子类化Sequential类并向构造函数添加一个参数来解决这个问题:

class SequentialWithCustomAttr(keras.Sequential):
    def __init__(self, custom_attr=[], layers=None, *args, **kwargs):
        super().__init__(layers=layers, trainable=kwargs["trainable"], name=kwargs["name"])
        self.custom_attr = custom_attr

    def get_config(self):
        config = super().get_config()
        config.update({"custom_attr": self.custom_attr})

        return config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        custom_attr = config.pop("custom_attr", None)

        # Deserialize layers one by one by using their configs.
        layers_confs = config.pop("layers", None)
        layers = list(map(keras.saving.deserialize_keras_object, layers_confs))

        # Create an instance of class SequentialWithCustomAttr.
        model = cls(custom_attr=custom_attr, layers=layers, **config)

        return model

字符串

相关问题