前言

模型压缩方法主要4种:

  • 网络剪枝(Network pruning)
  • 稀疏表示(Sparse representation)
  • 模型量化(Model quantification)
  • 知识蒸馏(Konwledge distillation)

本文主要来研究知识蒸馏的相关知识,并尝试用知识蒸馏的方法对YOLOv5进行改进。

知识蒸馏理论简介

概述

知识蒸馏(Knowledge Distillation)由深度学习三巨头Hinton在2015年提出。

论文标题:Distilling the knowledge in a neural network
论文地址:https://arxiv.org/pdf/1503.02531.pdf

“蒸馏”是个化工学科中的术语,本身指的是将液体混合物加热沸腾,使其中沸点较低的组分首先变成蒸气,再冷凝成液体,用来分离混合物。而知识蒸馏的含义和蒸馏本身相似但并不完全相同,知识蒸馏指的是同时训练两个网络,一个较复杂的网络作为教师网络,另一个较简单的网络作为学生网络,将教师网络训练得到的结果提炼出来,用来引导学生网络的结果,从而让学生网络学习得更好。

一个公认前提是小模型相比于大模型更容易陷入局部最优,下图[1]中,中间绿色的椭圆表示小网络模型的收敛空间,红色的椭圆表示大网络模型的收敛空间;如果不用知识蒸馏,直接训练小网络,它只会在绿色椭圆区域收敛,而使用知识蒸馏之后,小网络可以收敛到橙色椭圆区域,收敛到更小的最优点。

软标签

有了上面的概念,自然而然想到的一个问题就是,教师模型如何引导学生模型进行学习。这就涉及到论文中提及的一个概念——软标签(Soft target)

如上图[1]所示,以手写数字识别为例,这是一个10分类任务,左边这幅图是采用硬标签(Hard target),输出独热向量,概率最高的类别为1,其它类别为0;右边这幅图采用的是软标签(Soft target),通过softmax层输出的各类别概率,这样的输出具有更高的信息熵,即包含更多信息量。
教师模型输出软标签,从而指导学生模型学习。

softmax的原始公式是这样:

q i= exp⁡ ( zi) ∑ jexp⁡ ( zj) q_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)} qi=jexp(zj)exp(zi)

在论文中,作者对这个公式又加以改进,引入了一个新的温度变量T,公式如下:

q i= exp⁡ ( zi/ T ) ∑ jexp⁡ ( zj/ T ) q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)} qi=jexp(zj/T)exp(zi/T)

加入这个变量,能使各类别之间的输出更均衡,如下图[2]所示,T=1为softmax,但是当T过大时,会发现输出向量会趋于一条直线,因此,T通常取中间较小值。

蒸馏温度

上面引入了一个新的变量温度T,这个T也可以称为蒸馏温度,原论文中给出了关于T的进一步讨论,随着T的增加,信息熵会越来越大,如下图[1]所示:


实际上,温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签[1]。

因此,T的取值可以遵循如下策略:

  • 当想从负标签中学到一些信息量的时候,温度T应调高一些
  • 当想减少负标签的干扰的时候,温度T应调低一些

需要注意的是,这个T只作用于教师网络和学生网络的蒸馏过程,学生网络正常输出仍使用softmax,即T取值为1,就像蒸馏过程一样,需要先进行升温,将知识蒸馏出来,然后输出的时候要冷却降温(T=1)

知识蒸馏过程

从原理上来讲,知识蒸馏没有想象中那么复杂,其流程如下图[1]所示:

  1. 在T下,训练教师网络得到 soft targets1
  2. 在T下,训练学生网络得到 soft targets2
  3. 通过 soft targets1soft targets2 得到 distillation loss
  4. 在温度1下,训练学生网络得到 soft targets3
  5. 通过 soft targets3ground truth 得到 student loss

通过这五个步骤,就得到了两个损失值 distillation lossstudent loss,那么训练的整体损失,就是这两个损失值的加权和,公式[2]如下:


注:

  • 这里的蒸馏损失系数乘了一个 T 2T^2 T2
    这是由于soft targets产生的梯度大小按照1/ T 21/T^2 1/T2进行了缩放,这里需要补充回来
  • α\alpha α应远小于β\beta β
    即需要让知识蒸馏损失权重大一些,否则没有蒸馏效果

后面,论文作者分别做了手写数字识别和声音识别实验,这里主要来看作者在MNIST数据集上的实验结果,结果如下表所示:

10xEnsemble是10个教师模型的平均值,Distilled Single model是Baseline模型经过蒸馏之后的结果,可以看到蒸馏出来的准确率提升了1.9%.

YOLOv5加上知识蒸馏

下面就将知识蒸馏融入到YOLOv5目标检测任务中,使用的是YOLOv5-6.0版本。
相关代码参考自:https://github.com/Adlik/yolov5

代码修改

其实知识蒸馏的想法很简单,在仓库作者的代码版本中,修改的内容也并不多,主要是模型加载和损失计算部分。

下面按照顺序来解读一下修改内容。

首先是train_distillation.py这个文件,通过修改train.py得到。

新增四个参数:

parser.add_argument('--t_weights', type=str, default='./weights/yolov5s.pt',help='initial teacher model weights path')parser.add_argument('--t_cfg', type=str, default='models/yolov5s.yaml', help='teacher model.yaml path')parser.add_argument('--d_output', action='store_true', default=False,help='if true, only distill outputs')parser.add_argument('--d_feature', action='store_true', default=False,help='if true, distill both feature and output layers')
  • t_weights
    教师模型权重,和学生模型加载类似

  • t_cfg
    教师模型配置,和学生模型配置类似

  • d_output
    这个参数写在这里但不起作用,应该是作者调试时用到的参数,默认是只蒸馏结果

  • d_feature
    这个参数默认是关闭,如果开启,蒸馏损失计算将不仅仅是计算两个模型输出的结果,并且中间特征层也会参与计算(不过这个作者没写完整,可能写到一半弃坑了)

模型加载:
这部分需要多加载一个教师模型,相关代码如下:

# Modelcheck_suffix(weights, '.pt')# check weightspretrained = weights.endswith('.pt')if pretrained:with torch_distributed_zero_first(LOCAL_RANK):weights = attempt_download(weights)# download if not found locallyckpt = torch.load(weights, map_location=device)# load checkpointmodel = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)# createexclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []# exclude keyscsd = ckpt['model'].float().state_dict()# checkpoint state_dict as FP32csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)# intersectmodel.load_state_dict(csd, strict=False)# loadLOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')# report# 这里添加加载教师模型# Teacher modelLOGGER.info(f'Loaded teacher model {t_cfg}')# reportt_ckpt = torch.load(t_weights, map_location=device)# load checkpointt_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else []# exclude keyscsd = t_ckpt['model'].float().state_dict()# checkpoint state_dict as FP32csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude)# intersectt_model.load_state_dict(csd, strict=False)# load

损失计算:
这里多了一个d_outputs_loss,也就是计算蒸馏损失

s_loss, loss_items = compute_loss(pred, targets.to(device))# loss scaled by batch_sized_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)loss = d_outputs_loss + s_loss

蒸馏损失在loss.py中进行定义:

def compute_distillation_output_loss(p, t_p, model, d_weight=1):t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensort_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0])h = model.hyp# hyperparametersred = 'mean'# Loss reduction (sum or mean)if red != "mean":raise NotImplementedError("reduction must be mean in distillation mode!")DboxLoss = nn.MSELoss(reduction="none")DclsLoss = nn.MSELoss(reduction="none")DobjLoss = nn.MSELoss(reduction="none")# per outputfor i, pi in enumerate(p):# layer index, layer predictionst_pi = t_p[i]t_obj_scale = t_pi[..., 4].sigmoid()# BBoxb_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale)# Classif model.nc > 1:# cls loss (only if multiple classes)c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc)# t_lcls += torch.mean(c_obj_scale * (pi[..., 5:] - t_pi[..., 5:]) ** 2)t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale)# t_lobj += torch.mean(t_obj_scale * (pi[..., 4] - t_pi[..., 4]) ** 2)t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale)t_lbox *= h['box']t_lobj *= h['obj']t_lcls *= h['cls']# bs = p[0].shape[0]# batch sizeloss = (t_lobj + t_lbox + t_lcls) * d_weightreturn loss

因为目标检测和原论文中的分类问题有所区别,并不能直接简单套用原论文提出的soft-target,那么这里的处理方式就是将三个损失(位置损失、目标损失、类别损失)简单粗暴地用MSELoss进行计算,然后蒸馏损失就是这三部分之和。

值得注意的是,理论部分我们提到过,蒸馏损失需要比学生损失的权重更大,因此,这里在计算蒸馏损失中,加入了一个权重d_weight,权重计算时取10.

下面是代码作者给出的一个实验结果:

ModelCompression
strategy
Input size
[h, w]
mAPval
0.5:0.95
Pretrain weight
yolov5sbaseline[640, 640]37.2pth | onnx
yolov5sdistillation[640, 640]39.3pth | onnx
yolov5squantization[640, 640]36.5xml | bin
yolov5sdistillation + quantization[640, 640]38.6xml | bin

他采用的是coco数据集,用yolov5m作为教师模型,yolov5s作为学生模型,表格第二行展示了蒸馏之后的效果,mAP提升了2.1.

实验验证

为了验证蒸馏是否有效,我在VisDrone数据集上进行了实验,训练了100epoch,实验结果如下表所示:

Student ModelTeacher ModelInput size
[h, w]
mAPtest
0.5
mAPtest
0.5:0.95
yolov5m[640, 640]0.320.181
yolov5myolov5m[640, 640]0.3050.163
yolov5myolov5x[640, 640]0.3020.161
yolov5m[1280, 1280]0.4480.261
yolov5myolov5x[1280, 1280]0.4010.23

结果挺意外的,使用蒸馏训练之后,mAP反而下降了,严重怀疑蒸馏出来的是糟粕

结论

知识蒸馏理论上并不复杂,但经过实验,基本判断这玩意理论价值大于应用价值,用来讲故事可以,实际上提升效果非常有限。当然这是我做了有限实验得出的初步结论,如果读者有更好的思路,可以在评论区留言和我讨论。

参考

[1]【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network:https://www.bilibili.com/read/cv16841475
[2]【论文精讲|无废话版】知识蒸馏:https://www.bilibili.com/video/BV1h8411t7SA