我尝试设置一个可以学习的参数。我希望这个参数可以是一个阈值,可以限制输入。如果输入小于它,则设置为-1e9。如果不小于,则保持其数字。在下面的代码中,我想使用这个参数作为MASK的阈值,然后找到相应的位置将该像素设置为-1e9。`class A(nn.Module):definit(self,):super().init()self.threshold = nn.Parameter(torch.tensor(0.01,requires_grad=True))
def forward(self, x, mask=None):
x = torch.where(mask <= self.threshold,torch.tensor(-1e9).cuda(),x)
return x`
字符串
但问题是没有grad!有什么办法让我达到我想做的事情吗?
我希望这个门槛可以随着培训过程的更新而更新。
1条答案
按热度按时间eulz3vhy1#
字符串