6.1 知识蒸馏概念和理论
知识蒸馏的概念介绍¶
学习目标¶
- 理解什么是模型的知识蒸馏.
- 掌握对模型进行知识蒸馏的代码操作.
- 掌握知识蒸馏后模型的性能测试.
什么是模型的知识蒸馏¶
-
在工业级的应用中, 除了要求模型要有好的预测效果之外, 往往还希望它的"消耗"足够小. 也就是说一般希望部署在线上的应用模型消耗较小的资源. 这些资源包括存储空间, 包括算力.
-
在传统的深度学习背景下, 如果希望模型的效果足够好, 通常会有两种方案:
- 使用更大规模的参数.
- 使用集成模型, 将多个弱模型集成起来.
- 注意: 上面两种方案往往需要较大的计算资源, 对部署非常不利. 由此产生了模型压缩的动机: 我们希望有一个小模型, 但又能达到大模型一样或相当的效果.
- 我们可以先训练一个大模型, 然后将其中的经验知识转移给小的模型, 这就是知识蒸馏.
- 知识蒸馏的概念最早由Hinton在2015年提出, 在2019年后火热起来.
- 知识蒸馏在目前(2020-2021)已经成为一种既前沿又常用的提高模型泛化能力和部署优势的方法.
知识蒸馏的原理和算法¶
- 根据知识蒸馏的定义, 很明显我们手里有两个模型:
- 一个大模型, 拥有更多的"知识"和"经验", 有更优秀的模型效果.
- 一个小模型, 知识和经验欠缺, 但是参数量少, 有更优秀的部署效果.
- 那么很显然的问题就是: 如何将大模型包含的"知识"和"经验"传递给小模型呢?
- 小模型对应的label是真实标签(ground truth), 将大模型的预测输出的softmax分布作为"知识", 让小模型去学习+匹配这个知识. 经过训练后的大模型, 其预测出来的softmax分布一定包含有很多知识.
- 真实标签(ground truth)只能告诉我们: 一个样本是一辆奔驰车, 不是一辆宝马车, 也不是一个苹果.
- 而大模型输出的softmax分布可以告诉我们: 一个样本最可能是一辆奔驰车, 很小的可能是一辆宝马车, 绝不可能是一个苹果.
- 下图非常直观, 又经典的展示了知识蒸馏的架构图, 相当于有两部分的分支训练:
- 一部分是大模型的softmax分布作为"知识标签", 让小模型去学习.
- 一部分是真实label(ground truth)作为"真实标签", 让小样本去匹配.
- 我们对知识蒸馏进行公式化处理: 先训练好一个精度较高的Teacher网络(一般是复杂度较高的大规模预训练模型), 然后将Teacher网络的预测结果q作为Student网络的"学习目标", 来训练Student网络(一般是速度较快的小规模模型), 最终使得Student网络的结果p接近于q. 损失函数如下:
- 上图中CE是交叉熵(Cross Entropy), y是真实标签, q是Teacher网络的输出结果, p是Student网络的输出结果.
- 原始论文中提出了softmax-T公式来计算上图中的q:
- 上图中qi是Student网络学习的对象, 也就是所谓的软目标(soft targets), zi是神经网络softmax前的输出logits.
- 不同的温度系数T值, 对softmax-T算法有不同的影响, 总结如下:
- 如果将T值取1, softmax-T公式就成为softmax公式, 根据logits输出各个类别的概率.
- 如果T越接近于0, 则最大值会越接近1, 其他值会接近0, 类似于退化成one-hot编码.
- 如果T越大, 则输出的结果分布越平缓, 相当于标签平滑的原理, 起到保留相似信息的作用.
- 如果T趋于无穷大, 则演变成均匀分布.
- 归根结底总结上面的4条规律, 就是当软目标携带的信息量太少(比如在某些类别上的概率非常小, 只有1e-6), 而我们又想放大这些信息, 就可以尝试引入较大的温度参数T, 从而蒸馏出这些小概率值所携带的信息.
常见的知识蒸馏方式介绍¶
- 常见的知识蒸馏方法可以按照不同的目标分为如下几类:
- 模型压缩
- 同构蒸馏
- 集成蒸馏
- 大规模蒸馏
模型压缩¶
- 知识蒸馏最常用的目的就是模型压缩, 也就是将复杂网络学习的知识传递给简单网络, 同时保留接近于复杂网络的性能, 从而达到速度和精度的平衡.
同构蒸馏¶
- 模型压缩中的Teacher和Student网络模型的结构一般不同, 比如Teacher模型通常使用较复杂的大规模预训练模型, 而Student模型通常使用简单的RNN/CNN模型. 事实上, Tommaso在2018年提出, 如果Teacher和Student的模型结构完全相同, 蒸馏后会对模型的表现有一定程度的提升, 这种方式被称为同构蒸馏.
- 注意: 既然同构了, 那模型大小的压缩和推理速度的提升肯定不会很大, 因此同构蒸馏要看具体业务场景而应用.
集成蒸馏¶
-
集成蒸馏模型又被称为ensemble models.
-
Teacher模型并不局限于单一的复杂的模型, 可以将多个网络学到的知识转移到同一个小模型中, 使得Student模型的性能接近于ensemble的结果. 具体方法时先训练多个Teacher模型, 将结果取平均值或带权重的平均值, 作为软目标(soft targets)供Student模型学习. 此时的集成模型一般会比单一模型效果好.
大规模蒸馏¶
- 这里的大规模指大样本, 需要蒸馏出的Student模型往往需要大量的样本训练才能逼近大模型的结果.
- 如果我们手中有大量与训练数据同分布的日志数据(无标签数据, no-label): 可以先利用Teacher模型给这些无标签数据打上软标签(soft label), 再让Student通过标签平滑技术去学习.
- 如果我们手中没有同分布的日志数据(无标签数据, no-label): 可以使用EDA等常见的数据增强方法来增加样本的规模, 缓解数据不足的情况下Student模型精度损失较大的问题.