Residual Attention Network for Image Classification

喜欢ヅ旅行 2022-09-16 03:48 95阅读 0赞

Residual Attention Network for Image Classification

文章目录

      • Residual Attention Network for Image Classification
        • 参考
        • 个人理解
        • Attention Module
          • Soft Mask Branch
          • Spatial Attention and Channel Attention
        • 感觉宣兵夺主的原因
        • 总结

参考

  • ResidualAttentionNetwork-pytorch
  • 原文:Residual Attention Network for Image Classification
  • [论文笔记] Residual Attention Network

个人理解

  • 这篇,怎么说呢,效果很好但是提出的这个Mask Branch部分多少有点宣兵夺主的感觉(指层数,这个会在文章末尾说明)。
  • 基本思想就是利用Mask Branch来将背景过滤掉,然后使模型更加专注于图片中的主要信息。
  • 主要贡献呢,1是上面的这个思想,2是采用ResidualBlock避免直接堆叠带来的性能退化。

Attention Module

  • 大致结构示意图如下:

    • image-20211014200404850
    • 上面的部分是Trunk Branch,即主干部分,就是几个ResidualBlock的堆叠,作用是提取Feature Map。
    • 下面的部分是Mask Branch,即Mask分支,其结构下文会说(Soft Mask Branch部分),作用是遮挡住无效的背景信息。
    • 文中给出{p = 1, t = 2, r = 1},代表前面有p个ResidualBlock,以此类推,后面一节也会用到这个
Soft Mask Branch
  • image-20211014201212296
  • 结构长这样,就是一堆上下采样加残差单元的组合,实际上的实现也差不多,然后文中给出了r=1,即代表121的组合。
  • 需要注意的是中间这段加了个大号残差,有点像CSPDarknet的那个味道(话说回来,这玩意是在这篇后面提出来的~)
  • 值得关注的是这里有多次上/下采样(表现为maxpooling/interpolation),这种操作加上多层conv的叠加会使得感受野增大(conv的叠加增加感受野在VGG这篇里面曾经提到过,记得是3个3x3的感受野=1个7x7,而且又可以减少参数量啥的)

    • image-20211014202038897.png
Spatial Attention and Channel Attention
  • 这个部分略微有点诡异,我试着讲讲个人理解,首先这段讲的应该是下图中的这个函数选什么的问题(Mask Branch尾部)

    • image-20211014202552457
  • f1f2f3分别代表着混合注意力通道注意力空间注意力

    • image-20211014202333191.png
    • f1是sigmoid,f2是L2正则,f3是不知道啥东西,文中说是归一化+sigmoid(exp里面是归一化,类似BN,然后整体就是个sigmoid这样)
    • 然后我的理解是,f1仅仅将每个像素点用sigmoid映射到[0,1],f2就是指考虑考虑通道的L2正则,即可以看成将某个像素点在channel维度上的比重算出来,f3同理,只不过算的是HxW维度的这样
    • 然后效果是第一个最好~
    • image-20211014203153817

感觉宣兵夺主的原因

  • 这里我先给两张图,第一张是Attention-56/92的结构图,第二张是在imageNet上的效果图

    • image-20211014203259607.png
    • image-20211014203407068.png
  • 首先Attention-56的命名是怎么来的?56是Trunk层数,看起来很合理是吧?因为Trunk,主干嘛,而且56层,也不多。然后看看第二张图,和Attention-56对比的是ResNet-152,为什么呢?因为的性能强劲,56就足以和152匹敌?其实不是的~
  • 我找了份代码实现**ResidualAttentionNetwork-pytorch**,研究了下里面的Attention结构是咋回事,取的是这部分代码

    • class AttentionModule_pre(nn.Module):

      1. def __init__(self, in_channels, out_channels, size1, size2, size3):
      2. super(AttentionModule_pre, self).__init__()
      3. self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
      4. self.trunk_branches = nn.Sequential(
      5. ResidualBlock(in_channels, out_channels),
      6. ResidualBlock(in_channels, out_channels)
      7. )
      8. self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      9. self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
      10. self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)
      11. self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      12. self.softmax2_blocks = ResidualBlock(in_channels, out_channels)
      13. self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)
      14. self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      15. self.softmax3_blocks = nn.Sequential(
      16. ResidualBlock(in_channels, out_channels),
      17. ResidualBlock(in_channels, out_channels)
      18. )
      19. self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
      20. self.softmax4_blocks = ResidualBlock(in_channels, out_channels)
      21. self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
      22. self.softmax5_blocks = ResidualBlock(in_channels, out_channels)
      23. self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
      24. self.softmax6_blocks = nn.Sequential(
      25. nn.BatchNorm2d(out_channels),
      26. nn.ReLU(inplace=True),
      27. nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
      28. nn.BatchNorm2d(out_channels),
      29. nn.ReLU(inplace=True),
      30. nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
      31. nn.Sigmoid()
      32. )
      33. self.last_blocks = ResidualBlock(in_channels, out_channels)
      34. def forward(self, x):
      35. x = self.first_residual_blocks(x)
      36. out_trunk = self.trunk_branches(x)
      37. out_mpool1 = self.mpool1(x)
      38. out_softmax1 = self.softmax1_blocks(out_mpool1)
      39. out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
      40. out_mpool2 = self.mpool2(out_softmax1)
      41. out_softmax2 = self.softmax2_blocks(out_mpool2)
      42. out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
      43. out_mpool3 = self.mpool3(out_softmax2)
      44. out_softmax3 = self.softmax3_blocks(out_mpool3)
      45. #
      46. out_interp3 = self.interpolation3(out_softmax3)
      47. # print(out_skip2_connection.data)
      48. # print(out_interp3.data)
      49. out = out_interp3 + out_skip2_connection
      50. out_softmax4 = self.softmax4_blocks(out)
      51. out_interp2 = self.interpolation2(out_softmax4)
      52. out = out_interp2 + out_skip1_connection
      53. out_softmax5 = self.softmax5_blocks(out)
      54. out_interp1 = self.interpolation1(out_softmax5)
      55. out_softmax6 = self.softmax6_blocks(out_interp1)
      56. out = (1 + out_softmax6) * out_trunk
      57. out_last = self.last_blocks(out)
      58. return out_last
  • 光看代码可能看不出来,我根据forward画了图,如下所示:

    • 可以看出,Trunk Branch仅仅2个ResidualBlock就完事了(左边),而Mask Branch就不是这么简单了,足足6个(按照文中的描述,也应该有4个这样子)
    • image-20211014205444874
    • 然后再回到第一张表格,Attention-56是怎么算的呢?

      • 首先有1+1+1+3=6个Residual unit,即6个ResidualBlock =>18层conv
      • 然后Attention Module部分的Trunk是1+2+1=4个ResidualBlock,一共1+1+1=3个Attention Module,故有12个ResidualBlock=>36层conv
      • 然后加上刚开始的conv和最后的fc,刚好就是56层~
    • 那么实际上应该是多少层呢?

      • 我们不按上面的图算,就按照文中的比例算,即用这个{p = 1, t = 2, r = 1}算,那么可以得到Soft Mask Branch部分一共是4个ResidualBlock+2个conv=14层conv,然后加上左右两个p=1个的ResidualBlock的话就是14+2*3=20层,远大于Trunk Branch的12层吼
      • Attention-56实际上最深应该是20*3+18+2 = 80层,当然远小于152,效果还是有的,只是觉得有些宣兵夺主

总结

  • 不管Mask 的设计有多么诡异,这都不妨碍这是一片优秀的论文,为CNN的注意力机制作出了巨大的贡献

发表评论

表情:
评论列表 (有 0 条评论,95人围观)

还没有评论,来说两句吧...

相关阅读