pytorch FID和自定义特征提取器

wi3ka0sx  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(71)

我想使用自定义要素提取器来计算FID
根据https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html,我可以使用nn.Module作为feature
下面的代码有什么问题?

import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3

net = inception_v3()
checkpoint = torch.load('checkpoint.pt')
net.load_state_dict(checkpoint['state_dict'])
net.eval()

fid = FrechetInceptionDistance(feature=net)
# generate two slightly overlapping image intensity distributions
imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
result = fid.compute()

print(result)

个字符

nbysray5

nbysray51#

问题是你将输入转换为dtype=torch.uint8。模型需要一个浮点Tensor。

相关问题