8.3 BERT之参数详解
BERT模型参数详解¶
BERT模型中参数量的计算¶
-
详细理解BERT模型中参数量的计算, 可以更加深入细致的掌握BERT结构, 从细节到宏观都有更好的把握. 分三个部分进行计算:
- Embedding层的参数计算.
- Encoder层的参数计算.
- Pooling层的参数计算.
-
计算举例以bert-base-case模型为参考, 核心参数如下:
- 层数layer = 12
- 隐藏层维度hidden_size = 768
- 多头注意力数head = 12
- 参数总量 = 110M
Embedding层¶
- BERT的Embedding分为三个部分Token Embedding, Segment Embedding, Position Embedding. 其中Token Embedding包括词表V有30522个, 对应30522个单词(或token, 不同的语言模型数量不同). Segment Embedding包括2个取值, 分别表示当前token属于第1个句子, 还是第2个句子, 这是和BERT预训练的NSP任务直接相关的. Position Embedding包括512个取值(因为BERT要求编码序列的长度不超过512). 最后每种embedding都会把token映射到H维(当前默认为768)的隐向量中.
- 这3部分词嵌入的参数量为(30522 + 512 + 2) * 768 = 23835648
- 在完成词嵌入后, 每个位置的隐向量维度都是768, 还要经过一层LayerNorm, LayerNorm的参数就是均值和方差, 所以这个模块的参数量是768 * 2.
- 综上所述: Embedding层的参数总量就是(30522 + 512 + 2) * 768 + 768 * 2 = 23837184
Encoder层¶
- BERT中的Encoder是由12个Encoder Block堆叠在一起的, 而每一个Encoder Block的内部结构完全一样, 从下到上依次是:
- Multi-head Attention
- Add & Norm
- Feed Forward
- Add & Norm
- Multi-head Attention
- 每个Block包含12个head, 每个head拥有不同的3个自注意力矩阵Q, K, V, 通过矩阵乘法将上一层的输出跟着3个矩阵分别相乘, 得到新的Q, K, V向量. 需要注意的是这里有12个head, 所以每个head的每个矩阵都会把上一层的768维度向量的输出, 映射成768 / 12 = 64维的新向量, 最后通过concat操作进行拼接, 重新得到新的768维隐向量.
- 因此这里每个head的每个矩阵的参数包括weights = 768 * (768 / 12), bias = 768 / 12. 每个head同时拥有Q, K, V这3个矩阵, 每个Block又有12个head, 因此参数量 = 12 * 3 * (768 * (768 / 12) + 768 / 12) = 1771776.
- 在将12个head的输出concat到一起后, 还会经过一个全连接层的操作, 本质上是一个方阵映射, 这部分weight = 768 * 768, bias = 768.
- 因此整个Multi-head Attention模块的参数量 = 12 * 3 * (768 * (768 / 12) + 768 / 12) + 768 * 768 + 768 = 2362368
- Add & Norm
- Add本质上是跨层连接, 起到残差连接的作用, 没有额外的参数.
- Norm在这里指代LayerNorm操作, 参数包括均值和方差, 它接收的是上一层Multi-head Attention的768维的输出结果, 处理后的输出张量维度不变, 所以Norm总共有768 * 2个参数
- Feed Forward
- Feed Forward是前馈全连接层, 这里包括2层全连接层, 第1层全连接层会进行升维, 会把输入从当前维度(768)映射成4倍当前维度的中间层(3072), 第2层全连接层会进行降维, 会把4倍初始输入维度的中间层结果(3072)映射到初始维度(768).
- 第1层全连接层的weight = 768 * (768 * 4), bias = 768 * 4; 第2层全连接层的weight = (768 * 4) * 768, bias = 768.
- 综上所述, Feed Forward层的参数量 = 768 * (768 * 4) + 768 * 4 + (768 * 4) * 768 + 768 = 4722432
- Add & Norm
- 这里的Add & Norm和前面的论述一样, 参数量768 * 2.
- 综上所述, 整个Encoder部分拥有12个Block, 每个Block的参数前面已经详细计算过了, 因为总量是12 * (2362368 + 1536 + 4722432 + 1536) = 85054464
Pooling层¶
- Pooling层本质是一层全连接层, 它的输入是Encoder层输出的768维隐向量, 输出保持维度不变. 因此只包括weight = 768 * 768, bias = 768, 参数量 = 768 * 768 + 768 = 590592
结论: 整个BERT模型的参数量为前述3部分的总和, Embedding + Encoder + Pooling = 23837184 + 85054464 + 590592 = 109482240.