跳转至

7.7 Reformer模型深度解析

Reformer模型


Reformer模型背景

  • Reformer模型是Google于2020年提出的最新版本模型, 针对Transformer处理超长文本并压缩内存的优化模型. 原始论文<< REFORMER: THE EFFICIENT TRANSFORMER >>.


  • 理解序列数据, 如语言, 音乐或视频是一项具有挑战性的任务, 特别是当它依赖于大量的周围环境时. 例如, 如果一个人或一个物体在视频中消失, 很久以后又重新出现, 许多模型就会忘记它的样子. 在自然语言处理领域, LSTM神经网络覆盖了足够的上下文来逐句翻译. 在这种情况下, 上下文窗口(在翻译过程中需要考虑的数据范围), 从几十个词到大约100个词不等. 最新的Transformer模型不仅改进了逐句翻译的性能, 还可以通过多文档摘要生成整个Wikipedia的文章. 这主要依赖于Transformer的强大能力, 它使用的上下文窗口可以扩展到数千个单词. 有了这样一个大的上下文窗口, Transformer可以用于文本以外的应用, 包括像素或音符, 使其能够用于生成音乐和图像.

  • 但是将Transformer扩展到更大的上下文窗口会遇到限制. Transformer的能力来自于注意力机制, 在这个过程中, 它考虑上下文窗口中所有可能的单词对, 以理解它们之间的联系. 因此, 对于100K个单词的文本, 这需要评估100K x 100K个单词对, 或者每一步100亿对, 这对于计算复杂度来说是不切实际的. 另一个问题是存储每个模型层输出的标准实践. 对于使用大型上下文窗口的应用程序, 存储多个模型层的输出的内存需求很快变得非常大(从只有几层的GB字节到有数千层的模型的TB字节), 这意味着使用许多层的实际的Transformer模型只能用于几段文本或生成简短的音乐片段.

  • Reformer模型的本质是一个基于Transformer的模型, 设计用于处理最多100万个单词的上下文窗口, 所有这些都在一个单一的加速器上, 并且只使用了16GB的内存. 它结合了两种关键技术来解决注意力和内存分配问题.



Reformer模型的四大优化点

优化一: Axial Positional Embedding

  • 参考论文里的实验, 其中enwik8-64K的序列长度是64K! 假设位置向量维度是512, 那么仅仅位置编码矩阵就有3000+万个参数, 这明显是不行了.

  • 参考Reformer的源代码, 就会发现使用的是Axial Positional Embedding(APE). APE可用于输入是多维数组(Tensor)的场景, 不过NLP的输入基本都是一维的序列, 我们假设位置向量维度是d, 输入序列的长度是L, 并且L = m * n, 我们可以将序列排列成矩阵P, 如下图所示, 则矩阵元素P(i, j)对应原始序列中第(i - 1) * m + j个位置(i和j的下标都是从1开始).

  • 为矩阵P的每一行都创建一个向量r_i, 同时为每一列都创建一个向量c_j, 向量维度分别是d1, d2. 这样矩阵元素P(i, j)就关联了两个向量(r_i, c_j), 如果再令d1 + d2 = d, 就可以通过concat(r_i, c_j)来表示P(i, j)的位置向量, 维度等于d, 这就是一维场景下的APE.

  • 分析一下参数量: 传统位置编码矩阵大小是L * d, 使用APE方法会得到两个矩阵, 大小分别是m * d1, n * d2. 以enwik8-64K为例, L = 10000, d = 512, 如果令m = 253, n = 253, d1 = 256, d2 = 256, 参数量从3000+万减少到13万!!!

优化二: 分段处理全连接层(FFN)

  • 全连接层一直是神经网络中的参数量大户, Transformer中的FFN计算公式为FFF(x) = max(0, x*W1 + b1)*W2 + b2. 其中W1和W2的维度是embedding_dim * ff_dim, 这里面ff_dim一般从2048到4096.

  • 注意Transformer中的FFN全称是Position-wise Feed-Forward Networks, 重点就是这个position-wise. 区别于普通的全连接网络, 这里FFN的输入是序列中每个位置上的元素, 而不是整个序列, 所以每个元素完全可以独立计算, 最极端节省内存的做法是遍历序列, 每次只取一个元素得到FFN的结果. 但是这样做时间消耗太大, "分段"的含义就是做下折中, 将序列分成N段, 也就是N个子序列, 每次读取一个子序列进行FFN计算, 最后将N份的结果拼接, 如下述公式:


  • 分段FFN只是一种计算上的技巧, 计算结果和原始FFN完全一致, 所以不会影响到模型效果. 优点是不需要一次性将整个序列(batch_size, seq_length, embedding_dim)读入内存, 缺点是会增加额外的时间开销.

优化三: 局部敏感性哈希算法(LSH)

  • attention的计算时间复杂度O(L*L)的问题, 自从2017年以来一直都是研究热点, 优化方法层出不穷. 由于softmax先指数缩放再归一化的本质, 使其极易受到极大值影响, 特别是对于长序列, 得到的分布几乎总是稀疏的. 这种稀疏性分布是有现实含义的: 序列中的某个元素一般只会和少数几个元素具有较高的相似性/关联性.

  • 一种通用的思路就是如果我们能为每个query(q_i)找到最相似的K个key, 只和它们计算点积, 再取softmax, 只要寻找相似key集合速度够快, 并且K相比于L足够小, 这种方式不论是时间还是空间都是有优势的. 那么问题来了, 如何快速找到和q_i最相似的key呢?

  • 这个问题可以转化为向量检索问题, 向量检索又可以分为精确(exact)检索和近似(approximate)检索两类, 后者在学术界被称为approximated nearest neighbor search(ANN)问题, 在工业界有非常广泛的应用, 毕竟能提供检索服务的场景数据量一般都非常大, 如何快速召回就显得格外重要, ANN问题的解决方案包括局部敏感性哈希(Locality Sensitive Hashing, LSH), 树方法和Product Quantization.


  • 在Reformer模型中选择的就是LSH. 局部敏感性哈希函数是一类特殊的哈希函数,特殊在“局部敏感”上面,如果一个哈希函数具有如下的特点,则它就属于局部敏感性哈希函数:
A hash function is locality-sensitive if its probability of collision is higher for "nearby" points than for points that are "far apart".

  • 举例: 现在有三个点x1, x2, x3, 如果x1和x2相邻, x1和x3相隔较远, 同时Hash(x1)和Hash(x2)碰撞的概率要比Hash(x1)和Hash(x3)碰撞的概率大得多, 我们就说Hash()函数属于LSH, 也就是它保留了原始数据之间的距离属性.

  • LSH有多种算法, Reformer用到的是基于随机投影(random projections)的哈希方法, 实现非常简单. 假设我们希望有b个哈希结果, 也就是哈希函数对应b个桶, 首先创建大小是(embedding_dim, b/2)的矩阵R, 矩阵元素值服从标准正态分布(torch.randn((embedding_dim, b/2))), 向量x的哈希值等于argmax(concat([xR;, -xR])).

  • 在Transformer中, Q, K, V是通过三个矩阵变换后得到的, 由于经过了不同的线性映射, 即使是同一个位置的q_i和k_i都很难保证哈希值相同, 那如何去找相似呢? Reformer的解决方案是令Q = K, 实验证明这样做不会影响模型效果.

  • 当Q = K时, q_i和k_i相同, 则k_i必然属于最相似的K个key, 并且q_i和k_i的点积结果也一定是最大的, 这同样会造成softmax稀疏, 所以LSH的self-attention会mask掉自己, 如下图所示:


  • 在得到每个q_i的哈希值后, 如何快速得到每个q_i的最相似的key集合? 也就是和q_i落入同一个桶内的Seti(q_i) = {q_j, j!=i}? Reformer选择了排序, 也就是下图中的"Sort by LSH bucket", 将落入同一个桶的q排列在一起, 下面就可以进行点积计算了. 这时候又有一个问题, 有的桶元素个数多, 有的桶元素个数少, 如果以桶数据组成batch, 明显各个batch size不同, 不便于批处理, 那就进行分块(chunk)处理, 假设序列长度L, 一共有b个桶, 则平均下来每个桶内有L/b个元素, 考虑到有的桶元素个数会大于L/b, 为了让同一个桶内的q_i都属于Set(q_i)肯定要让每一段序列长度大于L/b, Reformer设置的子序列长度是2L/b, 平均一个子序列包含两个桶的数据, 具体在计算点积时, q_i会在它所在的子序列和前一个子序列中找同一个桶内的q_j进行点积, 然后计算softmax, 再得到v_i.


  • 注意: 考虑到LSH毕竟有误差, 有可能很相似的q_i和q_j没有落入同一个桶内, 那就多来几轮哈希, 将每一轮哈希得到的Set(q_i)取并集作为最终的key集合.

优化四: 可逆(Reversible)残差连接Transformer

  • 可逆残差网络是一种节省内存的技巧, 下图首先展示正常的Transformer中的残差连接:


  • 可逆残差网络的特点是在前向计算过程中不保存bn, ReLU等中间值, 只保存残差块的输出x + F(x)的值, 也就是计算过程在with torch.no_grad()范围内. 优点是不存储中间临时变量, 自然节省了内存. 但是缺点也很明显, 在反向传播过程中如果只记录了x + F(x)的值, 下面的两个问题如何解决呢?
    • 问题一: 如何得到x?
    • 问题二: 如何计算F'(x)?

  • 问题一: 如何得到x?

  • 可逆残差网络对残差块(residual block)进行了修改, 首先将输入x分割为两部分x1和x2, 至于如何分割, 论文中尝试了多种方法, 包括按行、按列、按通道等, 最终发现根据通道来分割效果最好. 也就是x包含C个通道, x1和x2各一半, F和G结构相同, 就是普通的残差块结构, 可逆残差块的输出是concat([y1, y2]), 公式如下:


  • 在反向传播中, x1和x2可以通过如下方法计算得到:


  • 问题二: 如何计算F'(x)?

  • 有了x1和x2, 就可以再计算一遍F和G的函数输出值, 只不过这一次要在with torch.grad()中, 然后计算该模块的权重梯度, 如下图所示:


  • 将Transformer中的残差连接, 改造成可逆残差连接, 做法其实很简单, 计算公式如下:


  • 解析: Reformer如何分割得到X1和X2? 在第一层之前, 将输入copy了一份, 也就是X = X1 = X2, 其余层的X1与X2则对应Y1与Y2. 在前向计算过程中, 每一个block计算都是with toch.no_grad(), 只输出[Y1(i), Y2(i)], 在作为下一层输入计算得到下一层的[Y1(i+1), Y2(i+1)]后就可以del([Y1(i), Y2(i)]). 这样即使Reformer有N层, 在前向过程中我们也只需保留最后一层的[Y1(n), Y2(n)]即可.

  • 注意: 可逆残差块的结构和原始Transformer中的残差块是不同的!!! 原始Transformer中经过残差计算后, 在反向传播中无法还原出X1, X2, 也就是不可逆.



Reformer模型在CV领域的应用

  • 在Reformer中, 这两种新方法的应用使其具有很高的效率, 使其能够仅使用16GB内存在单个GPU上处理长度高达100万字的文本序列. 由于Reformer具有如此高的效率, 它可以直接应用于上下文窗口比几乎所有当前最先进的文本域数据集大得多的数据. 也许Reformer处理如此大的数据集的能力将刺激社区创建它们.

  • 大上下文数据的一个不足之处是图像生成, 因此论文中对图像进行了Reformer的实验. 在这篇论文中, 作者将举例说明如何使用Reformer来"完成"部分图像. 从下图最上面一行的图像片段开始, Reformer可以逐像素地生成全帧图像(紧邻的下面一行, 以此类推):


  • 顶部: 图像片段用作Reformer的输入.

  • 底部: "完全体"的全帧图像的生成.


  • 虽然Reformer在图像和视频任务上的应用潜力巨大, 但在文本上的应用更令人兴奋. Reformer可以一次性在单一的设备中处理整个小说. 将来, 当有更多的数据集需要训练长文本时, Reformer模型可能会使生成长连贯的文本成为可能!


Reformer模型在NLP领域的应用

  • 下载谷歌已经在<< crime and punishment >>上预训练好的模型:
-rw-r--r-- 1 ec2-user ec2-user     1151 Feb 19 04:04 config.json
-rw-r--r-- 1 ec2-user ec2-user 11013576 Feb 19 04:04 pytorch_model.bin
-rw-r--r-- 1 ec2-user ec2-user     1270 Feb 19 04:04 README.md
-rw-r--r-- 1 ec2-user ec2-user 11014133 Feb 19 04:04 rust_model.ot
-rw-r--r-- 1 ec2-user ec2-user   241801 Feb 19 04:04 spiece.model
-rw-r--r-- 1 ec2-user ec2-user   323306 Feb 19 04:04 tokenizer.json

  • 展示Reformer不带语言头的纯粹模型应用:
from transformers import ReformerTokenizer, ReformerModel
import torch


tokenizer = ReformerTokenizer.from_pretrained('./reformer-crime-and-punishment')
model = ReformerModel.from_pretrained('./reformer-crime-and-punishment')

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
output = model( **inputs)

last_hidden_states = output.last_hidden_state
print(last_hidden_states.shape)
print('-------------------------------------')
print(last_hidden_states)

  • 输出结果:
torch.Size([1, 11, 512])
-------------------------------------
tensor([[[-0.5412,  0.9889,  0.6175,  ..., -1.9236,  2.4747, -0.0368],
         [-0.5621,  0.3776, -0.0137,  ..., -0.0730,  1.3191,  0.0162],
         [-1.0310,  0.0343,  0.8384,  ...,  1.4163,  0.6293, -0.1870],
         ...,
         [ 0.6405, -0.1191, -0.2453,  ...,  1.0976, -0.2549, -0.4543],
         [ 0.3263,  0.1651,  0.5058,  ..., -0.1725,  0.5778, -1.6924],
         [ 1.0301,  0.7234, -1.0143,  ...,  0.2901, -0.8706,  1.1020]]],
       grad_fn=<NativeLayerNormBackward>)

  • 展示Reformer带语言头的模型应用:
import torch
from transformers import ReformerTokenizer, ReformerModelWithLMHead


tokenizer = ReformerTokenizer.from_pretrained('./reformer-crime-and-punishment')
model = ReformerModelWithLMHead.from_pretrained('./reformer-crime-and-punishment')

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model( **inputs, labels=inputs["input_ids"])

loss = outputs.loss
logits = outputs.logits
print(loss)
print('---------------------------------------')
print(logits)

  • 输出结果:
tensor(8.4955, grad_fn=<NllLossBackward>)
---------------------------------------
tensor([[[ -7.6074, -18.0610, -17.1554,  ...,  -5.6912,  -9.3542, -13.0199],
         [ -6.1629, -19.8652, -19.5180,  ...,  -5.1157, -10.2902, -11.3418],
         [ -8.0223, -15.7445, -15.9723,  ...,  -9.1973, -11.9560, -14.7962],
         ...,
         [ -5.5516, -15.2925, -15.8361,  ...,  -2.4272,  -7.5874, -12.2855],
         [ -5.0697, -12.5760, -13.3897,  ...,  -9.8692,  -5.3596, -11.3801],
         [ -9.8683, -13.4770, -13.7801,  ..., -10.7918, -10.0629, -10.2269]]],
       grad_fn=<AddBackward0>)


小节总结

  • 本小节学习了Reformer模型的深入细节.

  • 本小节学习了如何加载和应用Reformer模型.