跳转至

6.2 知识蒸馏的优化

知识蒸馏的项目应用


学习目标

  • 掌握知识蒸馏的代码操作.
  • 掌握知识蒸馏后模型的性能测试.

掌握知识蒸馏的代码操作

  • 本项目中对模型应用知识蒸馏的步骤如下:
    • 第一步: 查看项目数据集
    • 第二步: 查看预训练模型相关数据
    • 第三步: 编写工具类函数
    • 第四步: 编写两个模型类的代码
    • 第五步: 编写训练函数,测试函数,评估函数
    • 第六步: 编写运行主函数

第一步: 查看项目数据集

  • 数据集的路径为/home/ec2-user/toutiao/bert_distil/data/data/

  • 项目的数据集包括5个文件, 依次来看一下:

  • 标签文件/home/ec2-user/toutiao/bert_distil/data/data/class.txt


finance
realty
stocks
education
science
society
politics
sports
game
entertainment

  • class.txt中包含10个类别标签, 每行一个标签, 为英文单词的展示格式.

  • 训练数据集/home/ec2-user/toutiao/bert_distil/data/data/train.txt

中华女子学院:本科层次仅1专业招男生     3
两天价网站背后重重迷雾:做个网站究竟要多少钱    4
东5环海棠公社230-290平2居准现房98折优惠 1
卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球       7
82岁老太为学生做饭扫地44年获授港大荣誉院士      5
记者回访地震中可乐男孩:将受邀赴美国参观        5
冯德伦徐若�隔空传情 默认其是女友        9
传郭晶晶欲落户香港战伦敦奥运 装修别墅当婚房     1
《赤壁OL》攻城战诸侯战硝烟又起  8
“手机钱包”亮相科博会    4

  • train.txt中包含180000行样本, 每行包括两列, 第一列为待分类的中文文本, 第二列是数字化标签, 中间用\t作为分隔符.

  • 验证数据集/home/ec2-user/toutiao/bert_distil/data/data/dev.txt

体验2D巅峰 倚天屠龙记十大创新概览       8
60年铁树开花形状似玉米芯(组图)  5
同步A股首秀:港股缩量回调       2
中青宝sg现场抓拍 兔子舞热辣表演 8
锌价难续去年辉煌        0
2岁男童爬窗台不慎7楼坠下获救(图)        5
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击 7
金科西府 名墅天成       1
状元心经:考前一周重点是回顾和整理      3
发改委治理涉企收费每年为企业减负超百亿  6

  • dev.txt中包含10000行样本, 每行包括两列, 第一列为待分类的中文文本, 第二列是数字化标签, 中 间用\t作为分隔符.

  • 测试数据集/home/ec2-user/toutiao/bert_distil/data/data/test.txt

词汇阅读是关键 08年考研暑期英语复习全指南       3
中国人民公安大学2012年硕士研究生目录及书目      3
日本地震:金吉列关注在日学子系列报道    3
名师辅导:2012考研英语虚拟语气三种用法  3
自考经验谈:自考生毕业论文选题技巧      3
本科未录取还有这些路可以走      3
2009年成人高考招生统一考试时间表        3
去新西兰体验舌尖上的饕餮之旅(组图)      3
四级阅读与考研阅读比较分析与应试策略    3
备考2012高考作文必读美文50篇(一)        3

  • test.txt中包含10000行样本, 每行包括两列, 第一列为待分类的中文文本, 第二列是数字化标签, 中 间用\t作为分隔符.

  • 词典文件/home/ec2-user/toutiao/bert_distil/data/data/vocab.pkl为不可读文件, 训练模型中使用.


第二步: 查看预训练模型相关数据

  • 预训练模型相关数据的文件夹路径为/home/ec2-user/toutiao/bert_distil/data/bert_pretrain/

  • 预训练模型相关数据共包含3个文件:

  • BERT模型的超参数配置文件/home/ec2-user/toutiao/bert_distil/data/bert_pretrain/bert_config.json


{
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}

  • BERT预训练模型文件/home/ec2-user/toutiao/bert_distil/data/bert_pretrain/pytorch_model.bin

-rw-r--r-- 1 root root 411578458 1月   9 11:50 pytorch_model.bin

  • BERT预训练模型词典文件/home/ec2-user/toutiao/bert_distil/data/bert_pretrain/vocab.txt

[PAD]
[unused1]
[unused2]
[unused3]
[unused4]
[unused5]
[unused6]
[unused7]
[unused8]
[unused9]
[unused10]

......
......
......

[unused98]
[unused99]
[UNK]
[CLS]
[SEP]
[MASK]
<S>
<T>
!
......
......
......


第三步: 编写工具类函数

  • 工具类函数的路径为/home/ec2-user/toutiao/bert_distil/src/utils.py

  • 第一个工具类函数build_vocab(), 位于utils.py中的独立函数.

def build_vocab(file_path, tokenizer, max_size, min_freq):
    vocab_dic = {}
    with open(file_path, "r", encoding="UTF-8") as f:
        for line in tqdm(f):
            lin = line.strip()
            if not lin:
                continue
            content = lin.split("\t")[0]
            for word in tokenizer(content):
                vocab_dic[word] = vocab_dic.get(word, 0) + 1
        vocab_list = sorted(
                [_ for _ in vocab_dic.items() if _[1]>=min_freq],key=lambda x:x[1],reverse=True)[:max_size]
        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
    return vocab_dic

  • 第二个工具类函数build_dataset_CNN(), 位于utils.py中的独立函数.

def build_dataset_CNN(config):
    tokenizer = lambda x: [y for y in x]  # char-level
    if os.path.exists(config.vocab_path):
        vocab = pkl.load(open(config.vocab_path, "rb"))
    else:
        vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
        pkl.dump(vocab, open(config.vocab_path, "wb"))
    print(f"Vocab size: {len(vocab)}")

    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, "r", encoding="UTF-8") as f:
            for line in tqdm(f):
                lin = line.strip()
                if not lin:
                    continue
                content, label = lin.split("\t")
                words_line = []
                token = tokenizer(content)
                seq_len = len(token)
                if pad_size:
                    if len(token) < pad_size:
                        token.extend([PAD] * (pad_size - len(token)))
                    else:
                        token = token[:pad_size]
                        seq_len = pad_size
                # word to id
                for word in token:
                    words_line.append(vocab.get(word, vocab.get(UNK)))
                contents.append((words_line, int(label), seq_len))
        return contents  # [([...], 0), ([...], 1), ...]
    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return vocab, train, dev, test

  • 第三个工具类函数build_dataset(), 位于utils.py中的独立函数.
def build_dataset(config):
    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, "r", encoding="UTF-8") as f:
            for line in tqdm(f):
                line = line.strip()
                if not line:
                    continue
                content, label = line.split("\t")
                token = config.tokenizer.tokenize(content)
                token = [CLS] + token
                seq_len = len(token)
                mask = []
                token_ids = config.tokenizer.convert_tokens_to_ids(token)

                if pad_size:
                    if len(token) < pad_size:
                        mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
                        token_ids += [0] * (pad_size - len(token))
                    else:
                        mask = [1] * pad_size
                        token_ids = token_ids[:pad_size]
                        seq_len = pad_size
                contents.append((token_ids, int(label), seq_len, mask))
        return contents

    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return train, dev, test

  • 第四个工具函数build_iterator(), 包括数据迭代器的类class DatasetIterater(), 位于utils.py中的独立函数和类.
class DatasetIterater(object):
    def __init__(self, batches, batch_size, device, model_name):
        self.batch_size = batch_size
        self.batches = batches
        self.model_name = model_name
        self.n_batches = len(batches) // batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = device

    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)

        # pad前的长度(超过pad_size的设为pad_size)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        if self.model_name == "bert" or self.model_name == "multi_task_bert":
            mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
            return (x, seq_len, mask), y
        if self.model_name == "textCNN":
            return (x, seq_len), y

    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size : len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size : (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches


def build_iterator(dataset, config):
    iter = DatasetIterater(dataset, config.batch_size, config.device, config.model_name)
    return iter

  • 第五个工具类函数get_time_dif(), 位于utils.py中的独立函数.
def get_time_dif(start_time):
    # 获取已使用时间
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


第四步: 编写两个模型类的代码

  • Teacher模型采用BERT.
    • 代码位置: /home/ec2-user/toutiao/bert_distil/src/models/bert.py

  • Teacher模型的Config类代码, 位于bert.py代码文件中:
# coding: UTF-8
import torch
import torch.nn as nn
import os
from transformers import BertModel, BertTokenizer, BertConfig

class Config(object):
    def __init__(self, dataset):
        self.model_name = "bert"
        self.data_path = "/home/ec2-user/toutiao/bert_distil/data/data/"
        self.train_path = self.data_path + "train.txt"  # 训练集
        self.dev_path = self.data_path + "dev.txt"  # 验证集
        self.test_path = self.data_path + "test.txt"  # 测试集
        self.class_list = [
            x.strip() for x in open(self.data_path + "class.txt").readlines()
        ]  # 类别名单
        self.save_path = '/home/ec2-user/toutiao/bert_distil/src/saved_dict'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        self.save_path += "/" + self.model_name + ".pt"  # 模型训练结果
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备

        self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.num_epochs = 3  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.bert_path = "/home/ec2-user/toutiao/bert_distil/data/bert_pretrain"
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.bert_config = BertConfig.from_pretrained(self.bert_path + '/bert_config.json')
        self.hidden_size = 768

  • Teacher模型的Model类代码, 位于bert.py代码文件中:
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path, config=config.bert_config)

        for param in self.bert.parameters():
            param.requires_grad = True

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        # 输入的句子
        context = x[0]
        # 对padding部分进行mask, 和句子一个size, padding部分用0表示, 比如[1, 1, 1, 1, 0, 0]
        mask = x[2]
        _, pooled = self.bert(context, attention_mask=mask)
        out = self.fc(pooled)
        return out

  • Student模型采用textCNN.
    • 代码位置: /home/ec2-user/toutiao/bert_distil/src/models/textCNN.py

  • 首先看textCNN模型的架构图:


  • Student模型的Config类代码, 位于textCNN.py代码文件中:
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import os


class Config(object):
    def __init__(self, dataset):
        self.model_name = "textCNN"
        self.data_path = "/home/ec2-user/toutiao/bert_distil/data/data/"
        self.train_path = self.data_path + "train.txt"  # 训练集
        self.dev_path = self.data_path + "dev.txt"  # 验证集
        self.test_path = self.data_path + "test.txt"  # 测试集
        self.class_list = [x.strip() for x in open(self.data_path+"class.txt", encoding="utf-8").readlines()]
        self.vocab_path = self.data_path + "vocab.pkl"  # 词表
        self.save_path = "/home/ec2-user/toutiao/bert_distil/src/saved_dict"
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        self.save_path += "/" + self.model_name + ".pt"  # 模型训练结果
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备

        self.dropout = 0.5  # 随机失活
        self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.n_vocab = 0  # 词表大小,在运行时赋值
        self.num_epochs = 50  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3  # 学习率
        self.embed = 300  # 字向量维度
        self.filter_sizes = (2, 3, 4)  # 卷积核尺寸
        self.num_filters = 512  # 卷积核数量(channels数)

  • Student模型的Model类代码, 位于textCNN.py代码文件中:
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes]
        )
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, x):
        out = self.embedding(x[0])
        out = out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out


第五步: 编写训练函数,测试函数,评估函数



  • 编写训练函数之前, 很关键的一点是按照上面架构图的思路, 确定损失函数的计算规则, 以及得到Teacher网络的输出作为soft targets.
    • 导入相关工具包.
    • 编写获取Teacher网络输出的函数.
    • 编写损失值的计算函数.
    • 编写训练Teacher模型的训练函数.
    • 编写知识蒸馏的训练函数.
    • 编写测试函数.
    • 编写评估函数.

  • 代码位置: /home/ec2-user/toutiao/bert_distil/src/train_eval.py

  • 导入相关工具包:

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from transformers.optimization import AdamW
from tqdm import tqdm
import math
import logging

  • 编写获取Teacher网络输出的函数:

  • 确定了使用Bert作为Teacher模型, 那么这里就需要用Bert对全部训练数据做预测, 并将结果预先存储进一个list中. 这些预测结果就是soft targets, 未来给Student模型做"学习标签"使用.


  • 具体到了代码层面, 需要注意的是Teacher模型和Student模型的DataLoader不是同一个, batch_size和顺序都要保持一致, 才能保证后续的训练样本与soft targets对齐!
def fetch_teacher_outputs(teacher_model, train_iter):
    teacher_model.eval()
    teacher_outputs = []

    with torch.no_grad():
        for i, (data_batch, labels_batch) in enumerate(train_iter):
            outputs = teacher_model(data_batch)
            teacher_outputs.append(outputs)

    return teacher_outputs

  • 编写知识蒸馏的损失值计算函数:

  • 通常采用的交叉熵损失函数计算, 有一点需要注意, F.cross_entropy()对输入有限制, 要求label必须是one-hot格式的. 但Teacher网络的输出soft targets是概率分布的形式, 不匹配!


  • 因此这里我们采用KL散度作为soft targets的loss, 注意: Pytorch中的KL散度函数可以接收概率分布形式的label.

  • 关于温度系数T, 原始论文中有如下陈述:

Since the magnitudes of the gradients produced by the soft targets scale as 1/(T*T), it is important to multiply them by T*T when using both hard and soft targets. This ensures that the relative contributions of the hard and soft targets remain roughly unchanged if the temperature used for distillation is changed while experimenting with meta-parameters.


  • 就是说引入温度系数T会导致软目标(soft targets)产生的梯度和真实目标产生的梯度相比, 只有1/(T^2), 因此计算完软目标的loss值后要乘以T^2.(对详细证明感兴趣的同学可以查阅原始论文)

  • 关于添加了温度系数T, 基于KL散度的损失函数代码如下:
criterion = nn.KLDivLoss()


def loss_fn(outputs, labels):
    # 一定要注意先定义类对象,再输入函数参数,不能写成nn.CrossEntropyLoss(outputs,labels),少了类对象
的括号()
    return nn.CrossEntropyLoss()(outputs, labels)


# 编写实现KL散度的损失函数的代码
def loss_fn_kd(outputs, labels, teacher_outputs):
    # 注意:pytorch中的KL散度nn.KLDivLoss要求student输入为log-probabilities,软目标为probabilities
    # 关于API函数nn.KLDivLoss(), 第1个参数必须是经历log计算后的分布值, 第2个参数必须是没有log计算>的分布值
    alpha = 0.8
    T = 2

    # 软目标损失
    # 首先计算学生网络的带有T参数的log_softmax输出分布
    output_student = F.log_softmax(outputs / T, dim=1)

    # 然后计算教师网络的带有T参数的softmax输出分布
    output_teacher = F.softmax(teacher_outputs / T, dim=1)

    # 计算软目标损失,使用KLDivLoss(),第一个参数为student网络输出, 第二个参数为teacher网络输出
    soft_loss = criterion(output_student, output_teacher)

    # 硬目标损失
    # 即学生网络的输出概率和真实标签之间的损失, 因为真实标签是one-hot编码, 因此直接使用交叉熵损失>即可
    hard_loss = F.cross_entropy(outputs, labels)

    # 计算总损失
    # 原始论文中已经证明, 引入T会导致软目标产生的梯度和真实目标产生的梯度相比只有1/(T*T)
    # 因此计算完软目标的loss值后要乘以T^2.
    KD_loss = soft_loss * alpha * T * T + hard_loss * (1.0 - alpha)

    return KD_loss

  • 编写训练Teacher模型的训练函数:
def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()
    model.train()
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.01
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0
            }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float("inf")
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升

    model.train()
    for epoch in range(config.num_epochs):
        print("Epoch [{}/{}]".format(epoch + 1, config.num_epochs))
        for i, (trains, labels) in enumerate(tqdm(train_iter)):
            model.zero_grad()
            outputs = model(trains)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            if total_batch % 100 == 0:
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = "*"
                    last_impove = total_batch
                else:
                    improve = ""
                time_dif = get_time_dif(start_time)
                msg = "Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}"
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    test(config, model, test_iter)

  • 编写知识蒸馏的训练函数:
def train_kd(bert_config, cnn_config, bert_model, cnn_model,
             bert_train_iter, cnn_train_iter, cnn_dev_iter, cnn_test_iter):
    start_time = time.time()
    param_optimizer = list(cnn_model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                    "weight_decay": 0.01
                },
                {
                    "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0
                }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=cnn_config.learning_rate)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float("inf")
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升

    cnn_model.train()
    loading_start = time.time()
    bert_model.eval()
    teacher_outputs = fetch_teacher_outputs(bert_model, bert_train_iter)
    elapsed_time = math.ceil(time.time() - loading_start)
    logging.info("- Finished computing teacher outputs after {} secs..".format(elapsed_time))

    for epoch in range(cnn_config.num_epochs):
        print("Epoch [{}/{}]".format(epoch + 1, cnn_config.num_epochs))
        for i, (trains, labels) in enumerate(tqdm(cnn_train_iter)):
            cnn_model.zero_grad()
            outputs = cnn_model(trains)
            loss = loss_fn_kd(outputs, labels, teacher_outputs[i])
            loss.backward()
            optimizer.step()

            if total_batch % 100 == 0:
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(cnn_config, cnn_model, cnn_dev_iter)

                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(cnn_model.state_dict(), cnn_config.save_path)
                    improve = "*"
                    last_improve = total_batch
                else:
                    improve = ""
                time_dif = get_time_dif(start_time)
                msg = "Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}"
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                cnn_model.train()
            total_batch += 1
            if total_batch - last_improve > cnn_config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    test(cnn_config, cnn_model, cnn_test_iter)

  • 编写测试函数:
def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = "Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}"
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

  • 编写评估函数:
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)

    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all,predict_all,target_names=config.class_list,digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)


第六步: 编写运行主函数

  • 主函数的路径为/home/ec2-user/toutiao/bert_distil/src/run.py
# coding: UTF-8
import numpy as np
import torch
import time
from train_eval import train_kd, train
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif, build_dataset_CNN

parser = argparse.ArgumentParser(description="Chinese Text Classification")
parser.add_argument("--task", type=str, required=True, help="choose a task: trainbert, or train_kd")
args = parser.parse_args()

if __name__ == "__main__":
    dataset = "toutiao"

    if args.task == "trainbert":
        model_name = "bert"
        x = import_module("models." + model_name)
        config = x.Config(dataset)
        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.deterministic = True  # 保证每次结果一样

        print("Loading data for Bert Model...")
        train_data, dev_data, test_data = build_dataset(config)
        train_iter = build_iterator(train_data, config)
        dev_iter = build_iterator(dev_data, config)
        test_iter = build_iterator(test_data, config)

        model = x.Model(config).to(config.device)
        train(config, model, train_iter, dev_iter, test_iter)

    if args.task == "train_kd":
        model_name = "bert"
        bert_module = import_module("models." + model_name)
        bert_config = bert_module.Config(dataset)

        model_name = "textCNN"
        cnn_module = import_module("models." + model_name)
        cnn_config = cnn_module.Config(dataset)

        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.deterministic = True  # 保证每次结果一样

        # 构建bert数据集,因为只需要训练结果作为软目标,这里不需要dev_iter和test_iter
        bert_train_data, _, _ = build_dataset(bert_config)
        bert_train_iter = build_iterator(bert_train_data, bert_config)

        # 构建cnn数据集
        vocab, cnn_train_data, cnn_dev_data, cnn_test_data = build_dataset_CNN(cnn_config)
        cnn_train_iter = build_iterator(cnn_train_data, cnn_config)
        cnn_dev_iter = build_iterator(cnn_dev_data, cnn_config)
        cnn_test_iter = build_iterator(cnn_test_data, cnn_config)
        cnn_config.n_vocab = len(vocab)

        print("Data loaded, now load teacher model")
        # 加载训练好的teacher模型
        bert_model = bert_module.Model(bert_config).to(bert_config.device)

        # 加载student模型
        cnn_model = cnn_module.Model(cnn_config).to(cnn_config.device)

        print("Teacher and student models loaded, start training")
        train_kd(bert_config, cnn_config, bert_model, cnn_model,
                 bert_train_iter, cnn_train_iter, cnn_dev_iter, cnn_test_iter)

  • 第一步: 训练Teacher模型(Bert模型的训练)
# 切换到主训练函数所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 直接在命令行运行训练Teacher模型的代码
python run.py --task trainbert

  • 输出结果:
Loading data for Bert Model...
180000it [00:37, 4820.80it/s]
10000it [00:02, 4954.00it/s]
10000it [00:02, 4952.50it/s]
Epoch [1/3]
 14%|█████████▉                                                            | 200/1407 [02:06<13:26,  1.50it/s]Iter:    200,  Train Loss:   0.3,  Train Acc: 91.41%,  Val Loss:  0.29,  Val Acc: 90.86%,  Time: 0:02:26 *
 28%|███████████████████▉                                                  | 400/1407 [04:44<11:46,  1.43it/s]Iter:    400,  Train Loss:  0.34,  Train Acc: 90.62%,  Val Loss:  0.26,  Val Acc: 92.10%,  Time: 0:05:07 *
 43%|█████████████████████████████▊                                        | 600/1407 [07:26<09:25,  1.43it/s]Iter:    600,  Train Loss:  0.29,  Train Acc: 91.41%,  Val Loss:  0.25,  Val Acc: 92.10%,  Time: 0:07:49 *
 57%|███████████████████████████████████████▊                              | 800/1407 [10:08<07:06,  1.42it/s]Iter:    800,  Train Loss:  0.15,  Train Acc: 94.53%,  Val Loss:  0.22,  Val Acc: 92.85%,  Time: 0:10:31 *
 71%|█████████████████████████████████████████████████                    | 1000/1407 [12:50<04:43,  1.44it/s]Iter:   1000,  Train Loss:  0.17,  Train Acc: 94.53%,  Val Loss:  0.22,  Val Acc: 93.00%,  Time: 0:13:10 
No optimization for a long time, auto-stopping...
Test Loss:   0.2,  Test Acc: 93.64%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9246    0.9320    0.9283      1000
       realty     0.9484    0.9370    0.9427      1000
       stocks     0.8787    0.8980    0.8882      1000
    education     0.9511    0.9730    0.9619      1000
      science     0.9236    0.8950    0.9091      1000
      society     0.9430    0.9270    0.9349      1000
     politics     0.9267    0.9100    0.9183      1000
       sports     0.9780    0.9780    0.9780      1000
         game     0.9514    0.9600    0.9557      1000
entertainment     0.9390    0.9540    0.9464      1000

     accuracy                         0.9364     10000
    macro avg     0.9365    0.9364    0.9364     10000
 weighted avg     0.9365    0.9364    0.9364     10000

Confusion Matrix...
[[932  10  37   2   5   5   7   1   1   0]
 [ 13 937  11   2   4  10   5   5   5   8]
 [ 49  12 898   1  19   1  15   0   2   3]
 [  1   1   0 973   0   8   7   0   1   9]
 [  4   4  28   7 895  10  12   2  27  11]
 [  2   8   4  16   5 927  18   1   5  14]
 [  3   8  34  12   9  19 910   0   0   5]
 [  2   3   2   1   1   1   4 978   1   7]
 [  0   2   4   0  24   1   3   1 960   5]
 [  2   3   4   9   7   1   1  12   7 954]]
Time usage: 0:00:19
 71%|█████████████████████████████████████████████████                    | 1000/1407 [13:29<05:29,  1.24it/s]

  • 结论: Teacher模型在测试集上的表现是Test Acc: 93.64%

  • 第二步: 训练Student模型(采用知识蒸馏的模式)

  • 设定Config中的重要参数如下:

# 模型迭代3轮
self.num_epochs = 3

# 卷积核尺寸分别选2, 3, 4
self.filter_sizes = (2, 3, 4)

# 卷积核的个数512
self.num_filters = 512

  • 调用:
# 切换到主训练函数所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 直接在命令行运行训练Student模型的代码
python run.py --task train_kd

  • 输出结果:
180000it [00:37, 4862.22it/s]
10000it [00:02, 4988.47it/s]
10000it [00:02, 4981.50it/s]
Vocab size: 4762
180000it [00:02, 69598.12it/s]
10000it [00:00, 82889.25it/s]
10000it [00:00, 82326.33it/s]
Data loaded, now load teacher model
Teacher and student models loaded, start training
Epoch [1/20]
 14%|█████████▉                                                            | 199/1407 [00:08<00:50, 23.87it/s]Iter:    200,  Train Loss:  0.29,  Train Acc: 69.53%,  Val Loss:  0.85,  Val Acc: 82.36%,  Time: 0:05:32 *
 28%|███████████████████▉                                                  | 400/1407 [00:17<00:42, 23.95it/s]Iter:    400,  Train Loss:  0.27,  Train Acc: 73.44%,  Val Loss:  0.81,  Val Acc: 84.00%,  Time: 0:05:40 *
 43%|█████████████████████████████▊                                        | 598/1407 [00:25<00:33, 23.86it/s]Iter:    600,  Train Loss:  0.24,  Train Acc: 83.59%,  Val Loss:  0.76,  Val Acc: 85.97%,  Time: 0:05:49 *
 57%|███████████████████████████████████████▊                              | 799/1407 [00:34<00:25, 23.91it/s]Iter:    800,  Train Loss:  0.23,  Train Acc: 83.59%,  Val Loss:  0.76,  Val Acc: 85.49%,  Time: 0:05:58 
 71%|█████████████████████████████████████████████████                    | 1000/1407 [00:43<00:17, 23.89it/s]Iter:   1000,  Train Loss:  0.21,  Train Acc: 84.38%,  Val Loss:  0.74,  Val Acc: 85.94%,  Time: 0:06:07 *
 85%|██████████████████████████████████████████████████████████▊          | 1198/1407 [00:52<00:08, 23.80it/s]Iter:   1200,  Train Loss:  0.22,  Train Acc: 85.94%,  Val Loss:  0.72,  Val Acc: 86.92%,  Time: 0:06:16 *
 99%|████████████████████████████████████████████████████████████████████▌| 1399/1407 [01:01<00:00, 23.85it/s]Iter:   1400,  Train Loss:  0.24,  Train Acc: 79.69%,  Val Loss:  0.72,  Val Acc: 86.87%,  Time: 0:06:24 *
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [01:01<00:00, 22.73it/s]
Epoch [2/20]
 14%|█████████▊                                                            | 198/1407 [00:08<00:50, 23.95it/s]Iter:    200,  Train Loss:  0.23,  Train Acc: 85.16%,  Val Loss:   0.7,  Val Acc: 88.34%,  Time: 0:06:33 *
 28%|███████████████████▊                                                  | 399/1407 [00:17<00:42, 23.92it/s]Iter:    400,  Train Loss:  0.23,  Train Acc: 82.81%,  Val Loss:  0.68,  Val Acc: 88.36%,  Time: 0:06:42 *
 43%|█████████████████████████████▊                                        | 600/1407 [00:25<00:33, 24.06it/s]Iter:    600,  Train Loss:   0.2,  Train Acc: 91.41%,  Val Loss:  0.68,  Val Acc: 88.26%,  Time: 0:06:51 *
 57%|███████████████████████████████████████▋                              | 798/1407 [00:34<00:25, 23.98it/s]Iter:    800,  Train Loss:  0.21,  Train Acc: 87.50%,  Val Loss:  0.67,  Val Acc: 88.83%,  Time: 0:07:00 *
 71%|█████████████████████████████████████████████████▋                    | 999/1407 [00:43<00:17, 23.94it/s]Iter:   1000,  Train Loss:  0.19,  Train Acc: 91.41%,  Val Loss:  0.68,  Val Acc: 88.52%,  Time: 0:07:09 
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [00:52<00:08, 24.00it/s]Iter:   1200,  Train Loss:   0.2,  Train Acc: 88.28%,  Val Loss:  0.67,  Val Acc: 89.07%,  Time: 0:07:17 *
 99%|████████████████████████████████████████████████████████████████████▌| 1398/1407 [01:00<00:00, 23.81it/s]Iter:   1400,  Train Loss:  0.21,  Train Acc: 86.72%,  Val Loss:  0.67,  Val Acc: 88.87%,  Time: 0:07:26 *
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [01:01<00:00, 22.79it/s]
Epoch [3/20]
 14%|█████████▊                                                            | 198/1407 [00:08<00:50, 23.90it/s]Iter:    200,  Train Loss:  0.22,  Train Acc: 85.16%,  Val Loss:  0.64,  Val Acc: 89.15%,  Time: 0:07:35 *
 28%|███████████████████▊                                                  | 399/1407 [00:17<00:42, 23.98it/s]Iter:    400,  Train Loss:  0.21,  Train Acc: 84.38%,  Val Loss:  0.64,  Val Acc: 89.43%,  Time: 0:07:44 *
 43%|█████████████████████████████▊                                        | 600/1407 [00:25<00:33, 24.07it/s]Iter:    600,  Train Loss:   0.2,  Train Acc: 91.41%,  Val Loss:  0.65,  Val Acc: 89.54%,  Time: 0:07:53 
 57%|███████████████████████████████████████▋                              | 798/1407 [00:34<00:25, 23.95it/s]Iter:    800,  Train Loss:   0.2,  Train Acc: 88.28%,  Val Loss:  0.64,  Val Acc: 89.50%,  Time: 0:08:01 
 71%|█████████████████████████████████████████████████▋                    | 999/1407 [00:43<00:17, 23.93it/s]Iter:   1000,  Train Loss:  0.18,  Train Acc: 90.62%,  Val Loss:  0.66,  Val Acc: 89.14%,  Time: 0:08:10 
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [00:52<00:08, 24.03it/s]Iter:   1200,  Train Loss:  0.19,  Train Acc: 92.97%,  Val Loss:  0.65,  Val Acc: 89.36%,  Time: 0:08:19 
 99%|████████████████████████████████████████████████████████████████████▌| 1398/1407 [01:00<00:00, 24.01it/s]Iter:   1400,  Train Loss:   0.2,  Train Acc: 86.72%,  Val Loss:  0.65,  Val Acc: 89.24%,  Time: 0:08:28 
No optimization for a long time, auto-stopping...
Test Loss:  0.62,  Test Acc: 89.89%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9297    0.8730    0.9005      1000
       realty     0.9341    0.9070    0.9203      1000
       stocks     0.8183    0.8780    0.8471      1000
    education     0.9564    0.9430    0.9496      1000
      science     0.8964    0.8220    0.8576      1000
      society     0.8359    0.9220    0.8768      1000
     politics     0.8920    0.8590    0.8752      1000
       sports     0.9436    0.9540    0.9488      1000
         game     0.9263    0.9050    0.9155      1000
entertainment     0.8736    0.9260    0.8990      1000

     accuracy                         0.8989     10000
    macro avg     0.9006    0.8989    0.8991     10000
 weighted avg     0.9006    0.8989    0.8991     10000

Confusion Matrix...
[[873   9  68   1   6  19  14   3   3   4]
 [ 15 907  13   2   4  18  13   6   2  20]
 [ 38  24 878   1  18  10  23   2   4   2]
 [  1   2   4 943   4  19   6   5   3  13]
 [  2   3  54   5 822  30  26   2  36  20]
 [  1  14   4  19   6 922  14   1   2  17]
 [  8   8  35   9  11  47 859   5   2  16]
 [  1   1   3   1   3  12   3 954   1  21]
 [  0   0   9   2  37   6   5  15 905  21]
 [  0   3   5   3   6  20   0  18  19 926]]
Time usage: 0:00:00
 99%|████████████████████████████████████████████████████████████████████▌| 1398/1407 [01:01<00:00, 22.65it/s]

  • 结论: Student模型在测试集上的表现是Test Acc: 89.89%

  • 第三步: 调参训练Student模型(采用知识蒸馏的模式)

  • 对Config类中的若干超参数做出重要修改:

# 模型迭代30轮
self.num_epochs = 30

# 卷积核尺寸分别选2, 3, 4, 5
self.filter_sizes = (2, 3, 4, 5)

# 卷积核的个数1024
self.num_filters = 1024

  • 调参后再次训练Student模型:
# 切换到主训练函数所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 直接在命令行运行训练Student模型的代码
python run.py --task train_kd

  • 输出结果:
180000it [00:37, 4830.81it/s]
10000it [00:02, 4935.57it/s]
10000it [00:02, 4955.57it/s]
Vocab size: 4762
180000it [00:02, 69735.78it/s]
10000it [00:00, 82937.77it/s]
10000it [00:00, 82402.02it/s]
Data loaded, now load teacher model
Teacher and student models loaded, start training
Epoch [1/30]
 28%|███████████████████▊                                                  | 399/1407 [00:39<01:40, 10.06it/s]Iter:    400,  Train Loss:  0.29,  Train Acc: 75.00%,  Val Loss:  0.76,  Val Acc: 84.65%,  Time: 0:06:00 *
 57%|███████████████████████████████████████▊                              | 800/1407 [01:20<01:00, 10.05it/s]Iter:    800,  Train Loss:  0.24,  Train Acc: 82.81%,  Val Loss:  0.71,  Val Acc: 86.89%,  Time: 0:06:41 *
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [02:01<00:20, 10.06it/s]Iter:   1200,  Train Loss:  0.23,  Train Acc: 82.81%,  Val Loss:  0.72,  Val Acc: 85.35%,  Time: 0:07:22 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:23<00:00,  9.80it/s]
Epoch [2/30]
 28%|███████████████████▉                                                  | 400/1407 [00:39<01:40, 10.06it/s]Iter:    400,  Train Loss:  0.23,  Train Acc: 79.69%,  Val Loss:  0.67,  Val Acc: 88.46%,  Time: 0:08:24 *
 57%|███████████████████████████████████████▊                              | 800/1407 [01:20<01:00, 10.08it/s]Iter:    800,  Train Loss:  0.21,  Train Acc: 86.72%,  Val Loss:  0.66,  Val Acc: 88.74%,  Time: 0:09:04 *
 85%|██████████████████████████████████████████████████████████▊          | 1199/1407 [02:01<00:20, 10.09it/s]Iter:   1200,  Train Loss:  0.21,  Train Acc: 92.19%,  Val Loss:  0.67,  Val Acc: 88.86%,  Time: 0:09:45 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:22<00:00,  9.85it/s]
Epoch [3/30]
 28%|███████████████████▉                                                  | 400/1407 [00:39<01:39, 10.13it/s]Iter:    400,  Train Loss:  0.22,  Train Acc: 82.81%,  Val Loss:  0.63,  Val Acc: 89.46%,  Time: 0:10:46 *
 57%|███████████████████████████████████████▊                              | 799/1407 [01:20<00:59, 10.15it/s]Iter:    800,  Train Loss:   0.2,  Train Acc: 92.97%,  Val Loss:  0.64,  Val Acc: 89.56%,  Time: 0:11:27 
 85%|██████████████████████████████████████████████████████████▊          | 1199/1407 [02:00<00:20, 10.15it/s]Iter:   1200,  Train Loss:  0.19,  Train Acc: 92.19%,  Val Loss:  0.64,  Val Acc: 89.66%,  Time: 0:12:08 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:22<00:00,  9.89it/s]
Epoch [4/30]
 28%|███████████████████▉                                                  | 400/1407 [00:39<01:39, 10.17it/s]Iter:    400,  Train Loss:  0.19,  Train Acc: 90.62%,  Val Loss:  0.64,  Val Acc: 89.44%,  Time: 0:13:08 
 57%|███████████████████████████████████████▊                              | 799/1407 [01:19<00:59, 10.18it/s]Iter:    800,  Train Loss:  0.18,  Train Acc: 95.31%,  Val Loss:  0.62,  Val Acc: 89.96%,  Time: 0:13:49 *
 85%|██████████████████████████████████████████████████████████▊          | 1199/1407 [02:00<00:20, 10.18it/s]Iter:   1200,  Train Loss:  0.18,  Train Acc: 92.19%,  Val Loss:  0.64,  Val Acc: 89.69%,  Time: 0:14:29 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:21<00:00,  9.92it/s]
Epoch [5/30]
 28%|███████████████████▊                                                  | 399/1407 [00:39<01:38, 10.25it/s]Iter:    400,  Train Loss:   0.2,  Train Acc: 88.28%,  Val Loss:  0.63,  Val Acc: 89.28%,  Time: 0:15:30 
 57%|███████████████████████████████████████▊                              | 799/1407 [01:19<00:59, 10.29it/s]Iter:    800,  Train Loss:  0.19,  Train Acc: 90.62%,  Val Loss:  0.64,  Val Acc: 89.60%,  Time: 0:16:10 
 85%|██████████████████████████████████████████████████████████▊          | 1199/1407 [01:59<00:20, 10.29it/s]Iter:   1200,  Train Loss:  0.17,  Train Acc: 96.88%,  Val Loss:  0.64,  Val Acc: 89.51%,  Time: 0:16:50 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:20<00:00, 10.02it/s]

......
......
......

Epoch [28/30]
 28%|███████████████████▉                                                  | 400/1407 [00:38<01:36, 10.40it/s]Iter:    400,  Train Loss:  0.15,  Train Acc: 98.44%,  Val Loss:  0.64,  Val Acc: 90.58%,  Time: 1:08:43 
 57%|███████████████████████████████████████▊                              | 800/1407 [01:17<00:58, 10.43it/s]Iter:    800,  Train Loss:  0.16,  Train Acc: 98.44%,  Val Loss:  0.63,  Val Acc: 91.09%,  Time: 1:09:22 
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [01:57<00:19, 10.43it/s]Iter:   1200,  Train Loss:  0.15,  Train Acc: 96.88%,  Val Loss:  0.64,  Val Acc: 90.55%,  Time: 1:10:02 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:18<00:00, 10.17it/s]
Epoch [29/30]
 28%|███████████████████▉                                                  | 400/1407 [00:38<01:36, 10.41it/s]Iter:    400,  Train Loss:  0.15,  Train Acc: 97.66%,  Val Loss:  0.64,  Val Acc: 90.78%,  Time: 1:11:01 
 57%|███████████████████████████████████████▊                              | 800/1407 [01:17<00:58, 10.42it/s]Iter:    800,  Train Loss:  0.16,  Train Acc: 98.44%,  Val Loss:  0.63,  Val Acc: 90.58%,  Time: 1:11:40 
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [01:57<00:19, 10.41it/s]Iter:   1200,  Train Loss:  0.15,  Train Acc: 97.66%,  Val Loss:  0.62,  Val Acc: 90.72%,  Time: 1:12:20 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:18<00:00, 10.17it/s]
Epoch [30/30]
 28%|███████████████████▉                                                  | 400/1407 [00:38<01:36, 10.40it/s]Iter:    400,  Train Loss:  0.16,  Train Acc: 98.44%,  Val Loss:  0.65,  Val Acc: 90.66%,  Time: 1:13:19 
 57%|███████████████████████████████████████▊                              | 800/1407 [01:17<00:58, 10.43it/s]Iter:    800,  Train Loss:  0.15,  Train Acc: 98.44%,  Val Loss:  0.63,  Val Acc: 90.79%,  Time: 1:13:59 
 85%|██████████████████████████████████████████████████████████▊          | 1200/1407 [01:57<00:19, 10.40it/s]Iter:   1200,  Train Loss:  0.15,  Train Acc: 99.22%,  Val Loss:  0.64,  Val Acc: 90.65%,  Time: 1:14:38 
100%|█████████████████████████████████████████████████████████████████████| 1407/1407 [02:18<00:00, 10.17it/s]
Test Loss:   0.6,  Test Acc: 91.25%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9105    0.9050    0.9077      1000
       realty     0.9311    0.9320    0.9315      1000
       stocks     0.8912    0.8440    0.8670      1000
    education     0.9532    0.9570    0.9551      1000
      science     0.8836    0.8730    0.8783      1000
      society     0.8306    0.9270    0.8762      1000
     politics     0.9041    0.8770    0.8904      1000
       sports     0.9733    0.9470    0.9600      1000
         game     0.9467    0.9240    0.9352      1000
entertainment     0.9108    0.9390    0.9247      1000

     accuracy                         0.9125     10000
    macro avg     0.9135    0.9125    0.9126     10000
 weighted avg     0.9135    0.9125    0.9126     10000

Confusion Matrix...
[[905  10  38   4   5  19  11   3   0   5]
 [ 13 932  13   2   3  17   6   3   4   7]
 [ 54  23 844   1  32   6  33   1   4   2]
 [  2   2   1 957   4  15   6   1   3   9]
 [  3   5  24   5 873  32  17   3  25  13]
 [  2  15   3  18   5 927  12   0   1  17]
 [ 12  10  16  10  14  48 877   2   3   8]
 [  2   0   3   1   2  21   4 947   1  19]
 [  0   0   3   3  43   8   2   5 924  12]
 [  1   4   2   3   7  23   2   8  11 939]]
Time usage: 0:00:01

  • 结论: 调参后的Student模型在测试集上的表现是Test Acc: 91.25%



掌握知识蒸馏后模型的性能测试

  • 完成知识蒸馏后, 我们获得了两个模型, Teacher模型和Student模型, 分别对其进行性能测试是项目开发的重要环节.
    • 第一步: 对比知识蒸馏模型的大小和准确率等指标
    • 第二步: 对比Teacher模型和Student模型的推理速度

第一步: 对比知识蒸馏模型的大小和准确率等指标

  • 首先展示Teacher模型的大小和准确率等指标:
# Teacher模型的大小
-rw-rw-r-- 1 root root 409190601 bert.pt

# Teacher模型的准确率等指标
Test Loss:   0.2,  Test Acc: 93.64%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9246    0.9320    0.9283      1000
       realty     0.9484    0.9370    0.9427      1000
       stocks     0.8787    0.8980    0.8882      1000
    education     0.9511    0.9730    0.9619      1000
      science     0.9236    0.8950    0.9091      1000
      society     0.9430    0.9270    0.9349      1000
     politics     0.9267    0.9100    0.9183      1000
       sports     0.9780    0.9780    0.9780      1000
         game     0.9514    0.9600    0.9557      1000
entertainment     0.9390    0.9540    0.9464      1000

     accuracy                         0.9364     10000
    macro avg     0.9365    0.9364    0.9364     10000
 weighted avg     0.9365    0.9364    0.9364     10000

  • 然后展示Student模型的大小和准确率等指标:

  • 第一次采用的重要超参数如下:

# 模型迭代3轮
self.num_epochs = 3

# 卷积核尺寸分别选2, 3, 4
self.filter_sizes = (2, 3, 4)

# 卷积核的个数512
self.num_filters = 512

  • 输出结果:
# Student模型的大小
-rw-rw-r-- 1 root root  11315008  textCNN.pt

# Student模型的准确率等指标
Test Loss:  0.62,  Test Acc: 89.89%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9297    0.8730    0.9005      1000
       realty     0.9341    0.9070    0.9203      1000
       stocks     0.8183    0.8780    0.8471      1000
    education     0.9564    0.9430    0.9496      1000
      science     0.8964    0.8220    0.8576      1000
      society     0.8359    0.9220    0.8768      1000
     politics     0.8920    0.8590    0.8752      1000
       sports     0.9436    0.9540    0.9488      1000
         game     0.9263    0.9050    0.9155      1000
entertainment     0.8736    0.9260    0.8990      1000

     accuracy                         0.8989     10000
    macro avg     0.9006    0.8989    0.8991     10000
 weighted avg     0.9006    0.8989    0.8991     10000

  • 第二次采用的重要超参数如下:
# 模型迭代30轮
self.num_epochs = 30

# 卷积核尺寸分别选2, 3, 4, 5
self.filter_sizes = (2, 3, 4, 5)

# 卷积核的个数1024
self.num_filters = 1024

  • 输出结果:
# Student模型的大小
-rw-rw-r-- 1 root root  23101842  textCNN.pt

# Student模型的准确率等指标
Test Loss:   0.6,  Test Acc: 91.25%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9105    0.9050    0.9077      1000
       realty     0.9311    0.9320    0.9315      1000
       stocks     0.8912    0.8440    0.8670      1000
    education     0.9532    0.9570    0.9551      1000
      science     0.8836    0.8730    0.8783      1000
      society     0.8306    0.9270    0.8762      1000
     politics     0.9041    0.8770    0.8904      1000
       sports     0.9733    0.9470    0.9600      1000
         game     0.9467    0.9240    0.9352      1000
entertainment     0.9108    0.9390    0.9247      1000

     accuracy                         0.9125     10000
    macro avg     0.9135    0.9125    0.9126     10000
 weighted avg     0.9135    0.9125    0.9126     10000

Confusion Matrix...
[[905  10  38   4   5  19  11   3   0   5]
 [ 13 932  13   2   3  17   6   3   4   7]
 [ 54  23 844   1  32   6  33   1   4   2]
 [  2   2   1 957   4  15   6   1   3   9]
 [  3   5  24   5 873  32  17   3  25  13]
 [  2  15   3  18   5 927  12   0   1  17]
 [ 12  10  16  10  14  48 877   2   3   8]
 [  2   0   3   1   2  21   4 947   1  19]
 [  0   0   3   3  43   8   2   5 924  12]
 [  1   4   2   3   7  23   2   8  11 939]]
Time usage: 0:00:01

  • 结论1: Teacher模型大小为409.2MB, Student模型大小为11.3MB和23.1MB.

  • 结论2: Teacher模型测试集准确率为93.64%, Student模型测试集准确率为89.89%和91.25%.



第二步: 对比Teacher模型和Student模型的推理速度

  • 一般都是在测试集上评估模型的运行速度和准确率等指标

  • 1: 评估Teacher模型在GPU上的运行速度和准确率:

  • 增加测试代码文件/home/ec2-user/toutiao/bert_distil/src/test_eval.py
# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
import math

def test(config, model, test_iter):
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = "Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}"
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(config, model, data_iter, test=False):
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)

    average_time = 0.0

    with torch.no_grad():
        for texts, labels in data_iter:
            start_time = time.time()
            outputs = model(texts)
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            end_time = time.time() - start_time
            average_time += end_time
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    average = average_time / 10000
    print('Average predict time is:', average*1000, 'ms')
    print('*********************************')
    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all,predict_all,target_names=config.class_list,digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)

  • 增加测试主函数代码文件/home/ec2-user/toutiao/bert_distil/src/test.py
# coding: UTF-8
import numpy as np
import torch
import time
from test_eval import test
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif, build_dataset_CNN

parser = argparse.ArgumentParser(description="Chinese Text Classification")
parser.add_argument("--task", type=str, required=True, help="choose a task: test_bert, or test_kd")
args = parser.parse_args()

if __name__ == "__main__":
    dataset = "toutiao"

    if args.task == "test_bert":
        model_name = "bert"
        x = import_module("models." + model_name)
        config = x.Config(dataset)
        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.deterministic = True  # 保证每次结果一样

        print("Loading data for Bert Model...")
        train_data, dev_data, test_data = build_dataset(config)
        test_iter = build_iterator(test_data, config)

        model = x.Model(config).to(config.device)
        test(config, model, test_iter)

  • 调用:
# 切换到测试程序所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 命令行执行测试任务
python test.py --task test_bert

  • 输出结果:
Loading data for Bert Model...
180000it [00:36, 4876.33it/s]
10000it [00:02, 4993.56it/s]
10000it [00:02, 4996.31it/s]
Average predict time is: 1.5909206867218018 ms
*********************************
Test Loss:   0.2,  Test Acc: 93.64%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9246    0.9320    0.9283      1000
       realty     0.9484    0.9370    0.9427      1000
       stocks     0.8787    0.8980    0.8882      1000
    education     0.9511    0.9730    0.9619      1000
      science     0.9236    0.8950    0.9091      1000
      society     0.9430    0.9270    0.9349      1000
     politics     0.9267    0.9100    0.9183      1000
       sports     0.9780    0.9780    0.9780      1000
         game     0.9514    0.9600    0.9557      1000
entertainment     0.9390    0.9540    0.9464      1000

     accuracy                         0.9364     10000
    macro avg     0.9365    0.9364    0.9364     10000
 weighted avg     0.9365    0.9364    0.9364     10000

Confusion Matrix...
[[932  10  37   2   5   5   7   1   1   0]
 [ 13 937  11   2   4  10   5   5   5   8]
 [ 49  12 898   1  19   1  15   0   2   3]
 [  1   1   0 973   0   8   7   0   1   9]
 [  4   4  28   7 895  10  12   2  27  11]
 [  2   8   4  16   5 927  18   1   5  14]
 [  3   8  34  12   9  19 910   0   0   5]
 [  2   3   2   1   1   1   4 978   1   7]
 [  0   2   4   0  24   1   3   1 960   5]
 [  2   3   4   9   7   1   1  12   7 954]]
Time usage: 0:00:16

  • 结论: Teacher模型在测试集上的准确率为93.64%, 平均单条样本推断时间1.591ms.

  • 2: 评估Student模型在GPU上的运行速度和准确率:

  • 向测试主代码文件test.py中追加如下代码:
# coding: UTF-8
import numpy as np
import torch
import time
from test_eval import test
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif, build_dataset_CNN

parser = argparse.ArgumentParser(description="Chinese Text Classification")
parser.add_argument("--task", type=str, required=True, help="choose a task: test_bert, or test_kd")
args = parser.parse_args()

if __name__ == "__main__":
    dataset = "toutiao"

    if args.task == "test_kd":
        model_name = "textCNN"
        cnn_module = import_module("models." + model_name)
        cnn_config = cnn_module.Config(dataset)

        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.deterministic = True  # 保证每次结果一样

        # 构建cnn数据集
        vocab, cnn_train_data, cnn_dev_data, cnn_test_data = build_dataset_CNN(cnn_config)
        cnn_test_iter = build_iterator(cnn_test_data, cnn_config)
        cnn_config.n_vocab = len(vocab)

        # 加载student模型
        cnn_model = cnn_module.Model(cnn_config).to(cnn_config.device)
        test(cnn_config, cnn_model, cnn_test_iter)

  • 调用:
# 切换到测试程序所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 命令行执行测试任务
python test.py --task test_kd

  • 输出结果:
Vocab size: 4762
180000it [00:02, 72839.73it/s]
10000it [00:00, 82158.93it/s]
10000it [00:00, 83238.81it/s]
Data loaded, now load teacher model
Teacher and student models loaded, start training
Average predict time is: 0.06290757656097412 ms
*********************************
Test Loss:  0.62,  Test Acc: 89.89%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9297    0.8730    0.9005      1000
       realty     0.9341    0.9070    0.9203      1000
       stocks     0.8183    0.8780    0.8471      1000
    education     0.9564    0.9430    0.9496      1000
      science     0.8964    0.8220    0.8576      1000
      society     0.8359    0.9220    0.8768      1000
     politics     0.8920    0.8590    0.8752      1000
       sports     0.9436    0.9540    0.9488      1000
         game     0.9263    0.9050    0.9155      1000
entertainment     0.8736    0.9260    0.8990      1000

     accuracy                         0.8989     10000
    macro avg     0.9006    0.8989    0.8991     10000
 weighted avg     0.9006    0.8989    0.8991     10000

Confusion Matrix...
[[873   9  68   1   6  19  14   3   3   4]
 [ 15 907  13   2   4  18  13   6   2  20]
 [ 38  24 878   1  18  10  23   2   4   2]
 [  1   2   4 943   4  19   6   5   3  13]
 [  2   3  54   5 822  30  26   2  36  20]
 [  1  14   4  19   6 922  14   1   2  17]
 [  8   8  35   9  11  47 859   5   2  16]
 [  1   1   3   1   3  12   3 954   1  21]
 [  0   0   9   2  37   6   5  15 905  21]
 [  0   3   5   3   6  20   0  18  19 926]]
Time usage: 0:00:01

  • 结论: Student模型在测试集上的准确率为89.89%, 单条样本的推断时间为0.063ms.

  • 3: 评估Student模型调参后在GPU上的运行速度和准确率:
# 首选切换到模型路径下
cd /home/ec2-user/toutiao/bert_distil/src/models/

# 将textCNN.py中的类Config中的超参数恢复成第二次调参时的数据
# 主要涉及self.filter_sizes = (2, 3, 4, 5), self.num_filters = 1024
# 其他所有参数保持不变

# 切换到模型保存路径, 恢复第二次调参训练的Student模型
cd /home/ec2-user/toutiao/bert_distil/src/saved_dict/
cp textCNN_9125.pt textCNN.pt

# 切换到测试程序所在的路径下
cd /home/ec2-user/toutiao/bert_distil/src/

# 命令行执行测试任务
python test.py --task test_kd

  • 输出结果:
Vocab size: 4762
180000it [00:02, 73571.48it/s]
10000it [00:00, 82791.09it/s]
10000it [00:00, 82700.32it/s]
Average predict time is: 0.1240086317062378 ms
*********************************
Test Loss:   0.6,  Test Acc: 91.25%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9105    0.9050    0.9077      1000
       realty     0.9311    0.9320    0.9315      1000
       stocks     0.8912    0.8440    0.8670      1000
    education     0.9532    0.9570    0.9551      1000
      science     0.8836    0.8730    0.8783      1000
      society     0.8306    0.9270    0.8762      1000
     politics     0.9041    0.8770    0.8904      1000
       sports     0.9733    0.9470    0.9600      1000
         game     0.9467    0.9240    0.9352      1000
entertainment     0.9108    0.9390    0.9247      1000

     accuracy                         0.9125     10000
    macro avg     0.9135    0.9125    0.9126     10000
 weighted avg     0.9135    0.9125    0.9126     10000

Confusion Matrix...
[[905  10  38   4   5  19  11   3   0   5]
 [ 13 932  13   2   3  17   6   3   4   7]
 [ 54  23 844   1  32   6  33   1   4   2]
 [  2   2   1 957   4  15   6   1   3   9]
 [  3   5  24   5 873  32  17   3  25  13]
 [  2  15   3  18   5 927  12   0   1  17]
 [ 12  10  16  10  14  48 877   2   3   8]
 [  2   0   3   1   2  21   4 947   1  19]
 [  0   0   3   3  43   8   2   5 924  12]
 [  1   4   2   3   7  23   2   8  11 939]]
Time usage: 0:00:01

  • 结论: 调参后的Student在测试集上的准确率为91.25%, 单条样本的推断时间为0.124ms.

  • 结论: 对模型进行知识蒸馏后
    • 模型大小明显减少.
      • BERT模型409.2MB, 最优的textCNN模型23.1MB.
      • 模型大小压缩为原来的5.65%, 缩小了17.7倍.
    • 模型在测试集上准确率仅有2.39%的下降.
      • BERT模型准确率93.64%
      • textCNN模型原始训练50个epochs准确率89.58%
      • textCNN模型知识蒸馏后30个epochs准确率91.25%
      • textCNN对比BERT的表现, 从下降4.06%到下降2.39%, 相对提升率41.1%
    • GPU上的推断速度提升10-20倍.
      • BERT模型单条样本平均推断时间1.591ms.
      • textCNN在89.89%准确率时, 单条样本推断时间0.063ms, 提升25.3倍.
      • textCNN在91.25%准确率时, 单条样本推断时间0.124ms, 提升12.8倍.