8.2 BERT优化经验谈
BERT优化经验汇总¶
学习目标¶
- 理解BERT的若干优化技巧.
- 掌握BERT的具体优化方法.
BERT优化之数据¶
-
BERT非常厉害, 各个领域都拿到过SOTA, 虽然已经发布好几年了, 但是现在依旧是标杆级别的存在, 可是现在我用了我的数据, BERT不会好, 到底是怎么回事?
-
往扎心的说, 其实就是一味地万能钥匙不灵了, 所以无从下手.
-
NLP的一大好处就是数据本身是可解释可推理, 人本身也可以推测的, 而分析数据, 往往可以知道一些非常直接的问题.
- 首先需要看的是测试集, 这个相当于是考试的答卷, 做错了那些题一目了然, 通过这些数据我们能知道模型做错了那些事, 强如BERT, 也是可能出错的, 常见的错误是这些:
- 标注的质量: 在现实很多场景, 其实是很容易出现标注错误的, 很多NLP的问题准确率天花板都停留在90左右, 大都是因为标注质量问题, 说白了就是标错, 这些标错的数据很可能是模型预测对了标注错误了导致的正确, 这就导致指标不好看, "效果不好了".
- 标注的分布: 尤其是分类问题, 正负样本是否符合预期, 各个类目的数据是否达到了统计意义, 实际占比和与其占比是否一致等.
- 数据可靠性: 可能比较少见, 有的人做的测试集和训练集压根就不是一回事, 值得关心.
- bad case分析: 这块非常细致, 放在后面单独一个模块详细说.
- 测试集: 测试集是考试题, 本质是考验模型能力的, 所以这里核心是要保质保量, 质在于题目可靠, 真的能考验出模型的真实能力, 量在于统计意义, 现在的指标都是用的统计指标, 所以数据本身必须具有统计意义, 例如某个类只有2条数据, 两条全错能证明有问题吗, 其实也不太好说.
- 其次需要看的是训练集, 训练集内同样可能有问题, 而有些问题其实可通过训练集的问题来体现:
- 标注的质量: 同理于测试集.
- 标注的分布: 这块的核心是数据不平衡问题的解决方案.
- 样本数量: 越是复杂的模型, 对数据的渴求度越大, 尤其是场景比较偏的, 需要更多数据集才行, 少数据不足以让模型对你的数据有足够的了解.
- 样本领域性: 这里单独谈, 很多领域专业性强, 是需要更多数据支撑的, 例如医学, 另外是名词性比较强的, 对数据有特殊的依赖性.
- 训练集: 学习资料和练习题要足够, 才能让模型学得会, 学得好. 数据分布问题, 不能偏科, 各个类型的数据最好都能覆盖. 领域性的问题, 最好由领域性的数据选择, 甚至是用这些数据做MLM的任务来微调.
- 数据层面的问题, 很可能是导致BERT效果不好的根本原因, 他的背后其实是场景问题, 场景的数据可没有实验室的那么理想, 各有各的特色. 在实验室中BERT的效果确实会比常规的textcnn, biltm-crf, ESIM等小模型效果好, 但是在很多现实场景优势没那么明显, 甚至会不如, 大家可以持乐观态度, 但请别成为信仰!
- 数据增强: 增强本质不是增多, 不是所有缺数据的问题都是因为数量不足, 模型要泛化能力, 他的泛化能力来源于数据的泛化, 很多时候数据提供的不足那就不会有这么强的泛化, 得到的反而是过拟合. 很多时候你需要的可能是更多地挖掘数据, 从日志, 从更多渠道去找, 这个可能比增强本身要好.
- 模型 VS 规则: 值得强调的是, 对于名词性比较强的任务, 最好不要用BERT了, 甚至是模型都不要用, 举个例子, 现在告诉你"非诚勿扰", "中国诗词大会", "脱口秀大会"这些是综艺, 你能推断出"哈哈哈"也是综艺节目吗? 你这不就是难为人家模型了吗? 所以此处词典的作用会比模型本身要大, 哪怕是为了提升泛化能力要用模型, 那也需要和模型结合着来做.
- 样本不平衡问题: 分好几个方面去衡量, 下面一一解析.
- 1: 第一个就要检查一下不平衡的程度, 如果不同类别差别到数量级的程度, 那真得想办法了, 可以通过loss之类的手段来调整, 让loss更关注那些类别样本少的类.
- 2: 第二个就要检查是否有的类别样本数绝对量少, 比如只有几十个样本, 这个时候强烈考虑上规则, 或者被迫进行样本挖掘.
- 3: 第三个就要检查测试集是否存在样本不平衡, 表现差的那一类的准确率和召回率的情况. 如果每个类都差, 那基本确定和类别不平衡无关, 可以进行数据优化和模型层面的优化. 如果样本多的类差, 样本少的类好, 那也可以确定不是样本不平衡的问题. 当出现样本少的类差, 样本多的类好, 那这就是典型的样本不平衡问题, 具体按照bad case的方案来做就好.
样本不平衡测试¶
- 目的是在头条投满分项目上进行样本不平衡的测试, 以衡量在不同数量, 不同平衡比例的情况下, 对模型表现的影响:
- 情况1: 按照最高18000, 最低9000, 不平衡比例不超过1倍, 同时样本数量充足.
- 情况2: 按照最高18000, 最低1800, 不平衡比例达到10倍, 同时样本数量较充足.
- 情况3: 按照最高9000, 最低1800, 不平衡比例不超过5倍, 同时样本数量较充足.
- 情况4: 按照最高3000, 最低300, 不平衡比例达到10倍, 同时样本数量不足.
-
情况1: 按照最高18000, 最低9000, 不平衡比例不超过1倍, 同时样本数量充足.
-
数据情况展示:
# 针对10个类别的采样条数, 采样前通过random.shuffle()进行打乱, 随机采样
choice_num = [12000, 9000, 11000, 18000, 15000, 17000, 13000, 18000, 9000, 14000]
- 调用:
# 唯一修改的就是训练集数据, 对bert.py文件进行修改
self.data_path = "/home/ec2-user/ec2-user/zhudejun/bert/toutiao/data/data1/"
self.train_path = self.data_path + "train1.txt" # 训练集
self.dev_path = self.data_path + "dev.txt" # 验证集
self.test_path = self.data_path + "test.txt" # 测试集
- 输出结果
Loading data for Bert Model...
136000it [00:32, 4243.21it/s]
10000it [00:02, 4412.64it/s]
10000it [00:02, 4379.13it/s]
Epoch [1/5]
38%|████████████████████████▍ | 400/1063 [04:11<07:13, 1.53it/s]Iter: 400, Train Loss: 0.11, Train Acc: 97.66%, Val Loss: 0.3, Val Acc: 90.77%, Time: 0:04:31 *
75%|████████████████████████████████████████████████▉ | 800/1063 [08:50<02:50, 1.54it/s]Iter: 800, Train Loss: 0.2, Train Acc: 91.41%, Val Loss: 0.25, Val Acc: 92.34%, Time: 0:09:12 *
100%|████████████████████████████████████████████████████████████████| 1063/1063 [12:02<00:00, 1.47it/s]
Epoch [2/5]
38%|████████████████████████▍ | 400/1063 [04:20<07:12, 1.53it/s]Iter: 400, Train Loss: 0.085, Train Acc: 96.88%, Val Loss: 0.24, Val Acc: 92.44%, Time: 0:16:45 *
75%|████████████████████████████████████████████████▉ | 800/1063 [09:02<02:51, 1.53it/s]Iter: 800, Train Loss: 0.077, Train Acc: 96.09%, Val Loss: 0.25, Val Acc: 92.84%, Time: 0:21:24
100%|████████████████████████████████████████████████████████████████| 1063/1063 [12:11<00:00, 1.45it/s]
Epoch [3/5]
38%|████████████████████████▍ | 400/1063 [04:20<07:11, 1.53it/s]Iter: 400, Train Loss: 0.069, Train Acc: 96.88%, Val Loss: 0.27, Val Acc: 92.05%, Time: 0:28:54
75%|████████████████████████████████████████████████▉ | 800/1063 [08:59<02:51, 1.53it/s]Iter: 800, Train Loss: 0.05, Train Acc: 98.44%, Val Loss: 0.26, Val Acc: 92.65%, Time: 0:33:33
100%|████████████████████████████████████████████████████████████████| 1063/1063 [12:08<00:00, 1.46it/s]
Epoch [4/5]
38%|████████████████████████▍ | 400/1063 [04:20<07:10, 1.54it/s]Iter: 400, Train Loss: 0.023, Train Acc: 100.00%, Val Loss: 0.26, Val Acc: 92.97%, Time: 0:41:02
75%|████████████████████████████████████████████████▉ | 800/1063 [08:59<02:52, 1.53it/s]Iter: 800, Train Loss: 0.079, Train Acc: 97.66%, Val Loss: 0.27, Val Acc: 92.67%, Time: 0:45:42
100%|████████████████████████████████████████████████████████████████| 1063/1063 [12:08<00:00, 1.46it/s]
Epoch [5/5]
38%|████████████████████████▍ | 400/1063 [04:20<07:12, 1.53it/s]Iter: 400, Train Loss: 0.046, Train Acc: 99.22%, Val Loss: 0.3, Val Acc: 92.67%, Time: 0:53:11
75%|████████████████████████████████████████████████▉ | 800/1063 [08:59<02:50, 1.54it/s]Iter: 800, Train Loss: 0.039, Train Acc: 98.44%, Val Loss: 0.32, Val Acc: 92.57%, Time: 0:57:49
100%|████████████████████████████████████████████████████████████████| 1063/1063 [12:08<00:00, 1.46it/s]
Test Loss: 0.29, Test Acc: 92.44%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.9414 0.8840 0.9118 1000
realty 0.9535 0.9220 0.9375 1000
stocks 0.7991 0.9150 0.8531 1000
education 0.9572 0.9620 0.9596 1000
science 0.8738 0.9070 0.8901 1000
society 0.9138 0.9430 0.9281 1000
politics 0.9143 0.8850 0.8994 1000
sports 0.9858 0.9740 0.9799 1000
game 0.9609 0.9340 0.9473 1000
entertainment 0.9704 0.9180 0.9435 1000
accuracy 0.9244 10000
macro avg 0.9270 0.9244 0.9250 10000
weighted avg 0.9270 0.9244 0.9250 10000
Confusion Matrix...
[[884 10 74 4 11 7 7 1 2 0]
[ 9 922 29 2 8 12 10 1 3 4]
[ 34 14 915 1 22 0 12 0 1 1]
[ 0 0 3 962 4 13 12 0 0 6]
[ 2 1 27 7 907 18 18 0 14 6]
[ 0 9 5 11 13 943 12 0 3 4]
[ 7 5 62 9 9 21 885 0 0 2]
[ 2 2 8 0 2 4 4 974 0 4]
[ 0 1 5 4 43 8 3 1 934 1]
[ 1 3 17 5 19 6 5 11 15 918]]
- 结论: 数据进行9000 - 18000之间采样, 测试集上的表现从93.64%下降到92.44%, 下降1.2%, 已经属于显著性下降了. 同时发现在最少的两个类别9000样本的情况下, F1值反倒高于总体值, 说明情况1的不平衡并没有对单一类别造成大的影响, 只是影响了总体的表现.
-
情况2: 按照最高18000, 最低1800, 不平衡比例达到10倍, 同时样本数量较充足.
-
数据情况展示:
# 针对10个类别的采样条数, 采样前通过random.shuffle()进行打乱, 随机采样
choice_num = [10000, 6000, 13000, 1800, 3000, 18000, 15000, 1800, 9000, 14000]
- 调用:
# 唯一修改的就是训练集数据, 对bert.py文件进行修改
self.data_path = "/home/ec2-user/ec2-user/zhudejun/bert/toutiao/data/data1/"
self.train_path = self.data_path + "train2.txt" # 训练集
self.dev_path = self.data_path + "dev.txt" # 验证集
self.test_path = self.data_path + "test.txt" # 测试集
- 输出结果:
Loading data for Bert Model...
91600it [00:20, 4378.95it/s]
10000it [00:02, 4112.24it/s]
10000it [00:02, 4377.83it/s]
Epoch [1/5]
56%|████████████████████████████████████▊ | 400/716 [04:10<03:25, 1.54it/s]Iter: 400, Train Loss: 0.19, Train Acc: 92.97%, Val Loss: 0.34, Val Acc: 90.03%, Time: 0:04:31 *
100%|██████████████████████████████████████████████████████████████████| 716/716 [07:55<00:00, 1.51it/s]
Epoch [2/5]
56%|████████████████████████████████████▊ | 400/716 [04:19<03:25, 1.54it/s]Iter: 400, Train Loss: 0.12, Train Acc: 96.09%, Val Loss: 0.31, Val Acc: 90.65%, Time: 0:12:36 *
100%|██████████████████████████████████████████████████████████████████| 716/716 [08:05<00:00, 1.48it/s]
Epoch [3/5]
56%|████████████████████████████████████▊ | 400/716 [04:19<03:25, 1.54it/s]Iter: 400, Train Loss: 0.065, Train Acc: 97.66%, Val Loss: 0.34, Val Acc: 90.86%, Time: 0:20:39
100%|██████████████████████████████████████████████████████████████████| 716/716 [08:02<00:00, 1.49it/s]
Epoch [4/5]
56%|████████████████████████████████████▊ | 400/716 [04:19<03:24, 1.54it/s]Iter: 400, Train Loss: 0.062, Train Acc: 97.66%, Val Loss: 0.38, Val Acc: 90.37%, Time: 0:28:41
100%|██████████████████████████████████████████████████████████████████| 716/716 [08:02<00:00, 1.49it/s]
Epoch [5/5]
56%|████████████████████████████████████▊ | 400/716 [04:19<03:24, 1.54it/s]Iter: 400, Train Loss: 0.089, Train Acc: 96.88%, Val Loss: 0.41, Val Acc: 90.11%, Time: 0:36:43
100%|██████████████████████████████████████████████████████████████████| 716/716 [08:02<00:00, 1.49it/s]
Test Loss: 0.34, Test Acc: 91.80%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.9221 0.9110 0.9165 1000
realty 0.9600 0.9110 0.9348 1000
stocks 0.8164 0.9250 0.8673 1000
education 0.9690 0.9380 0.9533 1000
science 0.9519 0.7710 0.8519 1000
society 0.8519 0.9660 0.9053 1000
politics 0.8898 0.9120 0.9007 1000
sports 0.9915 0.9370 0.9635 1000
game 0.9452 0.9490 0.9471 1000
entertainment 0.9195 0.9600 0.9393 1000
accuracy 0.9180 10000
macro avg 0.9217 0.9180 0.9180 10000
weighted avg 0.9217 0.9180 0.9180 10000
Confusion Matrix...
[[911 9 57 1 2 11 6 1 0 2]
[ 14 911 21 2 3 22 10 3 1 13]
[ 34 10 925 1 11 1 15 0 1 2]
[ 1 1 4 938 0 36 15 0 0 5]
[ 13 6 68 4 771 33 44 0 47 14]
[ 2 4 2 9 0 966 10 0 0 7]
[ 6 1 39 7 2 28 912 0 0 5]
[ 3 3 11 1 1 9 5 937 2 28]
[ 2 2 5 1 16 13 3 1 949 8]
[ 2 2 1 4 4 15 5 3 4 960]]
- 结论: 数据不平衡度达到了10倍后, 最少样本1800, 但是在最少的两个类别中, 分类F1值均高于0.95. 并且总体的F1值继续下降到91.80%, 说明这种不平衡依然对少样本类吴影响, 对总体有影响.
-
情况3: 按照最高9000, 最低1800, 不平衡比例不超过5倍, 同时样本数量较充足.
-
数据情况展示:
# 针对10个类别的采样条数, 采样前通过random.shuffle()进行打乱, 随机采样
choice_num = [2000, 9000, 7000, 1800, 2500, 9000, 6000, 8000, 4000, 5000]
- 调用:
# 唯一修改的就是训练集数据, 对bert.py文件进行修改
self.data_path = "/home/ec2-user/ec2-user/zhudejun/bert/toutiao/data/data1/"
self.train_path = self.data_path + "train3.txt" # 训练集
self.dev_path = self.data_path + "dev.txt" # 验证集
self.test_path = self.data_path + "test.txt" # 测试集
- 输出结果:
Loading data for Bert Model...
54300it [00:12, 4266.69it/s]
10000it [00:02, 4414.53it/s]
10000it [00:02, 4185.29it/s]
Epoch [1/5]
94%|██████████████████████████████████████████████████████████████ | 400/425 [04:11<00:16, 1.53it/s]Iter: 400, Train Loss: 0.26, Train Acc: 92.19%, Val Loss: 0.35, Val Acc: 90.01%, Time: 0:04:31 *
100%|██████████████████████████████████████████████████████████████████| 425/425 [04:46<00:00, 1.48it/s]
Epoch [2/5]
94%|██████████████████████████████████████████████████████████████ | 400/425 [04:20<00:16, 1.53it/s]Iter: 400, Train Loss: 0.093, Train Acc: 97.66%, Val Loss: 0.34, Val Acc: 90.50%, Time: 0:09:29 *
100%|██████████████████████████████████████████████████████████████████| 425/425 [04:57<00:00, 1.43it/s]
Epoch [3/5]
94%|██████████████████████████████████████████████████████████████ | 400/425 [04:21<00:16, 1.54it/s]Iter: 400, Train Loss: 0.071, Train Acc: 97.66%, Val Loss: 0.38, Val Acc: 90.25%, Time: 0:14:24
100%|██████████████████████████████████████████████████████████████████| 425/425 [04:54<00:00, 1.44it/s]
Epoch [4/5]
94%|██████████████████████████████████████████████████████████████ | 400/425 [04:20<00:16, 1.54it/s]Iter: 400, Train Loss: 0.082, Train Acc: 97.66%, Val Loss: 0.44, Val Acc: 89.75%, Time: 0:19:18
100%|██████████████████████████████████████████████████████████████████| 425/425 [04:53<00:00, 1.45it/s]
Epoch [5/5]
94%|██████████████████████████████████████████████████████████████ | 400/425 [04:20<00:16, 1.53it/s]Iter: 400, Train Loss: 0.06, Train Acc: 98.44%, Val Loss: 0.44, Val Acc: 89.73%, Time: 0:24:12
100%|██████████████████████████████████████████████████████████████████| 425/425 [04:54<00:00, 1.44it/s]
Test Loss: 0.37, Test Acc: 91.07%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.9473 0.8450 0.8932 1000
realty 0.9219 0.9440 0.9328 1000
stocks 0.8134 0.8980 0.8536 1000
education 0.9606 0.9020 0.9304 1000
science 0.8844 0.8570 0.8705 1000
society 0.8432 0.9570 0.8965 1000
politics 0.9387 0.8730 0.9047 1000
sports 0.9600 0.9850 0.9724 1000
game 0.9737 0.8890 0.9294 1000
entertainment 0.8961 0.9570 0.9255 1000
accuracy 0.9107 10000
macro avg 0.9139 0.9107 0.9109 10000
weighted avg 0.9139 0.9107 0.9109 10000
Confusion Matrix...
[[845 19 90 5 13 14 5 4 1 4]
[ 2 944 12 5 4 14 4 4 0 11]
[ 29 24 898 2 21 3 15 4 0 4]
[ 2 4 5 902 3 50 13 2 2 17]
[ 2 7 45 5 857 27 12 4 18 23]
[ 2 11 2 10 3 957 4 0 1 10]
[ 8 7 47 8 8 35 873 4 0 10]
[ 0 0 0 0 0 5 0 985 0 10]
[ 1 4 5 1 55 15 3 5 889 22]
[ 1 4 0 1 5 15 1 14 2 957]]
- 结论: 数据量总体下降后, 最少样本的第1分类, 第4分类出现了大约2个百分点的下降, 总体F1值继续下降到91.07%, 说明数据量的下降还是对模型表现造成了影响, 但是BERT的鲁棒性和抗抖动性都非常出色!
-
情况4: 按照最高3000, 最低300, 不平衡比例达到10倍, 同时样本数量不足.
-
数据情况展示:
# 针对10个类别的采样条数, 采样前通过random.shuffle()进行打乱, 随机采样
choice_num = [2000, 900, 1000, 1800, 300, 2500, 3000, 800, 3000, 1500]
- 调用:
# 唯一修改的就是训练集数据, 对bert.py文件进行修改
self.data_path = "/home/ec2-user/ec2-user/zhudejun/bert/toutiao/data/data1/"
self.train_path = self.data_path + "train4.txt" # 训练集
self.dev_path = self.data_path + "dev.txt" # 验证集
self.test_path = self.data_path + "test.txt" # 测试集
- 输出结果:
Loading data for Bert Model...
16800it [00:03, 4266.37it/s]
10000it [00:02, 4370.77it/s]
10000it [00:02, 4179.98it/s]
Epoch [1/5]
76%|██████████████████████████████████████████████████ | 100/132 [00:59<00:19, 1.65it/s]Iter: 100, Train Loss: 0.24, Train Acc: 94.53%, Val Loss: 0.45, Val Acc: 86.92%, Time: 0:01:18 *
100%|██████████████████████████████████████████████████████████████████| 132/132 [01:36<00:00, 1.36it/s]
Epoch [2/5]
76%|██████████████████████████████████████████████████ | 100/132 [01:04<00:20, 1.54it/s]Iter: 100, Train Loss: 0.18, Train Acc: 92.97%, Val Loss: 0.49, Val Acc: 86.32%, Time: 0:03:00
100%|██████████████████████████████████████████████████████████████████| 132/132 [01:42<00:00, 1.29it/s]
Epoch [3/5]
76%|██████████████████████████████████████████████████ | 100/132 [01:05<00:20, 1.54it/s]Iter: 100, Train Loss: 0.057, Train Acc: 99.22%, Val Loss: 0.48, Val Acc: 87.34%, Time: 0:04:44
100%|██████████████████████████████████████████████████████████████████| 132/132 [01:43<00:00, 1.27it/s]
Epoch [4/5]
76%|██████████████████████████████████████████████████ | 100/132 [01:05<00:20, 1.54it/s]Iter: 100, Train Loss: 0.1, Train Acc: 97.66%, Val Loss: 0.54, Val Acc: 86.87%, Time: 0:06:28
100%|██████████████████████████████████████████████████████████████████| 132/132 [01:43<00:00, 1.27it/s]
Epoch [5/5]
76%|██████████████████████████████████████████████████ | 100/132 [01:04<00:20, 1.54it/s]Iter: 100, Train Loss: 0.013, Train Acc: 100.00%, Val Loss: 0.57, Val Acc: 87.24%, Time: 0:08:11
100%|██████████████████████████████████████████████████████████████████| 132/132 [01:43<00:00, 1.28it/s]
Test Loss: 0.59, Test Acc: 87.03%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.7543 0.9670 0.8475 1000
realty 0.9420 0.8770 0.9083 1000
stocks 0.8713 0.7110 0.7830 1000
education 0.9677 0.9300 0.9485 1000
science 0.9823 0.5540 0.7084 1000
society 0.8442 0.9430 0.8909 1000
politics 0.8693 0.9110 0.8896 1000
sports 0.9922 0.8870 0.9366 1000
game 0.7541 0.9720 0.8493 1000
entertainment 0.8661 0.9510 0.9066 1000
accuracy 0.8703 10000
macro avg 0.8843 0.8703 0.8669 10000
weighted avg 0.8843 0.8703 0.8669 10000
Confusion Matrix...
[[967 5 6 2 0 7 6 0 4 3]
[ 33 877 9 2 0 31 12 2 16 18]
[202 35 711 0 6 3 30 0 8 5]
[ 5 0 1 930 0 27 17 0 6 14]
[ 37 7 57 7 554 43 37 0 225 33]
[ 6 4 1 11 0 943 19 0 3 13]
[ 16 2 25 5 2 30 911 0 2 7]
[ 10 1 4 2 0 11 10 887 29 46]
[ 3 0 2 0 2 10 3 0 972 8]
[ 3 0 0 2 0 12 3 5 24 951]]
- 结论: 在样本最少的类别只有300条, 最终的F1值只有70.84%, 可见低于一定量的样本值对模型的特征学习会造成很大的负面影响. 其他两类样本量低于1000的类别, 也都有几个百分点的下降. 初步可以得出这样的结论, 在我们当前的项目中, 维持1000个左右的样本是底线, 最好能达到3000以上, 以利于模型充分的学习特征!
BERT优化之训练¶
-
BERT的训练其实挺多讲究的, 这里的实验效果要保证对参数的有一定的要求, 所以大家要多去观察训练过程暴露的问题, 训练过程其实就是要观测loss变化, 验证集效果等的问题, 防止没学到, 学飘了之类的问题.
-
比如超参数的设定习惯, 学习率的warmup启动, 关于类似于scheduled sampling, weight tying等训练策略的优化等等.
BERT优化之模型¶
-
第一层级(选定了一个BERT系列模型): 检查有没有bug, 代码整体流程是否有问题, 无论是训练还是推理, 这个就得自己检查和使用了, 这个没法解, 只能自己debug, 找问题然后解决.
-
第二层级(针对具体问题, 具体业务挑选一个合适的BERT系列模型): 这块就是针对于前面介绍过的总共20个左右模型, 大量进行试验, 快速对比baseline, 以及自己的实践经验来进行挑选.
BERT优化之bad case¶
- Bad case分析的意义: 通过基本的指标, 分析出我们现在的算法方案的现状可能并不困难, 但是怎么"提升效果", 却是很多人很迷茫的问题, 这是微观角度, 更宏观地讲, 如何提升自己解决问题的能力, 如何让自己更好更快地解决更多的问题, 这个同样让人非常迷茫, 我们为了提升自己不断学习各种算法, 各种论文, 却发现除了模型逐步内卷之外, 自己解决问题的能力并没有提升, 一个问题试了成千上万的模型却发现提升可谓是杯水车薪, 解决问题的天花板变得很低, 效果提升也变得很不可控.
-
关于模型效果的提升, 有些时候需要站的更高, 看得更远, 跳出现有的维度:
- 大道至简: 简单的方案并不比复杂模型的效果差.
- 从个例中找突破: bad case是破解问题, 探索问题的关键钥匙.
-
如果模型效果不好是病, 那bad case就是症状, 是问题现场留下的线索, 足够了解bad case能让你对问题背后的病因有更多的了解, 方便对症下药.
-
bad case分析是我了解业务, 了解场景的重要途径, 知道当前研究, 实践中容易出现的问题, 往通俗的说就是吸收经验的重要途径.
-
第一步: 明确bad case分析的当前状态.
-
现在有一个基本baseline, baseline自然就有对应的评测集基线方案的效果, 这个所谓的效果就是现状, 好or不好, 是否需要持续优化, 哪个方面需要持续优化, 这就是所谓的现状, 只有明确了宏观的现状, 才能知道我们下一步的动作.
首先要有一个完整意义的测试集, 不能用几条样本来代替总体评估.
第二是由一定统计意义的数据集, 数量和质量都要保证.
第三是数据质量要可靠, 可以通过人工审核的方法来进行, 随机抽取100-500条样本, 纯人工细致审核, 如果准确率达到99%以上算顶尖级精品数据集, 95%以上算靠谱的, 90%就勉强够用, 低于90%的数据集别指望训练出优秀的模型.
- 第二步: 明确bad case分析的当前指标.
- 1: 观测的目标. 首先, 指标的设计必须要考虑观测的目标. 指标是用来体现现状的, 那要体现什么现状, 这就是指标存在的根本. 准确率是用来评估预测的正确性的, 这个大家都知道, 但是我们需要评估的是哪种方面的准确率? 这个就很关键了, 例如我们在线的关键目标是为了避免误召回, 要看误召回的情况, 则准确率很自然地就会想到用"实际是该意图/算法预测该意图正类"来作为准确率指标, 对于负类的预测情况, 我们可能关心的并不多, 所以就可以不设计在计算公式里了, 这就是指标设计和背后的观测目标的关系.
- 2: 指标的口径: 数据分析非常讲究一个点, 就是口径, 即我们取数据的条件, 不同的条件取数据很可能会得到截然不同的结论, 因此我们必须非常关注口径. 举例, 我们在分析和解决bad case的时候, 针对多路召回, 我们是要看每一路的效果情况来归因, 所以我们就要划分口径了. 方案A召回的正类准确率是多少, 方案B召回的准确率又是多少, 所有出错的case中, 哪一路召回的量最大, 通过这样的口径切分和分析, 我们能快速知道是哪个方案出现了问题, 哪个方案的问题是关键问题.
- 3: 多个指标的组合观测: 一个指标能让我们知道一个问题, 但是有的时候单独看一个指标很可能会让我们忽略一切别的关键问题. 例如, 准确率虽然能让我们观测算法方案误召回的情况, 一旦要优化准确率, 召回的数量肯定也会随之降低, 为了避免出现召回过低的问题, 在看准确率的同时我们还要看召回率, 两者平衡提升才是健康可靠的.
-
第三步: 明确bad case分析的当前结论.
-
最核心的问题就是当前的总体表现距离业务要求的还差多少? 这个结论本身是指导我们后续工作的, 好or不好取决于我们是否还需要做进一步的优化, 所以下结论的时候要顾及我们下一步的工作.
-
Bad case分析之解决方案:
-
比如错误样本的长度都是短句, 都是问句, 或者都包含数字串, 或者有正则特征, 等等.
-
例如我们在百科意图里发现很多问句会被误召回, 这时候我们可以查查训练集中带有问句关键词的正负样本的比例是不是有问题, 例如正样本远多于负样本, 此时模型很可能就会认为只要带有疑问词的就是正类, 我们需要加一些带疑问词的负样本进去, 例如"为什么你的心情不好"这种属于闲聊类而非百科类的样本.
-
例如短句预测不正确, 而且一般都比较模糊, 所以我们考虑的是过滤所有短句, 然后再通过白名单的方式补召回, 那我们就可以把query归类为短句.
-
比如只要出现某些关键词的文本都有什么分类倾向, 那么直接字典方案处理即可.
- 很多极其特殊的文本, 就只能通过bad case分析发现问题, 然后靠规则来解决:
- 例如"怎么去", 几乎所有人的反应, 包括模型都认为是"地点意图":
- 怎么去王府井?
- 怎么去颐和园?
- 怎么去鼓楼大街的庆丰包子铺?
- 怎么去天堂? (WHAT???)
- 例如"怎么去", 几乎所有人的反应, 包括模型都认为是"地点意图":
- case分析这块的工作的, 算法的增长点, 创新点都是从这里产生的, 我们很多时候所谓的在努力地刷sota, 是否有真正看过现在所谓的前沿方案, 里面的bad case是什么样的, 我们能通过什么方式让他变得更好, 思路就是从中一点一点找出来的, 所谓的明察秋毫, 就是需要我们观察, bad case就是我们的线索, 要突破必须要把这块东西给分析清楚.
小节总结¶
- 本小节学习了BERT优化的方方面面, 既有心法, 又有具体策略, 值得仔细体味!