简介
CenterNet 是无锚框(Anchor Free)的目标检测算法,文章链接如下:Objects as Points。作者宣称 CenterNet 是一个端到端的、比当时基于锚框目标检测算法更快、更高准确率的算法。CenterNet 在速度和准确率上都要超过 YOLOv3 算法,是一个能很好的实现速度和精度平衡的算法。
CenterNet 算法思路非常的简单,也是我个人比较喜欢的目标检测算法之一,它把目标检测问题看作一个简单的回归问题,直接预测目标的中心点位置和目标框的大小。实际上这个算法的思路和 YOLO v1 的思路也比较像,因此 CenterNet 算法非常简洁,这也是我喜欢这个算法的原因。
CenterNet 目标检测算法结构
图片输入后经过一个主干网络(Backbone)进行特征提取,这个主干网络一般是先下采样再上采样,得到一张高分辨率的特征图。在原论文中,输入图片大小为 512 × 512 大小,最终得到特征图为 4 倍下采样也就是 128 × 128 大小。
然后在这张特征图上接上三个不同的预测分支,第一个分支预测目标中心点热力图(Keypoint Heatmap),大小为:C×128×128,其中 C 为预测目标类别数,也就是为每个类别预测一个 128×128 大小的热力图,这张图中元素大小均为 0−1 之间,接近 1 的点表示该点为目标中心。
由于这个热力图是输入图片经过 4 倍下采样得到的,中心点可能不够精确,因此需要第二个预测分支来预测目标中心点的位置偏移(local offset),预测结果图大小为 2×128×128,其中每个像素点的两个通道值分别预测 x 和 y 的偏移。
第三个分支预测目标框的高和宽,也就是预测目标框大小(object size),预测结果图大小为 2×128×128,其中每个像素点的两个通道值分别预测目标框的宽 w 和高 h。
具体来说,论文中提出了多种不同的 Backbone,包括:Hourglass-104, DLA-34, ResNet-101和ResNet-18。这些 Backbone 都只是为了得到原输入图片 4 倍下采样的特征图而已,拿最简单的 ResNet-18 来说,这个 Backbone 会对输入图片进行 32 倍下采样得到相应的特征图,为了得到高分辨率的特征图,还需要在后面增加 3 次反卷积操作或者可变形卷积(DCNv2)获得相应特征图。
部分实现如下,经过主干网络和上采样得到 64×128×128的特征图,然后经过 head 预测出 3 个分支。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| class CenterNet(nn.Module): def __init__(self, num_classes, backbone): super(CenterNet, self).__init__() self.backbone = backbone self.decoder = resnet50_Decoder(self.backbone.out_channels) self.head = resnet50_Head(in_channel=64, num_classes=num_classes) self._init_weights()
def _init_weights(self): self.head.cls_head[-1].weight.data.fill_(0) self.head.cls_head[-1].bias.data.fill_(-2.19) def forward(self, x): _1, _2, _3, feat = self.backbone(x) return self.head(self.decoder(feat))
|
head部分的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| class resnet50_Head(nn.Module): def __init__(self, num_classes=20, in_channel=256, bn_momentum=0.1): super(resnet50_Head, self).__init__() self.cls_head = nn.Sequential( nn.Conv2d(in_channel, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)) self.wh_head = nn.Sequential( nn.Conv2d(in_channel, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0))
self.reg_head = nn.Sequential( nn.Conv2d(in_channel, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0))
def forward(self, x): hm = self.cls_head(x).sigmoid_() wh = self.wh_head(x) offset = self.reg_head(x) return hm, wh, offset
|
上述实现来源:https://github.com/bubbliiiing/centernet-pytorch.
算法训练
Heatmap 标签生成
为了对网络进行训练(也就是让网络能够输出预期三个分支),需要生成相应的标签,其中热力图 Heatmap 生成方法如下,给定一张打好标签的图片:
- 首先计算出真实目标框在 4 倍下采样(128×128下)的坐标,取目标框中心点的坐标 point。
- 根据真实目标框的大小计算一个高斯圆的半径 R。
- 在对应类别的特征图(每个类别一个 128×128 大小的特征图,也就是预测的 C×128×128同样大小)上以点 point 为中心,R 为半径填充一个高斯衰减的圆数值。
注:此处参考:https://www.cnblogs.com/silence-cho/p/13955766.html
这样处理的原因是,如果只将目标框中心点对应特征图上的那一个像素点值设置成 1,会导致网络预测中心点时存在一点偏移的时候都会惩罚严重,实际上还有目标中心点偏移分支进行中心点调整,使用高斯核形式生成标签能够更温和一些。这个思想其实是来源于 CornerNet:CornerNet: Detecting Objects as Paired Keypoints。
损失函数设计
由于有 3 个预测分支,因此损失函数分为 3 个部分:
Ldet=Lk+λsizeLsize+λoffLoff
第一个部分是目标中心点热力图损失Lk:
Lk=N−1xyc∑⎩⎪⎨⎪⎧(1−Y^xyc)αlog(Y^xyc)(1−Yxyc)β(Y^xyc)αlog(1−Y^xyc) if Yxyc=1 otherwise
由于 CenterNet 算法的简洁性,模型达到一个比较好的效果需要对损失函数进行较好的设计,上面目标中心点热力图损失函数设计其实是对focal loss:Focal Loss for Dense Object Detection 进行改写,其中 α和β 都是超参数,论文中取值为 α=2,β=4,N 为图像中目标中心点数量,Y^xyc 为网络预测值,Yxyc 为上面生成的标注值。
上述损失解决难易分两本不均衡问题,对于正样本(Yxyc=1),如果当前预测值为 Y^xyc=0.1,也就是说这是难分样本,计算得到的损失值为:−(1−0.1)2log(0.1))=1.8651;而如果当前预测值为 Y^xyc=0.9,也就是说这是难分样本,计算得到的损失值为:−(1−0.9)2log(0.9))=0.00105。可以看到对于正样本而言,这个损失对难分样本的损失更大,对于易分样本的损失则有惩罚。对于负样本而言也有类似的效果。
第二个部分为中心点偏移预测损失Loff,是一个L1 loss:
Loff=N1p∑∣∣∣∣O^p~−(Rp−p~)∣∣∣∣
其中 O^p~为预测偏移值,Rp−p~为标签真实值,p 是目标真实框中心坐标,除以 R 也就是除以 4 缩放到了 128×128尺度上,减去p~就得到了偏移部分。
第三个部分为目标框大小预测损失Lsize,也是简单的L1 loss:
Lsize =N1k=1∑N∣∣∣∣S^pk−sk∣∣∣∣
其中 S^pk 为预测的尺寸,sk 为真实尺寸。
注意,第二和第三部分损失只计算正样本的损失。
损失部分的一个实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| c_loss = focal_loss(hm, targets_hm) wh_loss = 0.1 * reg_l1_loss(wh, targets_wh, targets_reg_mask) off_loss = reg_l1_loss(offset, targets_reg, targets_reg_mask)
loss = c_loss + wh_loss + off_loss
def focal_loss(pred, target): pred = pred.permute(0, 2, 3, 1)
pos_inds = target.eq(1).float() neg_inds = target.lt(1).float()
neg_weights = torch.pow(1 - target, 4) pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum()
if num_pos == 0: loss = -neg_loss else: loss = -(pos_loss + neg_loss) / num_pos return loss
def reg_l1_loss(pred, target, mask): pred = pred.permute(0, 2, 3, 1) expand_mask = torch.unsqueeze(mask,-1).repeat(1,1,1,2)
loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum') loss = loss / (mask.sum() + 1e-4) return loss
|
总结
整个算法就介绍完了,算法很简单,就像论文题目所说的将目标检测看作点检测。我也参照网上大家的实现写了一份 Pytorch 版本的 CenterNet 实现,同样为了算法部署方便,没有使用变形卷积(DCN),实现的 Backbone 也是 ResNet系列:https://github.com/xiaoqieF/my-centernet
参考文献
- Objects as Points
- https://www.cnblogs.com/silence-cho/p/13955766.html
- https://www.cnblogs.com/silence-cho/p/12987476.html
- https://github.com/bubbliiiing/centernet-pytorch
- https://github.com/yjh0410/CenterNet-plus