视觉TRANSFORMERS(ViT)
论文: AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
Github:
GitHub - google-research/vision_transformer
GitHub - lucidrains/vit-pytorch: Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
论文提出将 Transformer网络结构应用于视觉算法,即Vision Transformer (ViT)。最终取得了ImageNet 88:55%准确性, ImageNet-ReaL90:72%准确性, CIFAR-10094:55%准确性,
VTAB77:63%准确性。
网络结构:
标准的transformer的输入是1维的token embedding。为了处理二维图像,我们将尺寸为 H*W*C 的图像reshape为拉平的2维图块,尺寸为N *(P2*C) 。其中(P,P) 为图块的大小, N=H*W/P2 是图块的数量,会影响输入序列的长度。Transformer在所有图层上使用恒定的隐矢量D,比如D=1024,将patch块拉平,并使用可训练的线性投影映射到D的大小,此时输出大小为1*N*D的向量,随机初始化分类token,维度为1*1*D,将此1*N*D的向量和1*1*D的分类token向量进行concat操作,得到1*(N+1)*D的输出向量,
将此输出称为patch embedding。随机初始化大小为1*(N+1)*D的位置编码向量。而后,patch embedding 和position embedding相加,得到最终的 embedding向量。最终的 embedding向量经过Tranformer 编码器处理,输出的向量在(N+1)的维度上进行mean pooling操作,然后输入MLP HEAD进行线性投影得到最终的分类结果。
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Tranformer编码器由multi-head self-attention(MSA)和MLP块的层组成。在每个块之前应用Layernorm(LN),在每个块之后应用残差连接。MLP包含具有GELU非线性的两全连接层。
缺点:
图片进行分patch操作后,只有在每个patch内部有信息交互,而patch和patch之间,只有在最后的MLP HEAD层才有交互,patch之间的信息交互太少。导致最终的translation equivariance(平移等变性)和locality(局部感知性)较弱。只有在大规模的数据集下才能有比较好的效果。比如在 ImageNet-21k,JFT-300M这样的数据量下才可以超越resnet。而在比较小的数据量下,效果却没有resnet好。
还没有评论,来说两句吧...