pytorch 为什么timm视觉Transformer位置嵌入初始化为零?

fv2wmkja  于 6个月前  发布在  其他
关注(0)|答案(2)|浏览(83)

我正在查看visual transformers的timm实现,对于位置嵌入,他用零初始化了他的位置嵌入,如下所示:

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

字符串
参见此处:https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L309
我不知道这实际上是如何嵌入任何有关的立场时,它后来被添加到补丁?

x = x + self.pos_embed


任何反馈都很感激。

hts6caw3

hts6caw31#

位置嵌入是一个包含在计算图中的参数,并在训练过程中进行更新。因此,如果您初始化为零并不重要;它们是在训练过程中学习的。

eyh26e7m

eyh26e7m2#

如果你现在检查仓库,它确实是使用从正态分布中采样的权重进行初始化的。使用零初始化权重,然后在init_weights函数中重新初始化是很常见的。这可能是你看到的文件版本的情况。
当使用可学习的位置嵌入时,它们确实应该随机初始化。这有助于网络学习令牌的位置信息。我以前没有见过零初始化,这充其量可能会影响模型收敛。

相关问题