7.4 T5模型深度解析
T5模型¶
T5模型核心源代码分析¶
- T5源代码注意力机制分析:
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# Prune linear layers
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
- T5全连接层代码分析:
class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = nn.functional.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
T5模型在下游任务上的调优¶
- T5模型在今日头条投满分项目上可以做分类任务的调优.
- 总体思路和微调BERT以及AlBERT模型一样, 只需要替换预训练模型, 然后开启训练即可.
import torch
import torch.nn as nn
import os
from transformers import T5Tokenizer, T5EncoderModel
class Config(object):
def __init__(self, dataset):
self.model_name = "t5"
self.data_path = "/home/ec2-user/toutiao/t5/data/data/"
self.train_path = self.data_path + "train.txt" # 训练集
self.dev_path = self.data_path + "dev.txt" # 验证集
self.test_path = self.data_path + "test.txt" # 测试集
self.class_list = [x.strip() for x in open(self.data_path + "class.txt").readlines()] # 类别名单
self.save_path = "/home/ec2-user/toutiao/t5/src/saved_dic"
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
self.save_path += "/" + self.model_name + ".pt" # 模型训练结果
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 128 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.model_path = '/home/ec2-user/toutiao/t5/src/models/t5_chinese_base'
self.tokenizer = T5Tokenizer.from_pretrained(self.model_path)
self.hidden_size = 768
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.t5 = T5EncoderModel.from_pretrained(config.model_path)
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0] # 输入的句子
mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
output = self.t5(context, attention_mask=mask)
pooled = output.last_hidden_state
out = self.fc(pooled)
return out
- 输出结果:
Loading data for Bert Model...
0it [00:00, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.737 seconds.
Prefix dict has been built successfully.
180000it [01:04, 2794.84it/s]
10000it [00:03, 2677.98it/s]
10000it [00:03, 2897.69it/s]
Epoch [1/3]
14%|█████████▌ | 200/1407 [02:13<13:52, 1.45it/s]Iter: 200, Train Loss: 2.3, Train Acc: 16.41%, Val Loss: 2.3, Val Acc: 18.36%, Time: 0:02:34 *
28%|███████████████████ | 400/1407 [04:53<11:48, 1.42it/s]Iter: 400, Train Loss: 0.47, Train Acc: 85.16%, Val Loss: 0.43, Val Acc: 87.69%, Time: 0:05:17 *
43%|████████████████████████████▌ | 600/1407 [07:36<09:26, 1.43it/s]Iter: 600, Train Loss: 0.4, Train Acc: 89.06%, Val Loss: 0.35, Val Acc: 89.61%, Time: 0:08:00 *
57%|██████████████████████████████████████ | 800/1407 [10:19<07:07, 1.42it/s]Iter: 800, Train Loss: 0.26, Train Acc: 93.75%, Val Loss: 0.31, Val Acc: 90.51%, Time: 0:10:43 *
71%|██████████████████████████████████████████████▉ | 1000/1407 [13:02<04:45, 1.43it/s]Iter: 1000, Train Loss: 0.24, Train Acc: 91.41%, Val Loss: 0.3, Val Acc: 90.73%, Time: 0:13:26 *
85%|████████████████████████████████████████████████████████▎ | 1200/1407 [15:45<02:24, 1.43it/s]Iter: 1200, Train Loss: 0.27, Train Acc: 91.41%, Val Loss: 0.28, Val Acc: 91.11%, Time: 0:16:09 *
100%|█████████████████████████████████████████████████████████████████▋| 1400/1407 [18:28<00:04, 1.43it/s]Iter: 1400, Train Loss: 0.34, Train Acc: 89.06%, Val Loss: 0.27, Val Acc: 91.46%, Time: 0:18:51 *
100%|██████████████████████████████████████████████████████████████████| 1407/1407 [18:54<00:00, 1.24it/s]
Epoch [2/3]
14%|█████████▌ | 200/1407 [02:20<14:07, 1.42it/s]Iter: 200, Train Loss: 0.29, Train Acc: 89.84%, Val Loss: 0.27, Val Acc: 91.80%, Time: 0:21:38 *
28%|███████████████████ | 400/1407 [05:02<11:46, 1.42it/s]Iter: 400, Train Loss: 0.38, Train Acc: 89.84%, Val Loss: 0.26, Val Acc: 91.83%, Time: 0:24:21 *
43%|████████████████████████████▌ | 600/1407 [07:45<09:25, 1.43it/s]Iter: 600, Train Loss: 0.27, Train Acc: 92.19%, Val Loss: 0.25, Val Acc: 92.09%, Time: 0:27:02 *
57%|██████████████████████████████████████ | 800/1407 [10:27<07:06, 1.42it/s]Iter: 800, Train Loss: 0.13, Train Acc: 95.31%, Val Loss: 0.25, Val Acc: 92.22%, Time: 0:29:42
71%|██████████████████████████████████████████████▉ | 1000/1407 [13:06<04:46, 1.42it/s]Iter: 1000, Train Loss: 0.18, Train Acc: 92.19%, Val Loss: 0.25, Val Acc: 92.26%, Time: 0:32:21
85%|████████████████████████████████████████████████████████▎ | 1200/1407 [15:46<02:25, 1.43it/s]Iter: 1200, Train Loss: 0.18, Train Acc: 94.53%, Val Loss: 0.24, Val Acc: 92.46%, Time: 0:35:04 *
100%|█████████████████████████████████████████████████████████████████▋| 1400/1407 [18:28<00:04, 1.43it/s]Iter: 1400, Train Loss: 0.24, Train Acc: 93.75%, Val Loss: 0.23, Val Acc: 92.60%, Time: 0:37:46 *
100%|██████████████████████████████████████████████████████████████████| 1407/1407 [18:55<00:00, 1.24it/s]
Epoch [3/3]
14%|█████████▌ | 200/1407 [02:20<14:07, 1.42it/s]Iter: 200, Train Loss: 0.23, Train Acc: 92.97%, Val Loss: 0.23, Val Acc: 92.67%, Time: 0:40:30
28%|███████████████████ | 400/1407 [04:59<11:46, 1.43it/s]Iter: 400, Train Loss: 0.27, Train Acc: 92.19%, Val Loss: 0.23, Val Acc: 92.45%, Time: 0:43:10
43%|████████████████████████████▌ | 600/1407 [07:39<09:26, 1.42it/s]Iter: 600, Train Loss: 0.21, Train Acc: 93.75%, Val Loss: 0.23, Val Acc: 92.57%, Time: 0:45:50
57%|██████████████████████████████████████ | 800/1407 [10:19<07:07, 1.42it/s]Iter: 800, Train Loss: 0.13, Train Acc: 94.53%, Val Loss: 0.24, Val Acc: 92.60%, Time: 0:48:29
71%|██████████████████████████████████████████████▉ | 1000/1407 [12:58<04:46, 1.42it/s]Iter: 1000, Train Loss: 0.11, Train Acc: 97.66%, Val Loss: 0.23, Val Acc: 92.69%, Time: 0:51:12 *
85%|████████████████████████████████████████████████████████▎ | 1200/1407 [15:41<02:25, 1.43it/s]Iter: 1200, Train Loss: 0.12, Train Acc: 96.09%, Val Loss: 0.23, Val Acc: 92.66%, Time: 0:53:52
100%|█████████████████████████████████████████████████████████████████▋| 1400/1407 [18:20<00:04, 1.42it/s]Iter: 1400, Train Loss: 0.22, Train Acc: 92.97%, Val Loss: 0.23, Val Acc: 92.81%, Time: 0:56:34 *
100%|██████████████████████████████████████████████████████████████████| 1407/1407 [18:47<00:00, 1.25it/s]
Test Loss: 0.2, Test Acc: 93.49%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.9223 0.9260 0.9242 1000
realty 0.9435 0.9520 0.9477 1000
stocks 0.8852 0.8870 0.8861 1000
education 0.9700 0.9690 0.9695 1000
science 0.9176 0.8800 0.8984 1000
society 0.9371 0.9240 0.9305 1000
politics 0.9092 0.9210 0.9151 1000
sports 0.9692 0.9740 0.9716 1000
game 0.9577 0.9510 0.9543 1000
entertainment 0.9369 0.9650 0.9507 1000
accuracy 0.9349 10000
macro avg 0.9349 0.9349 0.9348 10000
weighted avg 0.9349 0.9349 0.9348 10000
Confusion Matrix...
[[926 13 38 1 6 5 8 1 1 1]
[ 9 952 9 3 1 8 7 3 1 7]
[ 46 15 887 0 17 1 24 4 3 3]
[ 2 1 1 969 1 8 7 1 0 10]
[ 4 2 35 5 880 11 19 1 32 11]
[ 6 13 2 10 8 924 22 3 1 11]
[ 6 6 22 7 8 19 921 4 1 6]
[ 3 2 3 1 3 3 1 974 1 9]
[ 1 2 4 0 27 5 1 2 951 7]
[ 1 3 1 3 8 2 3 12 2 965]]
Time usage: 0:00:19
- 结论: 采用T5预训练模型, 可以在测试集上得到F1=93.49%的优异表现. 相比较于BERT的93.64%只有非常轻微的下降, 考虑到T5更擅长的领域是在生成式任务上, 所以分类任务得到几乎等同于BERT的表现已经非常棒了!
T5 PEGASUS模型¶
PEGASUS介绍¶
-
T5模型可以做生成式任务的绝佳预训练模型.
- Google在2020年提出基于T5的新版模型T5 Pegasus (T5天马模型), 在生成式任务上获得了SOTA效果.
-
<< PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization >>
- Pegasus模型的架构图如下:
- PRGASUS模型的主要工作:
- 1: 提出了为摘要生成定制的预训练目标GSG(Gap Sentences Generation).
- 2: 研究了多种GSG方式, 即如何在预训练目标中选择重要的句子.
- 3: 研究了多种MASK策略, 在token mask和sentence mask中产生最佳的组合.
-
首先来看Gap Sentences Generation (GSG)
-
基本公理假设: 假设预训练目标和下游任务越接近, 那么finetuning会带来更好更快的表现.
-
基于摘要提取的目的, 论文提出了类似摘要模型的输入文本, 为了能在大量文本上进行无监督的预训练, 需要设计一个无摘要的seq2seq自监督目标. 这个目标的设定非常关键, 直接影响模型的能力和任务. 一个简单的想法是将预训练作为摘要提取器, 然而这样导致的结果是, 训练的模型得到的是原始文本中的句子, 并不适合生成式的任务!
-
受到span mask的启发, 论文中选择了mask文本中的句子段, 并且拼接这些gap-sentences形成伪摘要. 相应位置的gap-sentences用[MASK1]替换. 而正常的token级别的遮掩则用[MASK2]替换.
- 选择GSG的三大策略:
- 1: Random - 随机选择m个句子.
- 2: Lead - 选择文档中前面的m个句子.
- 3: Principal - 根据重要性, 选择得分最高的top-m个句子.
- 关于Principal分成两种选择策略:
- 1: 独立性选择(Ind): 句子重要性课根据选中句和其他句子集合的rouge1-F1来计算, 最终选择得分最高的m个句子.
- 2: 通过贪心最大化选中句子集合, 和其他句子集合的rouge1-F1来计算, 如下图展示.
-
这其中计算rouge1-F1的方式也分成2种, Uniq和Orig.
- Uniq: 先将句子集合处理, 即去除掉重复的n-gram, 再计算rouge1-F1.
- Orig: 保留原始句子, 容许重复n-gram出现.
-
基于最后的实验结果表明, 选择了文档30%的句子作为gap sentences, 而Principal策略在PEGASUS中采用了ind-Orig的方式.
-
Masked Language Model (MLM)
-
论文中MASK的方式一共三种:
- 1: 只做MLM, 类似BERT, 以输入文本的15%的tokens, 将80%替换成[MASK2], 10%随机替换, 10%保持不变. (策略1时, 在finetuning下游任务时, Transformer Decoder部分共享Encoder的参数)
- 2: 不采用MLM, 只采用GSG, 将选中的sentence用[MASK1]替换掉.
- 3: 采用GSG的方式MASK选中的句子, 然后对于未选中的句子, 15%的token利用MLM去MASK掉.
-
基于最后的实验结果表明, 仅采用MLM效果最差, 预训练100-200k steps时, MLM&GSG的效果在提升 但此后包含MLM效果在下降. 因此最终PEGASUS-large仅采用GSG, PEGASUS-base采用了MLM&GSG.
- PEGASUS-BASE: 223M Parameters, L=12, H=768, F=3072, A=12
- PEGASUS-LARGE: 568M Parameters, L=16, H=1024, F=4096, A=16
-
消融实验(ablation experiments)
- 预训练语料
- 预训练目标
- 词表影响
-
预训练语料: 实验结果表明, 两个新闻相关的语料(XSum和CNN/DailyMail)在HugeNews上表现更好, 非新闻相关的语料在C4上表现更好. 这说明预训练语料与下游任务越接近, 则预训练产生的迁移学习效果越好.
-
预训练目标: 实验结果表明, Ind-Orig的表现最佳. 而且最佳的GSR总是低于50%, 所以PEGASUS最终选择了30%.
-
词表影响: 比较了两种token方式BPE(Byte-pair-encoding algorithm)和Unigram(SentencePiece Unigram algorithm), 实验结果表明不同的下游任务最佳结果并不一致, 因此作者选择了相对较优的方式, 在PEGASUS-large中选择了unigram 96K.
-
Low-Resource Summarization
-
在实际工程中, 很难收集到足量的标注语料用来训练摘要模型或者finetuning, 在少样本情况下, 进行了低资源模拟实验.
- PEGASUS最大的惊喜: 实验结果表明, 在12个摘要提取任务上, 仅需要1000个样本, 模型表现就可以超越人类!!!
PEGASUS中文模型介绍¶
- T5模型于2020提出后, 国内研究员迫切需要中文版本的天马, 于是在2021年T5终于可以在中文上天马行空了.
-
Tokenizer的优化
-
mT5所使用的Tokenizer是sentencepiece, 这是一个C++编写的分词库, 具有高效轻便的优点, 但是对于中文场景并不友好.
- 1: sentencepiece会把某些全角符号强制转化为半角符号, 这在某些情况下是难以接受的, 并且还可能影响评测结果.
- 2: sentencepiece内置的算法虽然有能力分出中文词来, 但对于中文分词来说还是不够智能.
- 3: sentencepiece是用C++编写的, 虽然开源了但对于用惯Python的人来说C++就相当于黑箱, 难以阅读源代码, 改写优化起来也不容易.
-
基于上述3点原因, 在中文本的PEGASUS中将Tokenizer进行了优化改写. 最主要的是基于BERT的Tokenizer添加了分词功能, 并进一步完善了vocab.txt
-
具体来说, 往原始中文BERT的token_dict里边加入结巴分词的前20万个词, 然后修改Tokenizer的逻辑, 使得它能够切分出中文词来. 用这个修改后的Tokenizer去遍历切分准备好的预训练语料, 统计各个token的频数, 最后只保留最高频的5万个token, 得到一个规模为5万的vocab.txt来构建最终的Tokenizer.
-
预训练任务的优化
-
中文版PEGASUS希望更加接近自然语言生成(而不是像T5那样的只预测挖空部分), 并且尽可能具有实用价值. PEGASUS在其论文称是专门为摘要定制的预训练模型, 但更主要的是作为通用的生成式预训练任务而呈现. PEGASUS的大体思路是通过最长公共子序列的方式来得到摘要类似的数据对, T5 PEGASUS并没有完全复现PEGASUS的做法, 只是借鉴了PEGASUS的思路做语料构建.
- 具体来说, 假设一个文档有n个句子, 我们从中挑出大约n/4个句子(可以不连续), 使得这n/4个句子拼起来的文本, 跟剩下的3n/4个句子拼起来的文本, 最长公共子序列尽可能长. 然后我们将3n/4个句子拼起来的文本视为原文, n/4个句子拼起来的文本视为摘要, 这样就构成了一个“(原文, 摘要)”的伪摘要数据对了. 就用这些数据对去训练Seq2Seq模型即可.
- 注意: 如果文档里没有重复句子的话, 那么原文跟摘要的句子是不会有交集的, 所以这样的生成任务并非是原文的简单复制, 因此还是有一定难度的.
- 搜索算法则是通过如下的贪心算法逐步搜索至满足长度要求:
1: 先找出1个句子, 使得它跟生成的n−1个句子的最长公共子序列最长.
2: 假设已经找到了k个句子, 那么继续找第k+1个句子, 使得这k+1个句子拼起来的文本, 跟剩下的n−k−1个句子拼起来的文本的最长公共子序列最长.
- 在主流摘要生成任务的评测中, 我们发现T5 PEGASUS获得了最优的表现:
- 非常惊喜的是: T5 PEGASUS有着非常出色的小样本学习能力, 在几千的样本量级时, 全部展示了最佳的表现!
-
当标注样本数等于10的时候, 模型生成的效果演示如下:
-
样例1:
- 样例2:
- 样例3:
- 结论: 哪怕标注样本很少, 但依然能够得到可读性较好的生成结果, 这得益于PEGASUS的伪摘要预训练任务与下游任务非常贴近!!!
小节总结¶
- 本小节学习了T5模型的深入细节.
- T5模型在分类任务上的调优, 也可以获得完全媲美BERT的效果.
- 原版PEGASUS的技术细节.
- T5 PEGASUS在中文场景下的优化和优异效果展示.