跳转至

7.2 新型生成式模型深度解析

MASS模型


MASS架构

  • 原始论文<< MASS: Masked Sequence to Sequence Pre-training for Language Generation >>由微软亚洲研究院于2019顶会提出.

  • MASS模型的基本架构采用经典Encoder-Decoder模式, 加强了模型在NLG上的能力. MASS模型在多种语言生成任务(包括机器翻译, 文本摘要, 多轮对话等)都取得了显著的提升.


  • MASS模型在架构上最主要的操作, 就是对句子随机屏蔽一段长度为k的连续片段, 然后通过Encoder-Attention-Decoder来预测这以长度为k的连续片段, 如下图所示:

  • 如上图所示, 输入8个token的序列中 x3x4x5x6 被mask掉. 注意到, 模型仅仅预测mask的部分 x3x4x5x6, 并且在decoder的4-6位置给定 x3x4x5 作为输入, 其它位置用特殊字符[M]作为输入. 由于我们的方法对于任何基于encode-decoder的神经网络框架都适用, 并且考虑到Transormer在序列学习达到了SOTA的强大能力, 因此原始论文中选择了Transormer作为Encoder和Decoder的主体.

  • 其实, 基于MLM的BERT模型和标准的语言模型GPT都可以被视作MASS的特殊情况:
    • 当k=1的时候, MASS可被视作BERT, 如下图(a)所示, Encoder只mask掉x5, Decoder无任何输入, 而预测x5, 可以认为整个结构只利用了Encoder部分.
    • 当k=m(m代表句子长度)的时候, MASS可被视作GPT, 如下图(b)所示, Encoder全部MASK无任何输入, 可以认为整个结构只利用了Decoder部分.


MASS模型优化点

  • MASS模型的主要优化点:
    • 1: 随机替换连续的token, 并用[M]替换, 起始位置随机选择.
    • 2: 与BERT论文一样, 80%的token进行mask, 10%随机替换replace, 10%保持不变unchanged.
    • 3: mask连续序列的长度设定为句子长度的50%, 这是为了解决Encoder和Decoder之间的平衡问题.
    • 4: 为了减少内存和时间消耗, 在Decoder部分移除了padding, 但保留了positional embedding, 这样就可以减少Decoder部分50%的计算量.

  • MASS模型在机器翻译中有显著的优势, 当采用BLEU评分标准时的具体实验数据如下:


  • MASS模型在文本摘要任务中也有显著的优势, 当采用ROUGE评分标准时的具体实验数据如下:


  • 结论: MASS模型第一次整合了Transformer架构下的Encoder和Decoder, 也是BERT和GPT的集合体, 验证了在生成序列任务中是优于BERT和GPT的. MASS的工作核心是将BERT整合到了seq2seq架构上.


BART模型

BART架构

  • Facebook的NLP团队于2019年发布了BART模型, 原始论文<< BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension >>.

  • 与MASS模型一样, BART模型的核心任务是优化NLG, 所以也采用了Encoder-Decoder架构, 核心步骤分两点:

    • 1: 用一个任意的noising function去破坏原句子序列.
    • 2: Decoder的目标是重构原句子序列.
  • 下图中的(a)代表BERT模型, (b)代表GPT模型, ©代表BART模型, 其中的MASK策略更为特殊:


  • 论文中比较了几种noising function:
    • 1: Token Masking: 同BERT一样, 随机挑选一些token, 然后用[MASK]替换.
    • 2: Token Deletion: 直接删除掉某些token.
    • 3: Text Infilling: 对一段连续的span进行mask, 与SpanBERT不同的是, 这里采用泊松分布决定Span的长度. 其次, 连续的Span只会用一个[MASK]符号进行代替, 作者解释是这样有利于模型学会预测这个Span有多少个token.
    • 4: Sentence Permutation: 打乱句子的顺序.
    • 5: Document Rotation: 反转整个文档.


  • 注意: 论文作者经过大量的消融实验, 最后采用的是Text Infilling和Sentence Permutation相结合的方式.

  • 虽然BART的提出是为了弥补BERT在NLG任务中的不足, 但是作者也通过实验论证了, BART在NLU任务中也能取得不俗的表现. 下图为BART应用于NLU任务和NLG任务的例子:


  • 下图展示了BART模型在采用不同的预训练策略时, 所若干自然语言理解任务的测试中的表现, 发现最优的分数出现在Text Infilling和Sentence Shuffling上, 这是典型的消融实验的结果, 自然而然的在未来训练模型的时候就只采用这两个方案的组合就好:


  • BART模型在NLU任务上的表现, 基本和RoBERTa旗鼓相当:


  • BART模型在NLG任务上的表现, 和BERT系列相比明显更优:



BART模型的应用

BART在分类任务中的应用

  • 展示BART在分类任务下的预训练模型(bart_chinese_NLU):
-rw-r--r-- 1 ec2-user ec2-user      1115 Feb 22 10:26 config.json
-rw-r--r-- 1 ec2-user ec2-user 400912567 Feb 22 10:30 pytorch_model.bin
-rw-r--r-- 1 ec2-user ec2-user      2512 Feb 22 10:26 README.md
-rw-r--r-- 1 ec2-user ec2-user       112 Feb 22 10:26 special_tokens_map.json
-rw-r--r-- 1 ec2-user ec2-user       430 Feb 22 10:26 tokenizer_config.json
-rw-r--r-- 1 ec2-user ec2-user    109540 Feb 22 10:26 vocab.txt

  • 在投满分项目中, 将迁移学习的预训练模型替换为BART
import torch
import torch.nn as nn
import os
from transformers import BertTokenizer, BartForConditionalGeneration


class Config(object):
    def __init__(self, dataset):
        self.model_name = "bart"
        self.data_path = "/home/ec2-user/ec2-user/zhudejun/bert/bart/data/data/"
        self.train_path = self.data_path + "train.txt"  # 训练集
        self.dev_path = self.data_path + "dev.txt"  # 验证集
        self.test_path = self.data_path + "test.txt"  # 测试集
        self.class_list = [x.strip() for x in open(self.data_path + "class.txt").readlines()] # 类别名单

        self.save_path = "/home/ec2-user/ec2-user/zhudejun/bert/bart/src/saved_dic"
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        self.save_path += "/" + self.model_name + ".pt"  # 模型训练结果

        # 模型训练+预测的时候, 放开下一行代码, 在GPU上运行.
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备
        # 模型量化的时候, 放开下一行代码, 在CPU上运行.
        # self.device = 'cpu'

        self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.num_epochs = 5  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.bart_path = "/home/ec2-user/ec2-user/zhudejun/bert/bart/data/bart_chinese_NLU"
        self.tokenizer = BertTokenizer.from_pretrained(self.bart_path)
        self.hidden_size = 768


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bart = BartForConditionalGeneration.from_pretrained(config.bart_path)

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        context = x[0]  # 输入的句子
        mask = x[2]  # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
        _, pooled = self.bart(context, attention_mask=mask)
        out = self.fc(pooled[:, 0, :])
        return out

  • 输出结果:
Loading data for Bert Model...
180000it [00:31, 5787.64it/s]
10000it [00:01, 5173.45it/s]
10000it [00:01, 5977.91it/s]
Epoch [1/5]
 14%|█████████▏                                                       | 200/1407 [01:02<06:22,  3.16it/s]Iter:    200,  Train Loss:   0.6,  Train Acc: 78.91%,  Val Loss:  0.48,  Val Acc: 85.81%,  Time: 0:01:19 *
 28%|██████████████████▍                                              | 400/1407 [02:25<05:58,  2.81it/s]Iter:    400,  Train Loss:  0.55,  Train Acc: 80.47%,  Val Loss:  0.38,  Val Acc: 88.29%,  Time: 0:02:45 *
 43%|███████████████████████████▋                                     | 600/1407 [03:52<04:32,  2.96it/s]Iter:    600,  Train Loss:  0.41,  Train Acc: 85.16%,  Val Loss:  0.34,  Val Acc: 88.92%,  Time: 0:04:13 *
 57%|████████████████████████████████████▉                            | 800/1407 [05:21<03:28,  2.91it/s]Iter:    800,  Train Loss:  0.45,  Train Acc: 85.94%,  Val Loss:  0.32,  Val Acc: 89.97%,  Time: 0:05:42 *
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [06:51<02:27,  2.76it/s]Iter:   1000,  Train Loss:  0.25,  Train Acc: 90.62%,  Val Loss:  0.31,  Val Acc: 90.03%,  Time: 0:07:11 *
 85%|██████████████████████████████████████████████████████▌         | 1200/1407 [08:19<01:13,  2.82it/s]Iter:   1200,  Train Loss:  0.27,  Train Acc: 90.62%,  Val Loss:  0.29,  Val Acc: 90.81%,  Time: 0:08:40 *
100%|███████████████████████████████████████████████████████████████▋| 1400/1407 [09:48<00:02,  3.05it/s]Iter:   1400,  Train Loss:  0.39,  Train Acc: 89.06%,  Val Loss:  0.28,  Val Acc: 91.17%,  Time: 0:10:08 *
100%|████████████████████████████████████████████████████████████████| 1407/1407 [10:10<00:00,  2.31it/s]
Epoch [2/5]
 14%|█████████▏                                                       | 200/1407 [01:08<07:11,  2.80it/s]Iter:    200,  Train Loss:  0.29,  Train Acc: 90.62%,  Val Loss:  0.28,  Val Acc: 91.04%,  Time: 0:11:37 
 28%|██████████████████▍                                              | 400/1407 [02:34<05:38,  2.97it/s]Iter:    400,  Train Loss:  0.37,  Train Acc: 87.50%,  Val Loss:  0.26,  Val Acc: 91.48%,  Time: 0:13:05 *
 43%|███████████████████████████▋                                     | 600/1407 [04:03<04:39,  2.89it/s]Iter:    600,  Train Loss:  0.25,  Train Acc: 92.97%,  Val Loss:  0.26,  Val Acc: 91.70%,  Time: 0:14:34 *
 57%|████████████████████████████████████▉                            | 800/1407 [05:31<03:21,  3.01it/s]Iter:    800,  Train Loss:  0.18,  Train Acc: 93.75%,  Val Loss:  0.26,  Val Acc: 91.97%,  Time: 0:16:02 *
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [07:00<02:20,  2.90it/s]Iter:   1000,  Train Loss:  0.15,  Train Acc: 93.75%,  Val Loss:  0.26,  Val Acc: 91.83%,  Time: 0:17:31 *
 85%|██████████████████████████████████████████████████████▌         | 1200/1407 [08:28<01:08,  3.00it/s]Iter:   1200,  Train Loss:  0.23,  Train Acc: 89.84%,  Val Loss:  0.26,  Val Acc: 91.73%,  Time: 0:18:57 
100%|███████████████████████████████████████████████████████████████▋| 1400/1407 [09:55<00:02,  3.21it/s]Iter:   1400,  Train Loss:   0.4,  Train Acc: 86.72%,  Val Loss:  0.25,  Val Acc: 91.96%,  Time: 0:20:26 *
100%|████████████████████████████████████████████████████████████████| 1407/1407 [10:17<00:00,  2.28it/s]
Epoch [3/5]
 14%|█████████▏                                                       | 200/1407 [01:09<06:56,  2.90it/s]Iter:    200,  Train Loss:  0.25,  Train Acc: 90.62%,  Val Loss:  0.26,  Val Acc: 92.07%,  Time: 0:21:54 
 28%|██████████████████▍                                              | 400/1407 [02:36<05:53,  2.84it/s]Iter:    400,  Train Loss:  0.27,  Train Acc: 90.62%,  Val Loss:  0.25,  Val Acc: 92.16%,  Time: 0:23:24 *
 43%|███████████████████████████▋                                     | 600/1407 [04:04<04:55,  2.73it/s]Iter:    600,  Train Loss:  0.15,  Train Acc: 95.31%,  Val Loss:  0.24,  Val Acc: 92.46%,  Time: 0:24:52 *
 57%|████████████████████████████████████▉                            | 800/1407 [05:32<03:29,  2.90it/s]Iter:    800,  Train Loss: 0.077,  Train Acc: 98.44%,  Val Loss:  0.25,  Val Acc: 92.33%,  Time: 0:26:18 
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [06:58<02:16,  2.98it/s]Iter:   1000,  Train Loss:  0.18,  Train Acc: 93.75%,  Val Loss:  0.25,  Val Acc: 92.28%,  Time: 0:27:44 
 85%|██████████████████████████████████████████████████████▌         | 1200/1407 [08:24<01:13,  2.81it/s]Iter:   1200,  Train Loss:  0.19,  Train Acc: 91.41%,  Val Loss:  0.26,  Val Acc: 91.99%,  Time: 0:29:09 
100%|███████████████████████████████████████████████████████████████▋| 1400/1407 [09:50<00:02,  2.74it/s]Iter:   1400,  Train Loss:  0.26,  Train Acc: 94.53%,  Val Loss:  0.25,  Val Acc: 92.18%,  Time: 0:30:36 
100%|████████████████████████████████████████████████████████████████| 1407/1407 [10:10<00:00,  2.30it/s]
Epoch [4/5]
 14%|█████████▏                                                       | 200/1407 [01:09<06:39,  3.02it/s]Iter:    200,  Train Loss:  0.23,  Train Acc: 92.97%,  Val Loss:  0.26,  Val Acc: 92.43%,  Time: 0:32:05 
 28%|██████████████████▍                                              | 400/1407 [02:35<05:57,  2.82it/s]Iter:    400,  Train Loss:   0.2,  Train Acc: 92.97%,  Val Loss:  0.24,  Val Acc: 92.50%,  Time: 0:33:31 
 43%|███████████████████████████▋                                     | 600/1407 [04:01<04:47,  2.81it/s]Iter:    600,  Train Loss:  0.17,  Train Acc: 93.75%,  Val Loss:  0.25,  Val Acc: 92.21%,  Time: 0:34:57 
 57%|████████████████████████████████████▉                            | 800/1407 [05:27<03:36,  2.80it/s]Iter:    800,  Train Loss: 0.043,  Train Acc: 99.22%,  Val Loss:  0.25,  Val Acc: 92.58%,  Time: 0:36:23 
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [06:53<02:07,  3.18it/s]Iter:   1000,  Train Loss: 0.086,  Train Acc: 97.66%,  Val Loss:  0.26,  Val Acc: 92.34%,  Time: 0:37:49 
 85%|██████████████████████████████████████████████████████▌         | 1200/1407 [08:20<01:14,  2.76it/s]Iter:   1200,  Train Loss:  0.11,  Train Acc: 96.09%,  Val Loss:  0.26,  Val Acc: 92.42%,  Time: 0:39:16 
100%|███████████████████████████████████████████████████████████████▋| 1400/1407 [09:46<00:02,  2.72it/s]Iter:   1400,  Train Loss:  0.23,  Train Acc: 91.41%,  Val Loss:  0.25,  Val Acc: 92.35%,  Time: 0:40:42 
100%|████████████████████████████████████████████████████████████████| 1407/1407 [10:05<00:00,  2.32it/s]
Epoch [5/5]
 14%|█████████▏                                                       | 200/1407 [01:09<07:21,  2.73it/s]Iter:    200,  Train Loss:  0.17,  Train Acc: 93.75%,  Val Loss:  0.27,  Val Acc: 92.15%,  Time: 0:42:11 
 28%|██████████████████▍                                              | 400/1407 [02:35<05:55,  2.83it/s]Iter:    400,  Train Loss:  0.15,  Train Acc: 95.31%,  Val Loss:  0.25,  Val Acc: 92.71%,  Time: 0:43:37 
 43%|███████████████████████████▋                                     | 600/1407 [04:01<04:33,  2.95it/s]Iter:    600,  Train Loss:  0.14,  Train Acc: 96.09%,  Val Loss:  0.26,  Val Acc: 92.37%,  Time: 0:45:03 
 57%|████████████████████████████████████▉                            | 800/1407 [05:27<03:40,  2.76it/s]Iter:    800,  Train Loss:  0.04,  Train Acc: 98.44%,  Val Loss:  0.27,  Val Acc: 92.60%,  Time: 0:46:30 
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [06:54<02:21,  2.88it/s]Iter:   1000,  Train Loss:  0.08,  Train Acc: 96.88%,  Val Loss:  0.26,  Val Acc: 92.32%,  Time: 0:47:56 
 85%|██████████████████████████████████████████████████████▌         | 1200/1407 [08:19<01:12,  2.84it/s]Iter:   1200,  Train Loss:  0.15,  Train Acc: 92.97%,  Val Loss:  0.28,  Val Acc: 92.10%,  Time: 0:49:22 
100%|███████████████████████████████████████████████████████████████▋| 1400/1407 [09:45<00:02,  2.86it/s]Iter:   1400,  Train Loss:  0.19,  Train Acc: 95.31%,  Val Loss:  0.27,  Val Acc: 92.19%,  Time: 0:50:48 
100%|████████████████████████████████████████████████████████████████| 1407/1407 [10:05<00:00,  2.32it/s]
Test Loss:  0.23,  Test Acc: 92.71%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9287    0.9120    0.9203      1000
       realty     0.9307    0.9530    0.9417      1000
       stocks     0.9087    0.8460    0.8762      1000
    education     0.9563    0.9630    0.9596      1000
      science     0.8547    0.8940    0.8739      1000
      society     0.9317    0.9270    0.9293      1000
     politics     0.9031    0.9130    0.9080      1000
       sports     0.9838    0.9700    0.9768      1000
         game     0.9436    0.9370    0.9403      1000
entertainment     0.9327    0.9560    0.9442      1000

     accuracy                         0.9271     10000
    macro avg     0.9274    0.9271    0.9270     10000
 weighted avg     0.9274    0.9271    0.9270     10000

Confusion Matrix...
[[912  16  38   2  11   5  11   1   2   2]
 [  9 953   7   1   8   6   6   4   1   5]
 [ 46  22 846   1  48   2  27   0   6   2]
 [  1   0   1 963   7  13   6   1   1   7]
 [  4   5  13   3 894  15  13   1  37  15]
 [  0  15   0  18   5 927  19   0   4  12]
 [  5   5  21  10  21  17 913   0   0   8]
 [  0   4   1   1   4   3   6 970   0  11]
 [  2   0   4   3  36   6   3   2 937   7]
 [  3   4   0   5  12   1   7   7   5 956]]

  • 结论: 将迁移学习的预训练模型替换成BART之后, 在测试集上得到了92.71%的F1分数, 低于BERT的表现, 说明BART模型的长处在于生成式任务, 而在分类任务上尽量不要使用BART!!!


BART在生成式任务中的应用

  • 展示BART在分类任务下的预训练模型:
-rw-r--r-- 1 ec2-user ec2-user      1109 Feb 22 10:34 config.json
-rw-r--r-- 1 ec2-user ec2-user 533261527 Feb 22 10:37 pytorch_model.bin
-rw-r--r-- 1 ec2-user ec2-user      3597 Feb 22 10:34 README.md
-rw-r--r-- 1 ec2-user ec2-user       112 Feb 22 10:34 special_tokens_map.json
-rw-r--r-- 1 ec2-user ec2-user       430 Feb 22 10:34 tokenizer_config.json
-rw-r--r-- 1 ec2-user ec2-user    109540 Feb 22 10:34 vocab.txt

  • 编写一段测试代码, 来验证BART模型的生成能力:
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline

tokenizer = BertTokenizer.from_pretrained('./bart_chinese_NLG')

model = BartForConditionalGeneration.from_pretrained('./bart_chinese_NLG')

generator = Text2TextGenerationPipeline(model, tokenizer)

print(generator("中国的首都是[MASK]京", max_length=50, do_sample=False))

  • 输出结果:
[{'generated_text': '中 国 的 首 都 是 北 京'}]

  • 结论: BART模型在完形填空, 生成式任务中的能力要强于BERT, 未来在相关任务中建议采用!!!