跳转至

5.1 BERT模型微调

模型微调的实践案例


学习目标

  • 掌握基于BERT的微调模型.
  • 掌握对BERT微调模型的参数配置方法.

本项目用到的数据集

  • 投满分项目中用到的20万条数据, 文本长度在20到30之间, 共10个类别, 每类20000条.
    • 类别: 财经, 房产, 股票, 教育, 科技, 社会, 时政, 体育, 游戏, 娱乐.
    • 数据集划分: 训练集180000, 验证集10000, 测试集10000.

  • 本项目中用到是数据集放在/home/bert/bert_finetuning/data/目录下, 该目录下有如下文件信息:
# 下列5个文件
class.txt
test.csv
train.csv
dev.csv
vocab.pkl

  • 本项目中用到的预训练模型放在/home/bert/bert_finetuning/data/bert_pretrain目录下, 该目录下有如下文件信息:
bert_config.json
pytorch_model.bin
vocab.txt


本项目的实现步骤:

  • 第1步: 熟悉数据文件和数据集
  • 第2步: 构建基于BERT微调的多标签分类模型
  • 第3步: 对BERT模型的参数执行微调
  • 第4步: 编写textCNN模型类代码

第1步: 熟悉数据文件和数据集

  • 打开class.txt文件, 文件位置/home/bert/bert_finetuning/data/data/
# 文件中存储的是10种分类标签的英文名称, 也可以按照程序员的喜好换成中文

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

  • 打开train.txt文件, 文件位置/home/bert/bert_finetuning/data/data/
# 这里只截取文件的一部分展示, 文件总共有180000行

中华女子学院:本科层次仅1专业招男生     3
两天价网站背后重重迷雾:做个网站究竟要多少钱    4
东5环海棠公社230-290平2居准现房98折优惠 1
卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球       7
82岁老太为学生做饭扫地44年获授港大荣誉院士      5
记者回访地震中可乐男孩:将受邀赴美国参观        5
冯德伦徐若�隔空传情 默认其是女友        9
传郭晶晶欲落户香港战伦敦奥运 装修别墅当婚房     1
《赤壁OL》攻城战诸侯战硝烟又起  8
“手机钱包”亮相科博会    4
上海2010上半年四六级考试报名4月8日前完成        3
李永波称李宗伟难阻林丹取胜 透露谢杏芳有望出战   7
3岁女童下体红肿 自称被幼儿园老师用尺子捅伤      5
金证顾问:过山车行情意味着什么  2
谁料地王如此虚  1
《光环5》Logo泄露 Kinect版几无悬念      8
海淀区领秀新硅谷宽景大宅预计10月底开盘  1
柴志坤:土地供应量不断从紧 地价难现07水平(图)   1

  • 打开dev.csv文件, /home/bert/bert_finetuning/data/data/
# 这里只截取文件的一部分展示, 文件总共有10000行 

体验2D巅峰 倚天屠龙记十大创新概览       8
60年铁树开花形状似玉米芯(组图)  5
同步A股首秀:港股缩量回调       2
中青宝sg现场抓拍 兔子舞热辣表演 8
锌价难续去年辉煌        0
2岁男童爬窗台不慎7楼坠下获救(图)        5
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击 7
金科西府 名墅天成       1
状元心经:考前一周重点是回顾和整理      3
发改委治理涉企收费每年为企业减负超百亿  6
一年网事扫荡10年纷扰开心网李鬼之争和平落幕      4
2010英国新政府“三把火”或影响留学业      3
俄达吉斯坦共和国一名区长被枪杀  6
朝鲜要求日本对过去罪行道歉和赔偿        6
《口袋妖怪 黑白》日本首周贩售255万      8
图文:借贷成本上涨致俄罗斯铝业净利下滑21%       2
组图:新《三国》再曝海量剧照 火战场面极震撼     9
麻辣点评:如何走出“被留学”的尴尬        3

  • 打开test.csv文件, /home/bert/bert_finetuning/data/data/
# 这里只截取文件的一部分展示, 文件总共有10000行

词汇阅读是关键 08年考研暑期英语复习全指南       3
中国人民公安大学2012年硕士研究生目录及书目      3
日本地震:金吉列关注在日学子系列报道    3
名师辅导:2012考研英语虚拟语气三种用法  3
自考经验谈:自考生毕业论文选题技巧      3
本科未录取还有这些路可以走      3
2009年成人高考招生统一考试时间表        3
去新西兰体验舌尖上的饕餮之旅(组图)      3
四级阅读与考研阅读比较分析与应试策略    3
备考2012高考作文必读美文50篇(一)        3
名师详解考研复试英语听力备考策略        3
热议:艺考合格证是高考升学王牌吗(组图)  3
研究生办替考网站续:幕后老板年赚近百万(图)      3
2011年高考文科综合试题(重庆卷)  3
56所高校预估2009年湖北录取分数线出炉    3
公共英语(PETS)写作中常见的逻辑词汇汇总  3
时评:高考应成为教育公平的“助推器”      3

  • 还有一个文件vocab.pkl文件, 不是可打印文件, 是为后续的textCNN模型服务的.

  • 进入文件夹/home/bert/bert_finetuning/data/bert_pretrain/, 有如下3个文件:
# BERT预训练模型的所有参数
config.json
# 基于bert-base-chinese预训练模型得到的模型文件
pytorch_model.bin
# bert-base-chinese模型的所有词汇表, 共21128个字符
vocab.txt

第2步: 构建基于BERT微调的多标签分类模型

  • 首先编写经典的实现继承BERT预训练模型的分类任务类.
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig

# 构建基于BERT的微调模型类
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        # 导入参数设置对象
        model_config = BertConfig.from_pretrained(config.bert_path, num_labels=config.num_classes)
        # 导入基于bert-base-chinese的预训练模型
        self.bert = BertModel.from_pretrained(config.bert_path, config=model_config)

        # 此处用于调节是否将BERT纳入微调训练, 建议数据量+算力充足的情况下置为True
        # 如果设置为False, 则保持整个BERT网络参数不变, 微调仅仅针对最后的全连接层进行训练
        for param in self.bert.parameters():
            param.requires_grad = True

        # 全连接层的出口维度, 取决于具体的任务
        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        # x[0]是输入的具体文本信息
        context = x[0]
        # x[1]是经过tokenizer处理后返回的attention mask张量
        # mask的尺寸size和输入相同, padding部分用0遮掩, 比如[1, 1, 1, 0, 0]
        mask = x[1]
        # x[2]是字符类型id
        token_type_ids = x[2]
        # 利用BERT模型得到输出张量, 并且只保留BertPooler的输出, 即第一个字符CLS对应的输出张量
        _, pooled = self.bert(context, attention_mask=mask, token_type_ids=token_type_id)
        # 再利用微调网络进一步提取特征, 并利用全连接层对特征张量进行维度变换
        out = self.fc(pooled)
        return out


第3步: 对BERT模型的参数执行微调

  • 首先展示BERT模型中的参数命名:
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 将BERT中所有的参数层名字打印出来
        for name, param in self.bert.named_parameters():
            print(name)

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

  • 输出结果:
embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.key.bias
encoder.layer.0.attention.self.value.weight
encoder.layer.0.attention.self.value.bias
encoder.layer.0.attention.output.dense.weight
encoder.layer.0.attention.output.dense.bias
encoder.layer.0.attention.output.LayerNorm.weight
encoder.layer.0.attention.output.LayerNorm.bias
encoder.layer.0.intermediate.dense.weight
encoder.layer.0.intermediate.dense.bias
encoder.layer.0.output.dense.weight
encoder.layer.0.output.dense.bias
encoder.layer.0.output.LayerNorm.weight
encoder.layer.0.output.LayerNorm.bias
encoder.layer.1.attention.self.query.weight
encoder.layer.1.attention.self.query.bias
encoder.layer.1.attention.self.key.weight
encoder.layer.1.attention.self.key.bias
encoder.layer.1.attention.self.value.weight
encoder.layer.1.attention.self.value.bias
encoder.layer.1.attention.output.dense.weight
encoder.layer.1.attention.output.dense.bias
encoder.layer.1.attention.output.LayerNorm.weight
encoder.layer.1.attention.output.LayerNorm.bias
encoder.layer.1.intermediate.dense.weight
encoder.layer.1.intermediate.dense.bias
encoder.layer.1.output.dense.weight
encoder.layer.1.output.dense.bias
encoder.layer.1.output.LayerNorm.weight
encoder.layer.1.output.LayerNorm.bias

......
......
......

encoder.layer.11.attention.self.query.weight
encoder.layer.11.attention.self.query.bias
encoder.layer.11.attention.self.key.weight
encoder.layer.11.attention.self.key.bias
encoder.layer.11.attention.self.value.weight
encoder.layer.11.attention.self.value.bias
encoder.layer.11.attention.output.dense.weight
encoder.layer.11.attention.output.dense.bias
encoder.layer.11.attention.output.LayerNorm.weight
encoder.layer.11.attention.output.LayerNorm.bias
encoder.layer.11.intermediate.dense.weight
encoder.layer.11.intermediate.dense.bias
encoder.layer.11.output.dense.weight
encoder.layer.11.output.dense.bias
encoder.layer.11.output.LayerNorm.weight
encoder.layer.11.output.LayerNorm.bias
pooler.dense.weight
pooler.dense.bias

  • 针对BERT模型中的embedding层, 让其中的参数不参与微调, 代码1.
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 希望锁定embeddings层的参数, 不参与更新
        for name, param in self.bert.embeddings.named_parameters():
            print(name)
            param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

  • 输出结果:
word_embeddings.weight
position_embeddings.weight
token_type_embeddings.weight
LayerNorm.weight
LayerNorm.bias

  • 针对BERT中的全连接层, 让其中的weight参数不参与微调, 代码2.
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 希望将全连接层中的.weight部分参数锁定
        for name, param in self.bert.named_parameters():
            if name.endswith('weight'):
                print(name)
                param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

  • 输出结果:
embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.value.weight
encoder.layer.0.attention.output.dense.weight
encoder.layer.0.attention.output.LayerNorm.weight
encoder.layer.0.intermediate.dense.weight
encoder.layer.0.output.dense.weight
encoder.layer.0.output.LayerNorm.weight
encoder.layer.1.attention.self.query.weight
encoder.layer.1.attention.self.key.weight
encoder.layer.1.attention.self.value.weight
encoder.layer.1.attention.output.dense.weight
encoder.layer.1.attention.output.LayerNorm.weight
encoder.layer.1.intermediate.dense.weight
encoder.layer.1.output.dense.weight
encoder.layer.1.output.LayerNorm.weight

......
......
......

encoder.layer.11.attention.self.query.weight
encoder.layer.11.attention.self.key.weight
encoder.layer.11.attention.self.value.weight
encoder.layer.11.attention.output.dense.weight
encoder.layer.11.attention.output.LayerNorm.weight
encoder.layer.11.intermediate.dense.weight
encoder.layer.11.output.dense.weight
encoder.layer.11.output.LayerNorm.weight
pooler.dense.weight

  • 针对BERT中指定的若干层, 让其中的参数不参与微调, 代码3.
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)

        # 封闭BERT中的第1, 3, 5层参数, 不参与微调
        index_array = [1, 3, 5]
        for name, param in self.bert.named_parameters():
            new_x = name.split('.')[2]
            if new_x in index_array:
                print(name)
                param.requires_grad = False

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

  • 输出结果:
attention.self.query.weight
attention.self.query.bias
attention.self.key.weight
attention.self.key.bias
attention.self.value.weight
attention.self.value.bias
attention.output.dense.weight
attention.output.dense.bias
attention.output.LayerNorm.weight
attention.output.LayerNorm.bias
intermediate.dense.weight
intermediate.dense.bias
output.dense.weight
output.dense.bias
output.LayerNorm.weight
output.LayerNorm.bias
attention.self.query.weight
attention.self.query.bias
attention.self.key.weight
attention.self.key.bias
attention.self.value.weight
attention.self.value.bias
attention.output.dense.weight
attention.output.dense.bias
attention.output.LayerNorm.weight
attention.output.LayerNorm.bias
intermediate.dense.weight
intermediate.dense.bias
output.dense.weight
output.dense.bias
output.LayerNorm.weight
output.LayerNorm.bias
attention.self.query.weight
attention.self.query.bias
attention.self.key.weight
attention.self.key.bias
attention.self.value.weight
attention.self.value.bias
attention.output.dense.weight
attention.output.dense.bias
attention.output.LayerNorm.weight
attention.output.LayerNorm.bias
intermediate.dense.weight
intermediate.dense.bias
output.dense.weight
output.dense.bias
output.LayerNorm.weight
output.LayerNorm.bias


小节总结