深度学习系列(七):CenterNet目标检测算法

简介

CenterNet 是无锚框(Anchor Free)的目标检测算法,文章链接如下:Objects as Points。作者宣称 CenterNet 是一个端到端的、比当时基于锚框目标检测算法更快、更高准确率的算法。CenterNet 在速度和准确率上都要超过 YOLOv3 算法,是一个能很好的实现速度和精度平衡的算法。

CenterNet 算法思路非常的简单,也是我个人比较喜欢的目标检测算法之一,它把目标检测问题看作一个简单的回归问题,直接预测目标的中心点位置和目标框的大小。实际上这个算法的思路和 YOLO v1 的思路也比较像,因此 CenterNet 算法非常简洁,这也是我喜欢这个算法的原因。

CenterNet 目标检测算法结构

图片输入后经过一个主干网络(Backbone)进行特征提取,这个主干网络一般是先下采样再上采样,得到一张高分辨率的特征图。在原论文中,输入图片大小为 512 ×\times 512 大小,最终得到特征图为 4 倍下采样也就是 128 ×\times 128 大小。

然后在这张特征图上接上三个不同的预测分支,第一个分支预测目标中心点热力图(Keypoint Heatmap),大小为:C×128×128C \times 128 \times 128,其中 CC 为预测目标类别数,也就是为每个类别预测一个 128×128128 \times 128 大小的热力图,这张图中元素大小均为 010-1 之间,接近 1 的点表示该点为目标中心。

由于这个热力图是输入图片经过 4 倍下采样得到的,中心点可能不够精确,因此需要第二个预测分支来预测目标中心点的位置偏移(local offset),预测结果图大小为 2×128×1282 \times 128 \times 128,其中每个像素点的两个通道值分别预测 x 和 y 的偏移。

第三个分支预测目标框的高和宽,也就是预测目标框大小(object size),预测结果图大小为 2×128×1282 \times 128 \times 128,其中每个像素点的两个通道值分别预测目标框的宽 w 和高 h。

具体来说,论文中提出了多种不同的 Backbone,包括:Hourglass-104, DLA-34, ResNet-101和ResNet-18。这些 Backbone 都只是为了得到原输入图片 4 倍下采样的特征图而已,拿最简单的 ResNet-18 来说,这个 Backbone 会对输入图片进行 32 倍下采样得到相应的特征图,为了得到高分辨率的特征图,还需要在后面增加 3 次反卷积操作或者可变形卷积(DCNv2)获得相应特征图。

部分实现如下,经过主干网络和上采样得到 64×128×12864 \times 128 \times 128的特征图,然后经过 head 预测出 3 个分支。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class CenterNet(nn.Module):
def __init__(self, num_classes, backbone):
super(CenterNet, self).__init__()
# 512,512,3 -> 16,16,N
self.backbone = backbone
# 16,16,N -> 128,128,64
self.decoder = resnet50_Decoder(self.backbone.out_channels)
# 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
# -> 128, 128, 64 -> 128, 128, 2
# -> 128, 128, 64 -> 128, 128, 2
self.head = resnet50_Head(in_channel=64, num_classes=num_classes)
self._init_weights()

def _init_weights(self):
self.head.cls_head[-1].weight.data.fill_(0)
self.head.cls_head[-1].bias.data.fill_(-2.19)

def forward(self, x):
_1, _2, _3, feat = self.backbone(x)
return self.head(self.decoder(feat))

head部分的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class resnet50_Head(nn.Module):
def __init__(self, num_classes=20, in_channel=256, bn_momentum=0.1):
super(resnet50_Head, self).__init__()
# 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
# -> 128, 128, 64 -> 128, 128, 2
# -> 128, 128, 64 -> 128, 128, 2
# 热力图预测部分
self.cls_head = nn.Sequential(
nn.Conv2d(in_channel, 64,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes,
kernel_size=1, stride=1, padding=0))
# 宽高预测的部分
self.wh_head = nn.Sequential(
nn.Conv2d(in_channel, 64,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(64, 2,
kernel_size=1, stride=1, padding=0))

# 中心点预测的部分
self.reg_head = nn.Sequential(
nn.Conv2d(in_channel, 64,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(inplace=True),
nn.Conv2d(64, 2,
kernel_size=1, stride=1, padding=0))

def forward(self, x):
hm = self.cls_head(x).sigmoid_()
wh = self.wh_head(x)
offset = self.reg_head(x)
return hm, wh, offset

上述实现来源:https://github.com/bubbliiiing/centernet-pytorch.

算法训练

Heatmap 标签生成

为了对网络进行训练(也就是让网络能够输出预期三个分支),需要生成相应的标签,其中热力图 Heatmap 生成方法如下,给定一张打好标签的图片:

  1. 首先计算出真实目标框在 4 倍下采样(128×128128 \times 128下)的坐标,取目标框中心点的坐标 point。
  2. 根据真实目标框的大小计算一个高斯圆的半径 R。
  3. 在对应类别的特征图(每个类别一个 128×128128 \times 128 大小的特征图,也就是预测的 C×128×128C\times128\times128同样大小)上以点 point 为中心,R 为半径填充一个高斯衰减的圆数值。

注:此处参考:https://www.cnblogs.com/silence-cho/p/13955766.html

这样处理的原因是,如果只将目标框中心点对应特征图上的那一个像素点值设置成 1,会导致网络预测中心点时存在一点偏移的时候都会惩罚严重,实际上还有目标中心点偏移分支进行中心点调整,使用高斯核形式生成标签能够更温和一些。这个思想其实是来源于 CornerNet:CornerNet: Detecting Objects as Paired Keypoints

损失函数设计

由于有 3 个预测分支,因此损失函数分为 3 个部分:

Ldet=Lk+λsizeLsize+λoffLoffL_{det} = L_k + λ_{size} L_{size} + λ_{off} L_{off}

第一个部分是目标中心点热力图损失LkL_k

Lk=1Nxyc{(1Y^xyc)αlog(Y^xyc) if Yxyc=1(1Yxyc)β(Y^xyc)αlog(1Y^xyc) otherwise L_k=\frac{-1}{N} \sum_{x y c}\left\{\begin{array}{cl} \left(1-\hat{Y}_{x y c}\right)^\alpha \log \left(\hat{Y}_{x y c}\right) & \text { if } Y_{x y c}=1 \\ \left(1-Y_{x y c}\right)^\beta\left(\hat{Y}_{x y c}\right)^\alpha \log \left(1-\hat{Y}_{x y c}\right) & \text { otherwise } \end{array}\right.

由于 CenterNet 算法的简洁性,模型达到一个比较好的效果需要对损失函数进行较好的设计,上面目标中心点热力图损失函数设计其实是对focal loss:Focal Loss for Dense Object Detection 进行改写,其中 α\alphaβ\beta 都是超参数,论文中取值为 α=2,β=4\alpha=2, \beta=4NN 为图像中目标中心点数量,Y^xyc\hat{Y}_{x y c} 为网络预测值,YxycY_{x y c} 为上面生成的标注值。

上述损失解决难易分两本不均衡问题,对于正样本(Yxyc=1Y_{xyc}=1),如果当前预测值为 Y^xyc=0.1\hat{Y}_{x y c}=0.1,也就是说这是难分样本,计算得到的损失值为:(10.1)2log(0.1))=1.8651-(1-0.1)^2log(0.1)) = 1.8651;而如果当前预测值为 Y^xyc=0.9\hat{Y}_{x y c}=0.9,也就是说这是难分样本,计算得到的损失值为:(10.9)2log(0.9))=0.00105-(1-0.9)^2log(0.9)) = 0.00105。可以看到对于正样本而言,这个损失对难分样本的损失更大,对于易分样本的损失则有惩罚。对于负样本而言也有类似的效果。

第二个部分为中心点偏移预测损失LoffL_{off},是一个L1 loss:

Loff=1NpO^p~(pRp~)L_{o f f}=\frac{1}{N} \sum_p\left|\hat{O}_{\tilde{p}}-\left(\frac{p}{R}-\tilde{p}\right)\right|

其中 O^p~\hat{O}_{\tilde{p}}为预测偏移值,pRp~\frac{p}{R}-\tilde{p}为标签真实值,pp 是目标真实框中心坐标,除以 RR 也就是除以 44 缩放到了 128×128128 \times 128尺度上,减去p~\tilde{p}就得到了偏移部分。

第三个部分为目标框大小预测损失LsizeL_{size},也是简单的L1 loss:

Lsize =1Nk=1NS^pkskL_{\text {size }}=\frac{1}{N} \sum_{k=1}^N\left|\hat{S}_{p_k}-s_k\right|

其中 S^pk\hat{S}_{p_k} 为预测的尺寸,sks_k 为真实尺寸。

注意,第二和第三部分损失只计算正样本的损失

损失部分的一个实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

c_loss = focal_loss(hm, targets_hm)
wh_loss = 0.1 * reg_l1_loss(wh, targets_wh, targets_reg_mask)
off_loss = reg_l1_loss(offset, targets_reg, targets_reg_mask)

loss = c_loss + wh_loss + off_loss

def focal_loss(pred, target):
pred = pred.permute(0, 2, 3, 1)

pos_inds = target.eq(1).float()
neg_inds = target.lt(1).float()

neg_weights = torch.pow(1 - target, 4)
pred = torch.clamp(pred, 1e-6, 1 - 1e-6)

pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()

if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss

def reg_l1_loss(pred, target, mask):
pred = pred.permute(0, 2, 3, 1)
expand_mask = torch.unsqueeze(mask,-1).repeat(1,1,1,2)

loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum')
loss = loss / (mask.sum() + 1e-4)
return loss

总结

整个算法就介绍完了,算法很简单,就像论文题目所说的将目标检测看作点检测。我也参照网上大家的实现写了一份 Pytorch 版本的 CenterNet 实现,同样为了算法部署方便,没有使用变形卷积(DCN),实现的 Backbone 也是 ResNet系列:https://github.com/xiaoqieF/my-centernet

参考文献

  1. Objects as Points
  2. https://www.cnblogs.com/silence-cho/p/13955766.html
  3. https://www.cnblogs.com/silence-cho/p/12987476.html
  4. https://github.com/bubbliiiing/centernet-pytorch
  5. https://github.com/yjh0410/CenterNet-plus