知识蒸馏是做什么的?

知识蒸馏的概念由Hinton在Distilling the Knowledge in a Neural Network中提出,目的是把 一个大模型或者多个模型集成 学到的知识迁移到另一个轻量级模型上。
Knowledge Distillation,简称KD,顾名思义,就是将已经训练好的模型包含的知识(Knowledge),蒸馏(Distill)提取到另一个模型里面去。
简而言之,就是模型压缩的一种方法,是一种基于“教师-学生网络思想”的训练方法

做模型压缩的原因:一般情况下,我们在训练模型的时候使用了大量训练数据和计算资源来提取知识,但是大模型不方便部署到服务中去,一是因为大模型的推理速度慢,二是对设备的资源要求高,因此我们希望对训练好的模型进行压缩,在保证推理效果的前提下减小模型的体量。
····················································································································································································
插一句别的:我们可以从模型参数量和训练数据量之间的相对关系来理解underfitting和overfitting。在网上看到一个很形象的解释:模型就像一个容器,训练数据中蕴含的知识就像是要装进容器里的水。当数据知识量(水量)超过模型所能建模的范围时(容器的容积),加再多的数据也不能提升效果(水再多也装不进容器),因为模型的表达空间有限(容器容积有限),就会造成underfitting;而当模型的参数量大于已有知识所需要的表达空间时(容积大于水量,水装不满容器),就会造成overfitting,即模型的variance会增大(想象一下摇晃半满的容器,里面水的形状是不稳定的)。
一个模型的参数量基本决定了其所能捕获到的数据内蕴含的“知识”的量。这个想法基本正确,但是要注意:
(1)模型的参数量和其所能捕获的“知识“量之间并非线性关系(下图中的1),而是接近边际收益逐渐减少的一种增长曲线(下图中的2和3)
(2)完全相同的模型架构和模型参数量,使用完全相同的训练数据,能捕获的“知识”量并不一定完全相同,另一个关键因素是训练的方法。合适的训练方法可以使得在模型参数总量比较小时,尽可能地获取到更多的“知识”(下图中的3与2曲线的对比)。
图源:【经典简读】知识蒸馏(Knowledge Distillation) 经典之作


知识蒸馏的理论依据

下面介绍一下知识蒸馏所用到的理论依据。

名词解释

  • Teacher:大而笨重的模型
  • Student:小而紧凑的模型
  • transfer set:用于小模型训练的数据,也是获得Teacher模型soft target输出的输入数据集
  • hard target:样本原始标签
  • soft target:Teacher模型输出的预测结果
  • temperature:softmax函数中的超参数
  • knowledge:可以理解为从输入向量到输出向量学习到的映射

符号定义

  • zz z:Logits,模型去除输出层的输出。
    对于一般的分类问题,比如图片分类,输入一张图片后,经过深度神经网络各种非线性变换,在网络最后的Softmax层之前,会得到这张图片属于各个类别的大小数值 z iz_i zi,某个类别的 z iz_i zi数值越大,模型认为输入图片属于这个类别的可能性就越大。那么什么是Logits?这些汇总了网络内部各种信息后,得出的属于各个类别的汇总分值 z iz_i zi就是Logits,ii i代表第ii i个类别, z iz_i zi代表属于第ii i类的可能性。因为Logits不是概率值,所以一般在Logits数值上会用Softmax函数进行变换,得出的概率值作为最终分类结果概率。Softmax一方面把Logits数值在各类别之间进行概率归一,使得各个类别归属数值满足概率分布;另外一方面,它会放大Logits数值之间的差异,使得Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。
  • pp p:probality,每个类的概率

Teacher Model和Student Model

知识蒸馏采取Teacher-Student模式:将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

需要注意的是,这里蒸馏的目的是小网络的概率分布趋近于大网络,而非单纯的正确率趋近于大网络。 Hinton注意到,虽然我们最终分类依靠的是softmax后的最大概率结果,但其实那些概率很小类之间的差别蕴含着网络进行特征提取的很多信息。例如,猫、狗、狮子三个类的输出分别是0.99,0.001,0.009,则很明显:狮子与猫的相似程度比狮子与狗高。
因此,从理论上来说:只有小网络的所有大小输出与大网络都非常相近,才可以视为大小网络间概率分布非常接近。


知识蒸馏分类

知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏)基于特征蒸馏的算法两个大的方向,下面我们对其进行介绍。
···················································································································································································

目标蒸馏-Logits方法

目标蒸馏方法中最经典的论文就是来自于2015年Hinton发表的Distilling the Knowledge in a Neural Network。下面以这篇论文为例,讲一下目标蒸馏方法的原理。
在这篇论文中,Hinton将问题限定在分类问题下,分类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的Teacher模型,在利用Teacher模型来蒸馏训练Student模型时,可以直接让Student模型去学习Teacher模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“Soft-target” 。

Hard-target 和 Soft-target

【KD的训练过程和传统的训练过程的对比】

  • 传统的神经网络training过程:定义一个损失函数,目标是使预测值尽可能接近于真实值(Hard-target),损失函数就是使神经网络的损失值和尽可能小。这种训练过程是对ground truth求极大似然。
  • KD的training过程:是使用大模型的类别概率作为Soft-target的训练过程。


从这张图可以看出:
Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。
Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。

知识蒸馏用Teacher模型预测的 Soft-target 来辅助 Hard-target 训练 Student模型的方式为什么有效呢?

softmax层的输出,除了正例之外,负标签也带有Teacher模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher模型在推理时认为该样本与该负标签有一定的相似性。而在传统的训练过程(Hard-target)中,所有负标签都被统一对待。也就是说,知识蒸馏的训练方式使得每个样本给Student模型带来的信息量大于传统的训练方式。

如在MNIST数据集中做手写体数字识别任务,假设某个输入的“2”更加形似”3″,softmax的输出值中”3″对应的概率会比其他负标签类别高;而另一个”2″更加形似”7″,则这个样本分配给”7″对应的概率会比其他负标签类别高。这两个”2″对应的Hard-target的值是相同的,但是它们的Soft-target却是不同的,由此我们可见Soft-target蕴含着比Hard-target更多的信息。

在使用 Soft-target 训练时,Student模型可以很快学习到 Teacher模型的推理过程;而传统的 Hard-target 的训练方式,所有的负标签都会被平等对待因此,Soft-target 给 Student模型带来的信息量要大于 Hard-target,并且Soft-target分布的熵相对高时,其Soft-target蕴含的知识就更丰富。同时,使用 Soft-target 训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。这也解释了为什么通过蒸馏的方法训练出的Student模型相比使用完全相同的模型结构和训练数据只使用Hard-target的训练方法得到的模型,拥有更好的泛化能力。

好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好。另外,对于分类任务,如果soft targets的熵比hard targets高,那显然student会学习到更多的信息。采用软标签的知识蒸馏方法,一方面压缩了模型,另一方面,增强了模型的泛化能力。

知识蒸馏的具体方法

神经网络使用 softmax 层来实现 logits 向类别概率(class probabilities)的转换。

原始的softmax函数: qi=e x p ( zi) ∑je x p ( zj) q_i= \frac{exp(z_i)}{\sum_jexp(z_j)}qi=jexp(zj)exp(zi)
但是直接使用softmax层的输出值作为soft label,会有一个问题:当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此“温度”这个变量就派上了用场。

加上temperature变量之后的softmax函数: qi=e x p ( zi/ T ) ∑je x p ( zj/ T ) q_i= \frac{exp(z_i/T)}{\sum_jexp(z_j/T)}qi=jexp(zj/T)exp(zi/T)。其中, qi q_iqi是每个类别输出的概率, zi z_izi是每个类别输出的logits, TTT是温度。当 T = 1T=1T=1时,就是标准的softmax公式。 TTT越高,softmax 的 output probability distribution 越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

知识蒸馏训练的具体方法如下图所示,主要包括以下几个步骤:
step1:训练好 Teacher 模型
step2:利用高温 T h i g h T_{high}Thigh产生 S o f t − t a r g e tSoft-targetSofttarget
step3:利用 { S o f t − t a r g e tSoft-targetSofttarget, T h i g h T_{high}Thigh} 和 { H a r d − t a r g e tHard-targetHardtarget, T = 1T=1T=1} 同时训练 Student 模型
step4:设置 T = 1T=1T=1, Student 模型线上做推理(inference)


训练Teacher的过程很简单,我们把step2和step3统一称为:高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应Soft-target)和Student loss(对应Hard-target)加权得到,即 L = α L s o f t+ β L h a r d L=αL_{soft}+βL_{hard}L=αLsoft+βLhard 。下面介绍一下具体的损失函数的两个部分:

(1) Teacher模型和Student模型同时输入 transfer set (这里可以直接复用训练Teacher模型用到的training set),用Teacher模型在高温 T h i g h T_{high}Thigh下产生的softmax distribution来作为Soft-target,Student模型在相同温度 T h i g h T_{high}Thigh条件(保证Student Model和Teacher Model的结果尽可能一致)下的softmax输出和Soft-target的cross entropy就是Loss函数的第一部分 L s o f t L_{soft}Lsoft
具体形式为: L s o f t= − ∑iNpiTl o g ( qiT)L_{soft}=-\sum_{i}^Np_i^Tlog(q_i^T)Lsoft=iNpiTlog(qiT) ,其中, piT p_i^TpiT指 Teacher 模型在温度等于 TTT的条件下 softmax 输出在第 iii类上的值。 qiT q_i^TqiT指Student的在温度等于 TTT的条件下softmax输出在第 iii类上的值。
其中, piT=e x p ( vi/ T ) ∑kNe x p ( vk/ T ) p_i^T=\frac{exp(v_i/T)}{\sum_k^Nexp(v_k/T)}piT=kNexp(vk/T)exp(vi/T) qiT=e x p ( zi/ T ) ∑kNe x p ( zj/ T ) q_i^T=\frac{exp(z_i/T)}{\sum_k^Nexp(z_j/T)}qiT=kNexp(zj/T)exp(zi/T)。其中, vi v_ivi指 Teacher 模型的logits, zi z_izi指 Student 模型的logits, NNN指总标签数量。

(2) Student模型在 T = 1T=1T=1的条件下(保证Student Model的结果和实际的类别标签尽可能一致)的softmax输出和ground truth的cross entropy就是Loss函数的第二部分 L h a r d L_{hard}Lhard
具体形式为: L h a r d= − ∑iNcil o g ( qi1)L_{hard}=-\sum_{i}^Nc_ilog(q_i^1)Lhard=iNcilog(qi1) ,其中, ci c_ici指在第 iii类上的ground truth值, ci∈ { 0,1}c_i∈\lbrace{0,1}\rbraceci{0,1},正标签取1,负标签取0。
其中, qi1=e x p ( zi) ∑kNe x p ( zk) q_i^1=\frac{exp(z_i)}{\sum_k^Nexp(z_k)}qi1=kNexp(zk)exp(zi)

第二部分Loss 的必要性其实很好理解:Teacher模型也有一定的错误率,使用 ground truth 可以有效降低错误被传播给Student模型的可能性。打个比喻,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师传授的知识之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。

最后, ααα βββ是关于 L s o f t L_{soft}Lsoft L h a r d L_{hard}Lhard的权重,实验发现,当 L h a r d L_{hard}Lhard权重较小时,能产生最好的效果,这是一个经验性的结论。理论的推导不再给了,这里直接给出结论:由于 L s o f t L_{soft}Lsoft贡献的梯度大约为 L h a r d L_{hard}Lhard 1 T 2 \frac{1}{T^2}T21,因此在同时使用Soft-target和Hard-target的时候,需要在 L s o f t L_{soft}Lsoft的权重上乘以 T2 T^2T2的系数,这样才能保证Soft-target和Hard-target贡献的梯度量基本一致。

关于温度

在知识蒸馏中,需要使用高温将知识“蒸馏”出来,但是如何调节温度呢,温度的变化会产生怎样的影响呢?

温度 TTT有这样几个特点:

  • 原始的softmax函数是T=1T=1 T=1时的特例;T<1T<1 T<1时,概率分布比原始更“陡峭”,也就是说,当T→0T→0 T0时,Softmax 的输出值会接近于 Hard-target; T<1T<1 T<1时,概率分布比原始更“平缓”。
  • 随着TT T的增加,Softmax 的输出分布越来越平缓,信息熵会越来越大。温度越高,softmax上各个值的分布就越平均,思考极端情况,当T=∞T=∞ T= ,此时softmax的值是平均分布的。
  • 不管温度TT T怎么取值,Soft-target都有忽略相对较小的 p ip_i pi(Teacher模型softmax输出在第ii i类上的值)携带的信息的倾向。

温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签。
实际上,负标签中包含一定的信息,尤其是那些负标签概率值显著高于平均值的负标签。但由于Teacher模型的训练过程决定了负标签部分概率值都比较小,并且负标签的值越低,其信息就越不可靠。因此温度的选取需要进行实际实验的比较,本质上就是在下面两种情况之中取舍:

  • 当想从负标签中学到一些信息量的时候,温度TT T应调高一些
  • 当想减少负标签的干扰的时候,温度TT T应调低一些

总的来说,TT T的选择和Student模型的大小有关,Student模型参数量比较小的时候,相对比较低的温度就可以了因为参数量小的模型不能学到所有Teacher模型的知识,所以可以适当忽略掉一些负标签的信息。

最后,在整个知识蒸馏过程中,我们先让温度 TTT升高,然后在测试阶段恢复“低温”( T = 1T=1T=1 ),从而将原模型中的知识提取出来,因此将其称为是蒸馏。
···················································································································································································

特征蒸馏

另外一种知识蒸馏思路是特征蒸馏方法,如下图所示。它不像Logits方法那样,Student只学习Teacher的Logits这种结果知识,而是学习Teacher网络结构中的中间层特征。最早采用这种模式的工作来自于论文FITNETS:Hints for Thin Deep Nets,它强迫Student某些中间层的网络响应,要去逼近Teacher对应的中间层的网络响应。这种情况下,Teacher中间特征层的响应,就是传递给Student的知识。在此之后,出了各种新方法,但是本质上还是是Teacher将特征级知识迁移给Student。因此,接下来以这篇论文为主,详细介绍特征蒸馏方法的原理。

主要解决的问题

这篇论文指出,既宽又深的模型通常需要大量的乘法运算,从而导致对内存和计算的高需求。因此,即使网络在准确性方面即使是性能最高的模型,其在现实世界中的应用也受到限制。为了解决这类问题,我们需要通过知识蒸馏将知识从复杂的模型转移到参数较少的简单模型。

到目前为止,知识蒸馏技术已经考虑了Student网络与Teacher网络有相同或更小的参数。这里有一个洞察点是,深度是特征学习的基本层面,到目前为止尚未考虑到Student网络的深度。一个具有比Teacher网络更多的层但每层具有较少神经元数量的Student网络称为“thin deep network”。

因此,该篇论文主要针对Hinton提出的知识蒸馏法进行扩展,允许Student网络可以比Teacher网络更深更窄,使用Teacher网络的输出和中间层的特征作为提示,改进训练过程和student网络的性能。

模型结构

Student网络不仅仅拟合Teacher网络的Soft-target,而且拟合隐藏层的输出(Teacher网络抽取的特征)

  • 第一阶段让Student网络去学习Teacher网络的隐藏层输出(特征蒸馏)
  • 第二阶段使用Soft-target来训练Student网络(目标蒸馏)


把“宽”且“深”的网络蒸馏成“瘦”且“更深”的网络,需要进行两阶段的训练:

第一阶段:首先选择待蒸馏的中间层(即Teacher的Hint layer和Student的Guided layer),如图中绿框和红框所示。由于两者的输出尺寸可能不同,因此,在Guided layer后另外接一层卷积层,使得输出尺寸与Teacher的Hint layer匹配。接着通过知识蒸馏的方式训练Student网络的Guided layer,使得Student网络的中间层学习到Teacher的Hint layer的输出。

第二阶段: 在训练好Guided layer之后,将当前的参数作为网络的初始参数,利用知识蒸馏的方式训练Student网络的所有层参数,使Student学习Teacher的输出。