CornerNet源码阅读笔记(一)------解码过程

朴灿烈づ我的快乐病毒、 2022-01-28 01:27 284阅读 0赞

CornerNet和loss、解码相关的函数其实在kp.py和kp_utils.py里面

解码函数如下所示:

  1. def _decode(
  2. tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
  3. K=100, kernel=1, ae_threshold=1, num_dets=1000
  4. ):
  5. batch, cat, height, width = tl_heat.size()
  6. tl_heat = torch.sigmoid(tl_heat)
  7. br_heat = torch.sigmoid(br_heat)
  8. # perform nms on heatmaps
  9. """
  10. 其实就是对概率图进行maxpooling
  11. """
  12. tl_heat = _nms(tl_heat, kernel=kernel)
  13. br_heat = _nms(br_heat, kernel=kernel)
  14. tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
  15. br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
  16. """
  17. tl_ys,tl_xs原本的shape为[batch,K]
  18. """
  19. tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
  20. tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
  21. br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
  22. br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
  23. if tl_regr is not None and br_regr is not None:
  24. tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
  25. tl_regr = tl_regr.view(batch, K, 1, 2)
  26. br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
  27. br_regr = br_regr.view(batch, 1, K, 2)
  28. tl_xs = tl_xs + tl_regr[..., 0]
  29. tl_ys = tl_ys + tl_regr[..., 1]
  30. br_xs = br_xs + br_regr[..., 0]
  31. br_ys = br_ys + br_regr[..., 1]
  32. # all possible boxes based on top k corners (ignoring class)
  33. bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
  34. tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
  35. tl_tag = tl_tag.view(batch, K, 1)
  36. br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
  37. br_tag = br_tag.view(batch, 1, K)
  38. """
  39. [k,1] - [1,k]隐式的都会扩张为[K,K]再相减
  40. dists也为[K,K]
  41. """
  42. dists = torch.abs(tl_tag - br_tag)
  43. tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
  44. br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
  45. scores = (tl_scores + br_scores) / 2
  46. # reject boxes based on classes
  47. tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
  48. br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
  49. cls_inds = (tl_clses != br_clses)
  50. # reject boxes based on distances
  51. dist_inds = (dists > ae_threshold)
  52. # reject boxes based on widths and heights
  53. width_inds = (br_xs < tl_xs)
  54. height_inds = (br_ys < tl_ys)
  55. scores[cls_inds] = -1
  56. scores[dist_inds] = -1
  57. scores[width_inds] = -1
  58. scores[height_inds] = -1
  59. scores = scores.view(batch, -1)
  60. scores, inds = torch.topk(scores, num_dets)
  61. scores = scores.unsqueeze(2)
  62. """
  63. 100*100的点最终匹配为10000个box,然后用类,距离,左上角和右下角的相对位置过滤掉n个点,过滤就体现在scores变为-1,然后传递到inds
  64. 最终通过 bboxes = _gather_feat(bboxes, inds)实现在box上的过滤
  65. """
  66. bboxes = bboxes.view(batch, -1, 4)
  67. bboxes = _gather_feat(bboxes, inds)
  68. clses = tl_clses.contiguous().view(batch, -1, 1)
  69. clses = _gather_feat(clses, inds).float()
  70. tl_scores = tl_scores.contiguous().view(batch, -1, 1)
  71. tl_scores = _gather_feat(tl_scores, inds).float()
  72. br_scores = br_scores.contiguous().view(batch, -1, 1)
  73. br_scores = _gather_feat(br_scores, inds).float()
  74. detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
  75. return detections

一、理论回顾

CornerNet的总共上下两个分支,每个分支三种输出,分别为:各点为角点的概率,偏移,表示两个点是否为同一个物体的embedding得分

左上角分支:

heat map:tl_heat [batch,C,H,W]

offset: tl_regr [batch,2,H,W]

embedding: tl_tag [batch,1,H,W]

右下角分支:

br_heat, tl_heat [batch,C,H,W]

br_regr [batch,2,H,W]

br_tag [batch,1,H,W]

总计tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr 六个输出。

解码流程:

1.分别在左上角和右下角的heat map中选出得分最高的前100个点

2.对这100个点进行逐个匹配,生成100*100个候选框

  1. 通过匹配的两个点是否为同一类,两个点的embedding得分,以及空间位置(左上角必须比右下角小)这几个条件过滤掉大多数bbox,最终留下1000个候选框输出

注:最终生成的1000的候选框应该是要进行nms处理的,但是作者并未将nms操作写入解码函数

发表评论

表情:
评论列表 (有 0 条评论,284人围观)

还没有评论,来说两句吧...

相关阅读

    相关 netty阅读解码

    netty编码我们分以下几点分析: 1、抽象解码器ByteToMessageDecoder 2、基于固定长度解码器分析 3、行解码器分析 4、基于分隔符解码器分析 5

    相关 ConcurrentHashMap阅读笔记

    > HashMap是我们用的比较多的数据结构,但是它在高并发下面进行put操作时,很有可能会引起死循环,这主要是在它扩容的情况下,导致链表头尾可能存在重复节点,而这时候解决的办