pytorch 如何在自定义损失函数中包含模型参数

9bfwbjaz  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(76)

我正在使用PyTorch Lightning,我定义了我的模型如下:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)

        self.model_parameter = nn.Parameter(
            torch.rand(...) 
        )

字符串
我使用一个自定义的损失函数,如下所示:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets):
        loss = ...
        scalar_loss = torch.mean(loss)
        return scalar_loss


在我的配置文件中,我像下面这样设置class_path:

model:
    class_path: ...path_to_MyModel
    init_args:
    criterion:
        class_path: ...path_to_MyCustomLoss


但是,我需要一种方法来访问我的自定义损失函数中的model_parameter。我需要这些参数来计算我的损失。我如何在自定义损失函数中使用我的模型参数?

nzk0hqpo

nzk0hqpo1#

您需要将模型示例传递给您的自定义损失函数。一旦您的自定义损失函数可以访问模型示例,您就可以拉取model_parameters。它是如何工作的:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets, model):
        # access model parameters
        model_parameter = model.model_parameter
        # proceed to calculate loss...

字符串
当你在训练步骤中调用损失函数时,传递模型示例:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)
        self.model_parameter = nn.Parameter(torch.rand(...))
        self.criterion = MyCustomLoss(...)  # Your custom loss function

    def forward(self, x):
        # define the forward pass
        ...

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        # compute loss
        loss = self.criterion(outputs, targets, self)
        return loss

相关问题