尝试在Pytorch中设置可学习的阈值

rjzwgtxy  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(63)

我尝试设置一个可以学习的参数。我希望这个参数可以是一个阈值,可以限制输入。如果输入小于它,则设置为-1e9。如果不小于,则保持其数字。在下面的代码中,我想使用这个参数作为MASK的阈值,然后找到相应的位置将该像素设置为-1e9。`class A(nn.Module):definit(self,):super().init()self.threshold = nn.Parameter(torch.tensor(0.01,requires_grad=True))

def forward(self, x, mask=None):
   
    x = torch.where(mask <= self.threshold,torch.tensor(-1e9).cuda(),x)
   
    return x`

字符串
但问题是没有grad!有什么办法让我达到我想做的事情吗?
我希望这个门槛可以随着培训过程的更新而更新。

eulz3vhy

eulz3vhy1#

import torch
import torch.nn as nn

class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.threshold = nn.Parameter(torch.tensor(0.01, requires_grad=True))

    def forward(self, x, mask=None):
        if mask is None:
            mask = torch.ones_like(x)

        thresholded_values = torch.sigmoid(self.threshold * mask)
        x = torch.where(thresholded_values > 0.5, x, -1e9 * mask)

        return x

字符串

相关问题