pytorch 微调SSD Light torchvision

gc0ot86w  于 6个月前  发布在  其他
关注(0)|答案(2)|浏览(62)

我想在PyTorch中微调一个对象检测器。为此,我使用了这个教程:
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
然而,FastRCNN模型不适合我的用例,所以我对SSDLight进行了微调。我写了这段代码来设置一个新的分类头:

from functools import partial
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
    
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)

in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer  = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
num_classes = 2
model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)

字符串
由于我的模型性能不好,我想问问社区,上面的代码是否正确?
先谢了。

j2qf4p5b

j2qf4p5b1#

如果您目标是创建一个带有自定义num_classes的模型,那么您可以:
1.在torchvision的初始化中设置新的自定义类。
1.显式加载默认的预训练模型。
1.匹配形状,并丢弃不同形状的权重。
1.将调整后的预训练权重加载到模型中,然后就可以进行重新训练了。
具体如下:

num_classes = 2
# Step 1.
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)
checkpoint = torch.load(default_pretrained_model_path) # in windows, you could check the model here C:\Users\user\.cache\torch\hub\checkpoints

# Step 2, load the model state_dict and the default model's state_dict
mstate_dict = model.state_dict()
cstate_dict = torch.load(args.weights)

# Step 3.
for k in mstate_dict.keys():
    if mstate_dict[k].shape != cstate_dict[k].shape:
        print('key {} will be removed, orishape: {}, training shape: {}'.format(k, cstate_dict[k].shape, mstate_dict[k].shape))
        cstate_dict.pop(k)
# Step 4.
model.load_state_dict(cstate_dict, strict=False)

字符串
希望有帮助,干杯~

eblbsuwk

eblbsuwk2#

这是我第一次做这种事情,但我得到了很好的结果:

model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=num_classes, weights_backbone='DEFAULT', trainable_backbone_layers=0)

字符串
所以我只用现有的 Backbone.js ,从零开始,不要训练 Backbone 。与问题中的想法和Briliantn的answer相比,达到类似点所需的训练至少少10倍(可能是因为这两种方法都没有在开始时冻结 Backbone.js ).使用冻结的 Backbone.js ,你可以增加批量大小,这可以加快训练速度。一旦模型停止改进,我就解冻 Backbone.js ,并训练它更多。

相关问题