跳转至

6.1 知识蒸馏概念和理论

知识蒸馏的概念介绍


学习目标

  • 理解什么是模型的知识蒸馏.
  • 掌握对模型进行知识蒸馏的代码操作.
  • 掌握知识蒸馏后模型的性能测试.

什么是模型的知识蒸馏

  • 在工业级的应用中, 除了要求模型要有好的预测效果之外, 往往还希望它的"消耗"足够小. 也就是说一般希望部署在线上的应用模型消耗较小的资源. 这些资源包括存储空间, 包括算力.

  • 在传统的深度学习背景下, 如果希望模型的效果足够好, 通常会有两种方案:

    • 使用更大规模的参数.
    • 使用集成模型, 将多个弱模型集成起来.

  • 注意: 上面两种方案往往需要较大的计算资源, 对部署非常不利. 由此产生了模型压缩的动机: 我们希望有一个小模型, 但又能达到大模型一样或相当的效果.

  • 我们可以先训练一个大模型, 然后将其中的经验知识转移给小的模型, 这就是知识蒸馏.
    • 知识蒸馏的概念最早由Hinton在2015年提出, 在2019年后火热起来.
    • 知识蒸馏在目前(2020-2021)已经成为一种既前沿又常用的提高模型泛化能力和部署优势的方法.


知识蒸馏的原理和算法

  • 根据知识蒸馏的定义, 很明显我们手里有两个模型:
    • 一个大模型, 拥有更多的"知识"和"经验", 有更优秀的模型效果.
    • 一个小模型, 知识和经验欠缺, 但是参数量少, 有更优秀的部署效果.

  • 那么很显然的问题就是: 如何将大模型包含的"知识"和"经验"传递给小模型呢?

  • 小模型对应的label是真实标签(ground truth), 将大模型的预测输出的softmax分布作为"知识", 让小模型去学习+匹配这个知识. 经过训练后的大模型, 其预测出来的softmax分布一定包含有很多知识.
    • 真实标签(ground truth)只能告诉我们: 一个样本是一辆奔驰车, 不是一辆宝马车, 也不是一个苹果.
    • 而大模型输出的softmax分布可以告诉我们: 一个样本最可能是一辆奔驰车, 很小的可能是一辆宝马车, 绝不可能是一个苹果.

  • 下图非常直观, 又经典的展示了知识蒸馏的架构图, 相当于有两部分的分支训练:
    • 一部分是大模型的softmax分布作为"知识标签", 让小模型去学习.
    • 一部分是真实label(ground truth)作为"真实标签", 让小样本去匹配.


  • 我们对知识蒸馏进行公式化处理: 先训练好一个精度较高的Teacher网络(一般是复杂度较高的大规模预训练模型), 然后将Teacher网络的预测结果q作为Student网络的"学习目标", 来训练Student网络(一般是速度较快的小规模模型), 最终使得Student网络的结果p接近于q. 损失函数如下:


  • 上图中CE是交叉熵(Cross Entropy), y是真实标签, q是Teacher网络的输出结果, p是Student网络的输出结果.

  • 原始论文中提出了softmax-T公式来计算上图中的q:


  • 上图中qi是Student网络学习的对象, 也就是所谓的软目标(soft targets), zi是神经网络softmax前的输出logits.

  • 不同的温度系数T值, 对softmax-T算法有不同的影响, 总结如下:
    • 如果将T值取1, softmax-T公式就成为softmax公式, 根据logits输出各个类别的概率.
    • 如果T越接近于0, 则最大值会越接近1, 其他值会接近0, 类似于退化成one-hot编码.
    • 如果T越大, 则输出的结果分布越平缓, 相当于标签平滑的原理, 起到保留相似信息的作用.
    • 如果T趋于无穷大, 则演变成均匀分布.

  • 归根结底总结上面的4条规律, 就是当软目标携带的信息量太少(比如在某些类别上的概率非常小, 只有1e-6), 而我们又想放大这些信息, 就可以尝试引入较大的温度参数T, 从而蒸馏出这些小概率值所携带的信息.



常见的知识蒸馏方式介绍

  • 常见的知识蒸馏方法可以按照不同的目标分为如下几类:
    • 模型压缩
    • 同构蒸馏
    • 集成蒸馏
    • 大规模蒸馏

模型压缩

  • 知识蒸馏最常用的目的就是模型压缩, 也就是将复杂网络学习的知识传递给简单网络, 同时保留接近于复杂网络的性能, 从而达到速度和精度的平衡.

同构蒸馏

  • 模型压缩中的Teacher和Student网络模型的结构一般不同, 比如Teacher模型通常使用较复杂的大规模预训练模型, 而Student模型通常使用简单的RNN/CNN模型. 事实上, Tommaso在2018年提出, 如果Teacher和Student的模型结构完全相同, 蒸馏后会对模型的表现有一定程度的提升, 这种方式被称为同构蒸馏.

  • 注意: 既然同构了, 那模型大小的压缩和推理速度的提升肯定不会很大, 因此同构蒸馏要看具体业务场景而应用.

集成蒸馏

  • 集成蒸馏模型又被称为ensemble models.

  • Teacher模型并不局限于单一的复杂的模型, 可以将多个网络学到的知识转移到同一个小模型中, 使得Student模型的性能接近于ensemble的结果. 具体方法时先训练多个Teacher模型, 将结果取平均值或带权重的平均值, 作为软目标(soft targets)供Student模型学习. 此时的集成模型一般会比单一模型效果好.


大规模蒸馏

  • 这里的大规模指大样本, 需要蒸馏出的Student模型往往需要大量的样本训练才能逼近大模型的结果.
    • 如果我们手中有大量与训练数据同分布的日志数据(无标签数据, no-label): 可以先利用Teacher模型给这些无标签数据打上软标签(soft label), 再让Student通过标签平滑技术去学习.
    • 如果我们手中没有同分布的日志数据(无标签数据, no-label): 可以使用EDA等常见的数据增强方法来增加样本的规模, 缓解数据不足的情况下Student模型精度损失较大的问题.