我正在使用PyTorch Lightning,我定义了我的模型如下:
class MyModel(MyBaseClass):
def __init__(self, ..., **kwargs):
super().__init__(**kwargs)
self.model_parameter = nn.Parameter(
torch.rand(...)
)
字符串
我使用一个自定义的损失函数,如下所示:
class MyCustomLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, outputs, targets):
loss = ...
scalar_loss = torch.mean(loss)
return scalar_loss
型
在我的配置文件中,我像下面这样设置class_path:
model:
class_path: ...path_to_MyModel
init_args:
criterion:
class_path: ...path_to_MyCustomLoss
型
但是,我需要一种方法来访问我的自定义损失函数中的model_parameter
。我需要这些参数来计算我的损失。我如何在自定义损失函数中使用我的模型参数?
1条答案
按热度按时间nzk0hqpo1#
您需要将模型示例传递给您的自定义损失函数。一旦您的自定义损失函数可以访问模型示例,您就可以拉取
model_parameters
。它是如何工作的:字符串
当你在训练步骤中调用损失函数时,传递模型示例:
型