
心已赠人 2022-09-12 13:46 369阅读 0赞


  • 学习前言
  • 源码下载
  • YoloX改进的部分(不完全)
  • YoloX实现思路
    • 一、整体结构解析
    • 二、网络结构解析
      • 1、主干网络CSPDarknet介绍
      • 2、构建FPN特征金字塔进行加强特征提取
      • 3、利用Yolo Head获得预测结果
    • 三、预测结果的解码
      • 1、获得预测框与得分
      • 2、得分筛选与非极大抑制
    • 四、训练部分
      • 1、计算loss所需内容
      • 2、正样本特征点的必要条件
      • 3、SimOTA动态匹配正样本
      • 4、计算Loss
  • 训练自己的YoloX模型
    • 一、数据集的准备
    • 二、数据集的处理
    • 三、开始网络训练
    • 四、训练结果预测







2、分类回归层:Decoupled Head,以前版本的Yolo所用的解耦头是一起的,也就是分类和回归在一个1X1卷积里实现,YoloX认为这给网络的识别带来了不利影响。在YoloX中,Yolo Head被分为了两部分,分别实现,最后预测的时候才整合在一起。


4、Anchor Free:不使用先验框。

5、SimOTA :为不同大小的目标动态匹配正样本。






和之前版本的Yolo类似,整个YoloX可以依然可以分为三个部分,分别是CSPDarknet,FPN以及Yolo Head



Yolo Head是YoloX的分类器与回归器,通过CSPDarknet和FPN,我们已经可以获得三个加强过的有效特征层。每一个特征层都有宽、高和通道数,此时我们可以将特征图看作一个又一个特征点的集合每一个特征点都有通道数个特征。Yolo Head实际上所做的工作就是对特征点进行判断,判断特征点是否有物体与其对应。以前版本的Yolo所用的解耦头是一起的,也就是分类和回归在一个1X1卷积里实现,YoloX认为这给网络的识别带来了不利影响。在YoloX中,Yolo Head被分为了两部分,分别实现,最后预测的时候才整合在一起。

因此,整个YoloX网络所作的工作就是 特征提取-特征加强-预测特征点对应的物体情况




  1. class Bottleneck(nn.Module):
  2. # Standard bottleneck
  3. def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
  4. super().__init__()
  5. hidden_channels = int(out_channels * expansion)
  6. Conv = DWConv if depthwise else BaseConv
  7. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  8. self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
  9. self.use_add = shortcut and in_channels == out_channels
  10. def forward(self, x):
  11. y = self.conv2(self.conv1(x))
  12. if self.use_add:
  13. y = y + x
  14. return y



  1. class CSPLayer(nn.Module):
  2. def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
  3. # ch_in, ch_out, number, shortcut, groups, expansion
  4. super().__init__()
  5. hidden_channels = int(out_channels * expansion) # hidden channels
  6. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  7. self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  8. self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
  9. module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
  10. self.m = nn.Sequential(*module_list)
  11. def forward(self, x):
  12. x_1 = self.conv1(x)
  13. x_2 = self.conv2(x)
  14. x_1 = self.m(x_1)
  15. x = torch.cat((x_1, x_2), dim=1)
  16. return self.conv3(x)


  1. class Focus(nn.Module):
  2. def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
  3. super().__init__()
  4. self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
  5. def forward(self, x):
  6. patch_top_left = x[..., ::2, ::2]
  7. patch_bot_left = x[..., 1::2, ::2]
  8. patch_top_right = x[..., ::2, 1::2]
  9. patch_bot_right = x[..., 1::2, 1::2]
  10. x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
  11. return self.conv(x)

4、使用了SiLU激活函数,SiLU是Sigmoid和ReLU的改进版。SiLU具备无上界有下界、平滑、非单调的特性。SiLU在深层模型上的效果优于 ReLU。可以看做是平滑的ReLU激活函数。
f ( x ) = x ⋅ sigmoid ( x ) f(x) = x · \text{sigmoid}(x) f(x)=x⋅sigmoid(x)

  1. class SiLU(Layer):
  2. def __init__(self, **kwargs):
  3. super(SiLU, self).__init__(**kwargs)
  4. self.supports_masking = True
  5. def call(self, inputs):
  6. return inputs * K.sigmoid(inputs)
  7. def get_config(self):
  8. config = super(SiLU, self).get_config()
  9. return config
  10. def compute_output_shape(self, input_shape):
  11. return input_shape


  1. class SPPBottleneck(nn.Module):
  2. def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
  3. super().__init__()
  4. hidden_channels = in_channels // 2
  5. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
  6. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
  7. conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
  8. self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
  9. def forward(self, x):
  10. x = self.conv1(x)
  11. x = torch.cat([x] + [m(x) for m in self.m], dim=1)
  12. x = self.conv2(x)
  13. return x


  1. import torch
  2. from torch import nn
  3. class SiLU(nn.Module):
  4. @staticmethod
  5. def forward(x):
  6. return x * torch.sigmoid(x)
  7. def get_activation(name="silu", inplace=True):
  8. if name == "silu":
  9. module = SiLU()
  10. elif name == "relu":
  11. module = nn.ReLU(inplace=inplace)
  12. elif name == "lrelu":
  13. module = nn.LeakyReLU(0.1, inplace=inplace)
  14. else:
  15. raise AttributeError("Unsupported act type: {}".format(name))
  16. return module
  17. class Focus(nn.Module):
  18. def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
  19. super().__init__()
  20. self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
  21. def forward(self, x):
  22. patch_top_left = x[..., ::2, ::2]
  23. patch_bot_left = x[..., 1::2, ::2]
  24. patch_top_right = x[..., ::2, 1::2]
  25. patch_bot_right = x[..., 1::2, 1::2]
  26. x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
  27. return self.conv(x)
  28. class BaseConv(nn.Module):
  29. def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
  30. super().__init__()
  31. pad = (ksize - 1) // 2
  32. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
  33. self.bn = nn.BatchNorm2d(out_channels)
  34. self.act = get_activation(act, inplace=True)
  35. def forward(self, x):
  36. return self.act(self.bn(self.conv(x)))
  37. def fuseforward(self, x):
  38. return self.act(self.conv(x))
  39. class DWConv(nn.Module):
  40. def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
  41. super().__init__()
  42. self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
  43. self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
  44. def forward(self, x):
  45. x = self.dconv(x)
  46. return self.pconv(x)
  47. class SPPBottleneck(nn.Module):
  48. def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
  49. super().__init__()
  50. hidden_channels = in_channels // 2
  51. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
  52. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
  53. conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
  54. self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
  55. def forward(self, x):
  56. x = self.conv1(x)
  57. x = torch.cat([x] + [m(x) for m in self.m], dim=1)
  58. x = self.conv2(x)
  59. return x
  60. class Bottleneck(nn.Module):
  61. # Standard bottleneck
  62. def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
  63. super().__init__()
  64. hidden_channels = int(out_channels * expansion)
  65. Conv = DWConv if depthwise else BaseConv
  66. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  67. self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
  68. self.use_add = shortcut and in_channels == out_channels
  69. def forward(self, x):
  70. y = self.conv2(self.conv1(x))
  71. if self.use_add:
  72. y = y + x
  73. return y
  74. class CSPLayer(nn.Module):
  75. def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
  76. # ch_in, ch_out, number, shortcut, groups, expansion
  77. super().__init__()
  78. hidden_channels = int(out_channels * expansion) # hidden channels
  79. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  80. self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  81. self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
  82. module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
  83. self.m = nn.Sequential(*module_list)
  84. def forward(self, x):
  85. x_1 = self.conv1(x)
  86. x_2 = self.conv2(x)
  87. x_1 = self.m(x_1)
  88. x = torch.cat((x_1, x_2), dim=1)
  89. return self.conv3(x)
  90. class CSPDarknet(nn.Module):
  91. def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
  92. super().__init__()
  93. assert out_features, "please provide output features of Darknet"
  94. self.out_features = out_features
  95. Conv = DWConv if depthwise else BaseConv
  96. base_channels = int(wid_mul * 64) # 64
  97. base_depth = max(round(dep_mul * 3), 1) # 3
  98. # stem
  99. self.stem = Focus(3, base_channels, ksize=3, act=act)
  100. # dark2
  101. self.dark2 = nn.Sequential(
  102. Conv(base_channels, base_channels * 2, 3, 2, act=act),
  103. CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
  104. )
  105. # dark3
  106. self.dark3 = nn.Sequential(
  107. Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
  108. CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
  109. )
  110. # dark4
  111. self.dark4 = nn.Sequential(
  112. Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
  113. CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
  114. )
  115. # dark5
  116. self.dark5 = nn.Sequential(
  117. Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
  118. SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
  119. CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
  120. )
  121. def forward(self, x):
  122. outputs = {
  123. }
  124. x = self.stem(x)
  125. outputs["stem"] = x
  126. x = self.dark2(x)
  127. outputs["dark2"] = x
  128. x = self.dark3(x)
  129. outputs["dark3"] = x
  130. x = self.dark4(x)
  131. outputs["dark4"] = x
  132. x = self.dark5(x)
  133. outputs["dark5"] = x
  134. return {
  135. k: v for k, v in outputs.items() if k in self.out_features}




  1. feat3=(20,20,1024)的特征层进行1次1X1卷积调整通道后获得P5,P5进行上采样UmSampling2d后与feat2=(40,40,512)特征层进行结合,然后使用CSPLayer进行特征提取获得P5_upsample,此时获得的特征层为(40,40,512)。
  2. P5_upsample=(40,40,512)的特征层进行1次1X1卷积调整通道后获得P4,P4进行上采样UmSampling2d后与feat1=(80,80,256)特征层进行结合,然后使用CSPLayer进行特征提取P3_out,此时获得的特征层为(80,80,256)。
  3. P3_out=(80,80,256)的特征层进行一次3x3卷积进行下采样,下采样后与P4堆叠,然后使用CSPLayer进行特征提取P4_out,此时获得的特征层为(40,40,512)。
  4. P4_out=(40,40,512)的特征层进行一次3x3卷积进行下采样,下采样后与P5堆叠,然后使用CSPLayer进行特征提取P5_out,此时获得的特征层为(20,20,1024)。


  1. class YOLOPAFPN(nn.Module):
  2. def __init__(self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[256, 512, 1024], depthwise=False, act="silu"):
  3. super().__init__()
  4. self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
  5. self.in_features = in_features
  6. self.in_channels = in_channels
  7. Conv = DWConv if depthwise else BaseConv
  8. self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
  9. self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act)
  10. self.C3_p4 = CSPLayer(
  11. int(2 * in_channels[1] * width),
  12. int(in_channels[1] * width),
  13. round(3 * depth),
  14. False,
  15. depthwise=depthwise,
  16. act=act,
  17. )
  18. self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act)
  19. self.C3_p3 = CSPLayer(
  20. int(2 * in_channels[0] * width),
  21. int(in_channels[0] * width),
  22. round(3 * depth),
  23. False,
  24. depthwise=depthwise,
  25. act=act,
  26. )
  27. self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act)
  28. self.C3_n3 = CSPLayer(
  29. int(2 * in_channels[0] * width),
  30. int(in_channels[1] * width),
  31. round(3 * depth),
  32. False,
  33. depthwise=depthwise,
  34. act=act,
  35. )
  36. self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act)
  37. self.C3_n4 = CSPLayer(
  38. int(2 * in_channels[1] * width),
  39. int(in_channels[2] * width),
  40. round(3 * depth),
  41. False,
  42. depthwise=depthwise,
  43. act=act,
  44. )
  45. def forward(self, input):
  46. out_features = self.backbone.forward(input)
  47. features = [out_features[f] for f in self.in_features]
  48. [feat1, feat2, feat3] = features
  49. P5 = self.lateral_conv0(feat3)
  50. P5_upsample = self.upsample(P5)
  51. P5_upsample = torch.cat([P5_upsample, feat2], 1)
  52. P5_upsample = self.C3_p4(P5_upsample)
  53. P4 = self.reduce_conv1(P5_upsample)
  54. P4_upsample = self.upsample(P4)
  55. P4_upsample = torch.cat([P4_upsample, feat1], 1)
  56. P3_out = self.C3_p3(P4_upsample)
  57. P3_downsample = self.bu_conv2(P3_out)
  58. P3_downsample = torch.cat([P3_downsample, P4], 1)
  59. P4_out = self.C3_n3(P3_downsample)
  60. P4_downsample = self.bu_conv1(P4_out)
  61. P4_downsample = torch.cat([P4_downsample, P5], 1)
  62. P5_out = self.C3_n4(P4_downsample)
  63. return (P3_out, P4_out, P5_out)

3、利用Yolo Head获得预测结果

利用FPN特征金字塔,我们可以获得三个加强特征,这三个加强特征的shape分别为(20,20,1024)、(40,40,512)、(80,80,256),然后我们利用这三个shape的特征层传入Yolo Head获得预测结果。

YoloX中的YoloHead与之前版本的YoloHead不同。以前版本的Yolo所用的解耦头是一起的,也就是分类和回归在一个1X1卷积里实现,YoloX认为这给网络的识别带来了不利影响。在YoloX中,Yolo Head被分为了两部分,分别实现,最后预测的时候才整合在一起。


  1. class YOLOXHead(nn.Module):
  2. def __init__(self, num_classes, width=1.0, strides=[8, 16, 32], in_channels=[256, 512, 1024], act="silu", depthwise=False,):
  3. super().__init__()
  4. self.n_anchors = 1
  5. self.num_classes = num_classes
  6. self.cls_convs = nn.ModuleList()
  7. self.reg_convs = nn.ModuleList()
  8. self.cls_preds = nn.ModuleList()
  9. self.reg_preds = nn.ModuleList()
  10. self.obj_preds = nn.ModuleList()
  11. self.stems = nn.ModuleList()
  12. Conv = DWConv if depthwise else BaseConv
  13. for i in range(len(in_channels)):
  14. self.stems.append(BaseConv(in_channels=int(in_channels[i] * width), out_channels=int(256 * width), ksize=1, stride=1, act=act))
  15. self.cls_convs.append(nn.Sequential(*[
  16. Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act),
  17. Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act),
  18. ]))
  19. self.reg_convs.append(nn.Sequential(*[
  20. Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act),
  21. Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act)
  22. ]))
  23. self.cls_preds.append(
  24. nn.Conv2d(in_channels=int(256 * width),out_channels=self.n_anchors * self.num_classes,kernel_size=1,stride=1,padding=0)
  25. )
  26. self.reg_preds.append(
  27. nn.Conv2d(in_channels=int(256 * width),out_channels=4,kernel_size=1,stride=1,padding=0)
  28. )
  29. self.obj_preds.append(
  30. nn.Conv2d(in_channels=int(256 * width),out_channels=self.n_anchors * 1,kernel_size=1,stride=1,padding=0)
  31. )
  32. def forward(self, inputs):
  33. outputs = []
  34. for k, x in enumerate(inputs):
  35. x = self.stems[k](x)
  36. cls_feat = self.cls_convs[k](x)
  37. cls_output = self.cls_preds[k](cls_feat)
  38. reg_feat = self.reg_convs[k](x)
  39. reg_output = self.reg_preds[k](reg_feat)
  40. obj_output = self.obj_preds[k](reg_feat)
  41. output = torch.cat([reg_output, obj_output, cls_output], 1)
  42. outputs.append(output)
  43. return outputs









  1. def decode_outputs(outputs, input_shape):
  2. grids = []
  3. strides = []
  4. hw = [x.shape[-2:] for x in outputs]
  5. outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
  6. outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
  7. for h, w in hw:
  8. #---------------------------#
  9. # 根据特征层生成网格点
  10. #---------------------------#
  11. grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
  12. grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
  13. shape = grid.shape[:2]
  14. grids.append(grid)
  15. strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
  16. #---------------------------#
  17. # 将网格点堆叠到一起
  18. #---------------------------#
  19. grids = torch.cat(grids, dim=1).type(outputs.type())
  20. strides = torch.cat(strides, dim=1).type(outputs.type())
  21. #------------------------#
  22. # 根据网格点进行解码
  23. #------------------------#
  24. outputs[..., :2] = (outputs[..., :2] + grids) * strides
  25. outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
  26. #-----------------#
  27. # 归一化
  28. #-----------------#
  29. outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
  30. outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
  31. return outputs







  1. def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
  2. #----------------------------------------------------------#
  3. # 将预测结果的格式转换成左上角右下角的格式。
  4. # prediction [batch_size, num_anchors, 85]
  5. #----------------------------------------------------------#
  6. box_corner = prediction.new(prediction.shape)
  7. box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
  8. box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
  9. box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
  10. box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
  11. prediction[:, :, :4] = box_corner[:, :, :4]
  12. output = [None for _ in range(len(prediction))]
  13. for i, image_pred in enumerate(prediction):
  14. #----------------------------------------------------------#
  15. # 对种类预测部分取max。
  16. # class_conf [num_anchors, 1] 种类置信度
  17. # class_pred [num_anchors, 1] 种类
  18. #----------------------------------------------------------#
  19. class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
  20. #----------------------------------------------------------#
  21. # 利用置信度进行第一轮筛选
  22. #----------------------------------------------------------#
  23. conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
  24. if not image_pred.size(0):
  25. continue
  26. #-------------------------------------------------------------------------#
  27. # detections [num_anchors, 7]
  28. # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
  29. #-------------------------------------------------------------------------#
  30. detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
  31. detections = detections[conf_mask]
  32. nms_out_index = boxes.batched_nms(
  33. detections[:, :4],
  34. detections[:, 4] * detections[:, 5],
  35. detections[:, 6],
  36. nms_thres,
  37. )
  38. output[i] = detections[nms_out_index]
  39. if output[i] is not None:
  40. output[i] = output[i].cpu().numpy()
  41. box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
  42. output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
  43. return output









  1. def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
  2. #-------------------------------------------------------#
  3. # expanded_strides_per_image [n_anchors_all]
  4. # x_centers_per_image [num_gt, n_anchors_all]
  5. # x_centers_per_image [num_gt, n_anchors_all]
  6. #-------------------------------------------------------#
  7. expanded_strides_per_image = expanded_strides[0]
  8. x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  9. y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  10. #-------------------------------------------------------#
  11. # gt_bboxes_per_image_x [num_gt, n_anchors_all]
  12. #-------------------------------------------------------#
  13. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  14. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  15. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  16. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  17. #-------------------------------------------------------#
  18. # bbox_deltas [num_gt, n_anchors_all, 4]
  19. #-------------------------------------------------------#
  20. b_l = x_centers_per_image - gt_bboxes_per_image_l
  21. b_r = gt_bboxes_per_image_r - x_centers_per_image
  22. b_t = y_centers_per_image - gt_bboxes_per_image_t
  23. b_b = gt_bboxes_per_image_b - y_centers_per_image
  24. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  25. #-------------------------------------------------------#
  26. # is_in_boxes [num_gt, n_anchors_all]
  27. # is_in_boxes_all [n_anchors_all]
  28. #-------------------------------------------------------#
  29. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  30. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  31. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  32. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  33. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  34. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  35. #-------------------------------------------------------#
  36. # center_deltas [num_gt, n_anchors_all, 4]
  37. #-------------------------------------------------------#
  38. c_l = x_centers_per_image - gt_bboxes_per_image_l
  39. c_r = gt_bboxes_per_image_r - x_centers_per_image
  40. c_t = y_centers_per_image - gt_bboxes_per_image_t
  41. c_b = gt_bboxes_per_image_b - y_centers_per_image
  42. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  43. #-------------------------------------------------------#
  44. # is_in_centers [num_gt, n_anchors_all]
  45. # is_in_centers_all [n_anchors_all]
  46. #-------------------------------------------------------#
  47. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  48. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  49. #-------------------------------------------------------#
  50. # is_in_boxes_anchor [n_anchors_all]
  51. # is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
  52. #-------------------------------------------------------#
  53. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  54. is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  55. return is_in_boxes_anchor, is_in_boxes_and_center









  1. @torch.no_grad()
  2. def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
  3. #-------------------------------------------------------#
  4. # fg_mask [n_anchors_all]
  5. # is_in_boxes_and_center [num_gt, len(fg_mask)]
  6. #-------------------------------------------------------#
  7. fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
  8. #-------------------------------------------------------#
  9. # fg_mask [n_anchors_all]
  10. # bboxes_preds_per_image [fg_mask, 4]
  11. # cls_preds_ [fg_mask, num_classes]
  12. # obj_preds_ [fg_mask, 1]
  13. #-------------------------------------------------------#
  14. bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
  15. cls_preds_ = cls_preds_per_image[fg_mask]
  16. obj_preds_ = obj_preds_per_image[fg_mask]
  17. num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
  18. #-------------------------------------------------------#
  19. # pair_wise_ious [num_gt, fg_mask]
  20. #-------------------------------------------------------#
  21. pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
  22. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
  23. #-------------------------------------------------------#
  24. # cls_preds_ [num_gt, fg_mask, num_classes]
  25. # gt_cls_per_image [num_gt, fg_mask, num_classes]
  26. #-------------------------------------------------------#
  27. cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  28. gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
  29. pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
  30. del cls_preds_
  31. cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
  32. num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
  33. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  34. return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
  35. def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
  36. if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
  37. raise IndexError
  38. if xyxy:
  39. tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
  40. br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
  41. area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
  42. area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
  43. else:
  44. tl = torch.max(
  45. (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
  46. (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
  47. )
  48. br = torch.min(
  49. (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
  50. (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
  51. )
  52. area_a = torch.prod(bboxes_a[:, 2:], 1)
  53. area_b = torch.prod(bboxes_b[:, 2:], 1)
  54. en = (tl < br).type(tl.type()).prod(dim=2)
  55. area_i = torch.prod(br - tl, 2) * en
  56. return area_i / (area_a[:, None] + area_b - area_i)
  57. def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
  58. #-------------------------------------------------------#
  59. # expanded_strides_per_image [n_anchors_all]
  60. # x_centers_per_image [num_gt, n_anchors_all]
  61. # x_centers_per_image [num_gt, n_anchors_all]
  62. #-------------------------------------------------------#
  63. expanded_strides_per_image = expanded_strides[0]
  64. x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  65. y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  66. #-------------------------------------------------------#
  67. # gt_bboxes_per_image_x [num_gt, n_anchors_all]
  68. #-------------------------------------------------------#
  69. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  70. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  71. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  72. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  73. #-------------------------------------------------------#
  74. # bbox_deltas [num_gt, n_anchors_all, 4]
  75. #-------------------------------------------------------#
  76. b_l = x_centers_per_image - gt_bboxes_per_image_l
  77. b_r = gt_bboxes_per_image_r - x_centers_per_image
  78. b_t = y_centers_per_image - gt_bboxes_per_image_t
  79. b_b = gt_bboxes_per_image_b - y_centers_per_image
  80. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  81. #-------------------------------------------------------#
  82. # is_in_boxes [num_gt, n_anchors_all]
  83. # is_in_boxes_all [n_anchors_all]
  84. #-------------------------------------------------------#
  85. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  86. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  87. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  88. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  89. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  90. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  91. #-------------------------------------------------------#
  92. # center_deltas [num_gt, n_anchors_all, 4]
  93. #-------------------------------------------------------#
  94. c_l = x_centers_per_image - gt_bboxes_per_image_l
  95. c_r = gt_bboxes_per_image_r - x_centers_per_image
  96. c_t = y_centers_per_image - gt_bboxes_per_image_t
  97. c_b = gt_bboxes_per_image_b - y_centers_per_image
  98. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  99. #-------------------------------------------------------#
  100. # is_in_centers [num_gt, n_anchors_all]
  101. # is_in_centers_all [n_anchors_all]
  102. #-------------------------------------------------------#
  103. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  104. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  105. #-------------------------------------------------------#
  106. # is_in_boxes_anchor [n_anchors_all]
  107. # is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
  108. #-------------------------------------------------------#
  109. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  110. is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  111. return is_in_boxes_anchor, is_in_boxes_and_center
  112. def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
  113. #-------------------------------------------------------#
  114. # cost [num_gt, fg_mask]
  115. # pair_wise_ious [num_gt, fg_mask]
  116. # gt_classes [num_gt]
  117. # fg_mask [n_anchors_all]
  118. # matching_matrix [num_gt, fg_mask]
  119. #-------------------------------------------------------#
  120. matching_matrix = torch.zeros_like(cost)
  121. #------------------------------------------------------------#
  122. # 选取iou最大的n_candidate_k个点
  123. # 然后求和,判断应该有多少点用于该框预测
  124. # topk_ious [num_gt, n_candidate_k]
  125. # dynamic_ks [num_gt]
  126. # matching_matrix [num_gt, fg_mask]
  127. #------------------------------------------------------------#
  128. n_candidate_k = min(10, pair_wise_ious.size(1))
  129. topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
  130. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  131. for gt_idx in range(num_gt):
  132. #------------------------------------------------------------#
  133. # 给每个真实框选取最小的动态k个点
  134. #------------------------------------------------------------#
  135. _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
  136. matching_matrix[gt_idx][pos_idx] = 1.0
  137. del topk_ious, dynamic_ks, pos_idx
  138. #------------------------------------------------------------#
  139. # anchor_matching_gt [fg_mask]
  140. #------------------------------------------------------------#
  141. anchor_matching_gt = matching_matrix.sum(0)
  142. if (anchor_matching_gt > 1).sum() > 0:
  143. #------------------------------------------------------------#
  144. # 当某一个特征点指向多个真实框的时候
  145. # 选取cost最小的真实框。
  146. #------------------------------------------------------------#
  147. _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
  148. matching_matrix[:, anchor_matching_gt > 1] *= 0.0
  149. matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
  150. #------------------------------------------------------------#
  151. # fg_mask_inboxes [fg_mask]
  152. # num_fg为正样本的特征点个数
  153. #------------------------------------------------------------#
  154. fg_mask_inboxes = matching_matrix.sum(0) > 0.0
  155. num_fg = fg_mask_inboxes.sum().item()
  156. #------------------------------------------------------------#
  157. # 对fg_mask进行更新
  158. #------------------------------------------------------------#
  159. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  160. #------------------------------------------------------------#
  161. # 获得特征点对应的物品种类
  162. #------------------------------------------------------------#
  163. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  164. gt_matched_classes = gt_classes[matched_gt_inds]
  165. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
  166. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds



  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class IOUloss(nn.Module):
  5. def __init__(self, reduction="none", loss_type="iou"):
  6. super(IOUloss, self).__init__()
  7. self.reduction = reduction
  8. self.loss_type = loss_type
  9. def forward(self, pred, target):
  10. assert pred.shape[0] == target.shape[0]
  11. pred = pred.view(-1, 4)
  12. target = target.view(-1, 4)
  13. tl = torch.max(
  14. (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
  15. )
  16. br = torch.min(
  17. (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
  18. )
  19. area_p = torch.prod(pred[:, 2:], 1)
  20. area_g = torch.prod(target[:, 2:], 1)
  21. en = (tl < br).type(tl.type()).prod(dim=1)
  22. area_i = torch.prod(br - tl, 1) * en
  23. iou = (area_i) / (area_p + area_g - area_i + 1e-16)
  24. if self.loss_type == "iou":
  25. loss = 1 - iou ** 2
  26. elif self.loss_type == "giou":
  27. c_tl = torch.min(
  28. (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
  29. )
  30. c_br = torch.max(
  31. (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
  32. )
  33. area_c = torch.prod(c_br - c_tl, 1)
  34. giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
  35. loss = 1 - giou.clamp(min=-1.0, max=1.0)
  36. if self.reduction == "mean":
  37. loss = loss.mean()
  38. elif self.reduction == "sum":
  39. loss = loss.sum()
  40. return loss
  41. class YOLOLoss(nn.Module):
  42. def __init__(self, num_classes, strides=[8, 16, 32]):
  43. super().__init__()
  44. self.num_classes = num_classes
  45. self.strides = strides
  46. self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
  47. self.iou_loss = IOUloss(reduction="none")
  48. self.grids = [torch.zeros(1)] * len(strides)
  49. def forward(self, inputs, labels=None):
  50. outputs = []
  51. x_shifts = []
  52. y_shifts = []
  53. expanded_strides = []
  54. #-----------------------------------------------#
  55. # inputs [[batch_size, num_classes + 5, 20, 20]
  56. # [batch_size, num_classes + 5, 40, 40]
  57. # [batch_size, num_classes + 5, 80, 80]]
  58. # outputs [[batch_size, 400, num_classes + 5]
  59. # [batch_size, 1600, num_classes + 5]
  60. # [batch_size, 6400, num_classes + 5]]
  61. # x_shifts [[batch_size, 400]
  62. # [batch_size, 1600]
  63. # [batch_size, 6400]]
  64. #-----------------------------------------------#
  65. for k, (stride, output) in enumerate(zip(self.strides, inputs)):
  66. output, grid = self.get_output_and_grid(output, k, stride)
  67. x_shifts.append(grid[:, :, 0])
  68. y_shifts.append(grid[:, :, 1])
  69. expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
  70. outputs.append(output)
  71. return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
  72. def get_output_and_grid(self, output, k, stride):
  73. grid = self.grids[k]
  74. hsize, wsize = output.shape[-2:]
  75. if grid.shape[2:4] != output.shape[2:4]:
  76. yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
  77. grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
  78. self.grids[k] = grid
  79. grid = grid.view(1, -1, 2)
  80. output = output.flatten(start_dim=2).permute(0, 2, 1)
  81. output[..., :2] = (output[..., :2] + grid) * stride
  82. output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
  83. return output, grid
  84. def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
  85. #-----------------------------------------------#
  86. # [batch, n_anchors_all, 4]
  87. #-----------------------------------------------#
  88. bbox_preds = outputs[:, :, :4]
  89. #-----------------------------------------------#
  90. # [batch, n_anchors_all, 1]
  91. #-----------------------------------------------#
  92. obj_preds = outputs[:, :, 4:5]
  93. #-----------------------------------------------#
  94. # [batch, n_anchors_all, n_cls]
  95. #-----------------------------------------------#
  96. cls_preds = outputs[:, :, 5:]
  97. total_num_anchors = outputs.shape[1]
  98. #-----------------------------------------------#
  99. # x_shifts [1, n_anchors_all]
  100. # y_shifts [1, n_anchors_all]
  101. # expanded_strides [1, n_anchors_all]
  102. #-----------------------------------------------#
  103. x_shifts = torch.cat(x_shifts, 1)
  104. y_shifts = torch.cat(y_shifts, 1)
  105. expanded_strides = torch.cat(expanded_strides, 1)
  106. cls_targets = []
  107. reg_targets = []
  108. obj_targets = []
  109. fg_masks = []
  110. num_fg = 0.0
  111. for batch_idx in range(outputs.shape[0]):
  112. num_gt = len(labels[batch_idx])
  113. if num_gt == 0:
  114. cls_target = outputs.new_zeros((0, self.num_classes))
  115. reg_target = outputs.new_zeros((0, 4))
  116. obj_target = outputs.new_zeros((total_num_anchors, 1))
  117. fg_mask = outputs.new_zeros(total_num_anchors).bool()
  118. else:
  119. #-----------------------------------------------#
  120. # gt_bboxes_per_image [num_gt, num_classes]
  121. # gt_classes [num_gt]
  122. # bboxes_preds_per_image [n_anchors_all, 4]
  123. # cls_preds_per_image [n_anchors_all, num_classes]
  124. # obj_preds_per_image [n_anchors_all, 1]
  125. #-----------------------------------------------#
  126. gt_bboxes_per_image = labels[batch_idx][..., :4]
  127. gt_classes = labels[batch_idx][..., 4]
  128. bboxes_preds_per_image = bbox_preds[batch_idx]
  129. cls_preds_per_image = cls_preds[batch_idx]
  130. obj_preds_per_image = obj_preds[batch_idx]
  131. gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
  132. num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
  133. expanded_strides, x_shifts, y_shifts,
  134. )
  135. torch.cuda.empty_cache()
  136. num_fg += num_fg_img
  137. cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
  138. obj_target = fg_mask.unsqueeze(-1)
  139. reg_target = gt_bboxes_per_image[matched_gt_inds]
  140. cls_targets.append(cls_target)
  141. reg_targets.append(reg_target)
  142. obj_targets.append(obj_target.type(cls_target.type()))
  143. fg_masks.append(fg_mask)
  144. cls_targets = torch.cat(cls_targets, 0)
  145. reg_targets = torch.cat(reg_targets, 0)
  146. obj_targets = torch.cat(obj_targets, 0)
  147. fg_masks = torch.cat(fg_masks, 0)
  148. num_fg = max(num_fg, 1)
  149. loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
  150. loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
  151. loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
  152. reg_weight = 5.0
  153. loss = reg_weight * loss_iou + loss_obj + loss_cls
  154. return loss / num_fg
  155. @torch.no_grad()
  156. def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
  157. #-------------------------------------------------------#
  158. # fg_mask [n_anchors_all]
  159. # is_in_boxes_and_center [num_gt, len(fg_mask)]
  160. #-------------------------------------------------------#
  161. fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
  162. #-------------------------------------------------------#
  163. # fg_mask [n_anchors_all]
  164. # bboxes_preds_per_image [fg_mask, 4]
  165. # cls_preds_ [fg_mask, num_classes]
  166. # obj_preds_ [fg_mask, 1]
  167. #-------------------------------------------------------#
  168. bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
  169. cls_preds_ = cls_preds_per_image[fg_mask]
  170. obj_preds_ = obj_preds_per_image[fg_mask]
  171. num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
  172. #-------------------------------------------------------#
  173. # pair_wise_ious [num_gt, fg_mask]
  174. #-------------------------------------------------------#
  175. pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
  176. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
  177. #-------------------------------------------------------#
  178. # cls_preds_ [num_gt, fg_mask, num_classes]
  179. # gt_cls_per_image [num_gt, fg_mask, num_classes]
  180. #-------------------------------------------------------#
  181. cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  182. gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
  183. pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
  184. del cls_preds_
  185. cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
  186. num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
  187. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  188. return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
  189. def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
  190. if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
  191. raise IndexError
  192. if xyxy:
  193. tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
  194. br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
  195. area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
  196. area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
  197. else:
  198. tl = torch.max(
  199. (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
  200. (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
  201. )
  202. br = torch.min(
  203. (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
  204. (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
  205. )
  206. area_a = torch.prod(bboxes_a[:, 2:], 1)
  207. area_b = torch.prod(bboxes_b[:, 2:], 1)
  208. en = (tl < br).type(tl.type()).prod(dim=2)
  209. area_i = torch.prod(br - tl, 2) * en
  210. return area_i / (area_a[:, None] + area_b - area_i)
  211. def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
  212. #-------------------------------------------------------#
  213. # expanded_strides_per_image [n_anchors_all]
  214. # x_centers_per_image [num_gt, n_anchors_all]
  215. # x_centers_per_image [num_gt, n_anchors_all]
  216. #-------------------------------------------------------#
  217. expanded_strides_per_image = expanded_strides[0]
  218. x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  219. y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
  220. #-------------------------------------------------------#
  221. # gt_bboxes_per_image_x [num_gt, n_anchors_all]
  222. #-------------------------------------------------------#
  223. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  224. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
  225. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  226. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
  227. #-------------------------------------------------------#
  228. # bbox_deltas [num_gt, n_anchors_all, 4]
  229. #-------------------------------------------------------#
  230. b_l = x_centers_per_image - gt_bboxes_per_image_l
  231. b_r = gt_bboxes_per_image_r - x_centers_per_image
  232. b_t = y_centers_per_image - gt_bboxes_per_image_t
  233. b_b = gt_bboxes_per_image_b - y_centers_per_image
  234. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  235. #-------------------------------------------------------#
  236. # is_in_boxes [num_gt, n_anchors_all]
  237. # is_in_boxes_all [n_anchors_all]
  238. #-------------------------------------------------------#
  239. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  240. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  241. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  242. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  243. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
  244. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
  245. #-------------------------------------------------------#
  246. # center_deltas [num_gt, n_anchors_all, 4]
  247. #-------------------------------------------------------#
  248. c_l = x_centers_per_image - gt_bboxes_per_image_l
  249. c_r = gt_bboxes_per_image_r - x_centers_per_image
  250. c_t = y_centers_per_image - gt_bboxes_per_image_t
  251. c_b = gt_bboxes_per_image_b - y_centers_per_image
  252. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  253. #-------------------------------------------------------#
  254. # is_in_centers [num_gt, n_anchors_all]
  255. # is_in_centers_all [n_anchors_all]
  256. #-------------------------------------------------------#
  257. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  258. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  259. #-------------------------------------------------------#
  260. # is_in_boxes_anchor [n_anchors_all]
  261. # is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
  262. #-------------------------------------------------------#
  263. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  264. is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  265. return is_in_boxes_anchor, is_in_boxes_and_center
  266. def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
  267. #-------------------------------------------------------#
  268. # cost [num_gt, fg_mask]
  269. # pair_wise_ious [num_gt, fg_mask]
  270. # gt_classes [num_gt]
  271. # fg_mask [n_anchors_all]
  272. # matching_matrix [num_gt, fg_mask]
  273. #-------------------------------------------------------#
  274. matching_matrix = torch.zeros_like(cost)
  275. #------------------------------------------------------------#
  276. # 选取iou最大的n_candidate_k个点
  277. # 然后求和,判断应该有多少点用于该框预测
  278. # topk_ious [num_gt, n_candidate_k]
  279. # dynamic_ks [num_gt]
  280. # matching_matrix [num_gt, fg_mask]
  281. #------------------------------------------------------------#
  282. n_candidate_k = min(10, pair_wise_ious.size(1))
  283. topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
  284. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  285. for gt_idx in range(num_gt):
  286. #------------------------------------------------------------#
  287. # 给每个真实框选取最小的动态k个点
  288. #------------------------------------------------------------#
  289. _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
  290. matching_matrix[gt_idx][pos_idx] = 1.0
  291. del topk_ious, dynamic_ks, pos_idx
  292. #------------------------------------------------------------#
  293. # anchor_matching_gt [fg_mask]
  294. #------------------------------------------------------------#
  295. anchor_matching_gt = matching_matrix.sum(0)
  296. if (anchor_matching_gt > 1).sum() > 0:
  297. #------------------------------------------------------------#
  298. # 当某一个特征点指向多个真实框的时候
  299. # 选取cost最小的真实框。
  300. #------------------------------------------------------------#
  301. _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
  302. matching_matrix[:, anchor_matching_gt > 1] *= 0.0
  303. matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
  304. #------------------------------------------------------------#
  305. # fg_mask_inboxes [fg_mask]
  306. # num_fg为正样本的特征点个数
  307. #------------------------------------------------------------#
  308. fg_mask_inboxes = matching_matrix.sum(0) > 0.0
  309. num_fg = fg_mask_inboxes.sum().item()
  310. #------------------------------------------------------------#
  311. # 对fg_mask进行更新
  312. #------------------------------------------------------------#
  313. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  314. #------------------------------------------------------------#
  315. # 获得特征点对应的物品种类
  316. #------------------------------------------------------------#
  317. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  318. gt_matched_classes = gt_classes[matched_gt_inds]
  319. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
  320. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds









  1. '''
  2. annotation_mode用于指定该文件运行时计算的内容
  3. annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
  4. annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
  5. annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
  6. '''
  7. annotation_mode = 0
  8. '''
  9. 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
  10. 与训练和预测所用的classes_path一致即可
  11. 如果生成的2007_train.txt里面没有目标信息
  12. 那么就是因为classes没有设定正确
  13. 仅在annotation_mode为0和2的时候有效
  14. '''
  15. classes_path = 'model_data/voc_classes.txt'
  16. '''
  17. trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
  18. train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
  19. 仅在annotation_mode为0和1的时候有效
  20. '''
  21. trainval_percent = 0.9
  22. train_percent = 0.9
  23. '''
  24. 指向VOC数据集所在的文件夹
  25. 默认指向根目录下的VOC数据集
  26. '''
  27. VOCdevkit_path = 'VOCdevkit'





  1. #-------------------------------#
  2. # 是否使用Cuda
  3. # 没有GPU可以设置成False
  4. #-------------------------------#
  5. Cuda = True
  6. #--------------------------------------------------------#
  7. # 训练前一定要修改classes_path,使其对应自己的数据集
  8. #--------------------------------------------------------#
  9. classes_path = 'model_data/voc_classes.txt'
  10. #------------------------------------------------------------------------------------------------------#
  11. # 权值文件请看README,百度网盘下载。数据的预训练权重对不同数据集是通用的,因为特征是通用的
  12. # 预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。
  13. # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配
  14. # 如果想要断点续练就将model_path设置成logs文件夹下已经训练的权值文件。
  15. #------------------------------------------------------------------------------------------------------#
  16. model_path = 'model_data/yolox_s.pth'
  17. #---------------------------------------------------------------------#
  18. # 所使用的YoloX的版本。s、m、l、x
  19. #---------------------------------------------------------------------#
  20. phi = 's'
  21. #------------------------------------------------------#
  22. # 输入的shape大小,一定要是32的倍数
  23. #------------------------------------------------------#
  24. input_shape = [640, 640]
  25. #------------------------------------------------------------------------------------------------------------#
  26. # YoloX的tricks应用
  27. # mosaic 马赛克数据增强 True or False
  28. # YOLOX作者强调要在训练结束前的N个epoch关掉Mosaic。因为Mosaic生成的训练图片,远远脱离自然图片的真实分布。
  29. # 并且Mosaic大量的crop操作会带来很多不准确的标注框,本代码自动会在前90%个epoch使用mosaic,后面不使用。
  30. # Cosine_scheduler 余弦退火学习率 True or False
  31. #------------------------------------------------------------------------------------------------------------#
  32. mosaic = False
  33. Cosine_scheduler = False
  34. #----------------------------------------------------#
  35. # 训练分为两个阶段,分别是冻结阶段和解冻阶段。
  36. # 显存不足与数据集大小无关,提示显存不足请调小batch_size。
  37. # 受到BatchNorm层影响,batch_size最小为2,不能为1。
  38. #----------------------------------------------------#
  39. #----------------------------------------------------#
  40. # 冻结阶段训练参数
  41. # 此时模型的主干被冻结了,特征提取网络不发生改变
  42. # 占用的显存较小,仅对网络进行微调
  43. #----------------------------------------------------#
  44. Init_Epoch = 0
  45. Freeze_Epoch = 50
  46. Freeze_batch_size = 8
  47. Freeze_lr = 1e-3
  48. #----------------------------------------------------#
  49. # 解冻阶段训练参数
  50. # 此时模型的主干不被冻结了,特征提取网络会发生改变
  51. # 占用的显存较大,网络所有的参数都会发生改变
  52. #----------------------------------------------------#
  53. UnFreeze_Epoch = 100
  54. Unfreeze_batch_size = 4
  55. Unfreeze_lr = 1e-4
  56. #------------------------------------------------------#
  57. # 是否进行冻结训练,默认先冻结主干训练后解冻训练。
  58. #------------------------------------------------------#
  59. Freeze_Train = True
  60. #------------------------------------------------------#
  61. # 用于设置是否使用多线程读取数据
  62. # 开启后会加快数据读取速度,但是会占用更多内存
  63. # 内存较小的电脑可以设置为2或者0
  64. #------------------------------------------------------#
  65. num_workers = 4
  66. #----------------------------------------------------#
  67. # 获得图片路径和标签
  68. #----------------------------------------------------#
  69. train_annotation_path = '2007_train.txt'
  70. val_annotation_path = '2007_val.txt'






评论列表 (有 0 条评论,369人围观)

