跳转至

3 循环网络层

学习目标

  • 掌握RNN网络原理
  • 掌握PyTorch RNN api

我们前面学习了词嵌入层,可以将文本数据映射为数值向量,进而能够送入到网络进行计算。但是,还存在一个问题,文本数据是具有序列特性的,例如: "我爱你", 这串文本就是具有序列关系的,"爱" 需要在 "我" 之后,"你" 需要在 "爱" 之后, 如果颠倒了顺序,那么可能就会表达不同的意思。

为了能够表示出数据的序列关系我们需要使用循环神经网络(Recurrent Nearal Networks, RNN) 来对数据进行建模,RNN 是一个具有记忆功能的网络,它作用于处理带有序列特点的样本数据。

本小节,我们将会带着大家深入学习 RNN 循环网络层的原理、计算过程,以及在 PyTorch 中如何使用 RNN 层。

1. RNN 网络原理

当我们希望使用循环网络来对 "我爱你" 进行语义提取时,RNN 是如何计算过程是什么样的呢?

上图中 h 表示隐藏状态, 每一次的输入都会有包含两个值: 上一个时间步的隐藏状态、当前状态的输入值,输出当前时间步的隐藏状态。

上图中,为了更加容易理解,虽然我画了 3 个神经元, 但是实际上只有一个神经元,"我爱你" 三个字是重复输入到同一个神经元中。

接下来,我们举个例子来理解上图的工作过程,假设我们要实现文本生成,也就是输入 "我爱" 这两个字,来预测出 "你",其如下图所示:

我们将上图展开成不同时间步的形式,如下图所示:

我们首先初始化出第一个隐藏状态,一般都是全0的一个向量,然后将 "我" 进行词嵌入,转换为向量的表示形式,送入到第一个时间步,然后输出隐藏状态 h1,然后将 h1 和 "爱" 输入到第二个时间步,得到隐藏状态 h2, 将 h2 送入到全连接网络,得到 "你" 的预测概率。

那么,你可能会想,循环网络只能有一个神经元吗?

我们的循环网络网络可以有多个神经元,如下图所示:

我们依次将 "我爱你" 三个字分别送入到每个神经元进行计算,假设词嵌入时,"我爱你" 的维度为 128,经过循环网络之, "我爱你" 三个字的词向量维度就会变成 4. 所以, 我们理解了循环神经网络的的神经元个数会影响到输出的数据维度。

每个神经元内部是如何计算的呢?

上述公式中:

  1. Wih 表示输入数据的权重
  2. bih 表示输入数据的偏置
  3. Whh 表示输入隐藏状态的权重
  4. bhh 表示输入隐藏状态的偏置

最后对输出的结果使用 tanh 激活函数进行计算,得到该神经元你的输出。

2. PyTorch RNN 层的使用

接下来,我们学习 PyTorch 的 RNN 层的用法.

注意: RNN 层输入的数据为三个维度: (seq_len, batch_size, input_size).

示例代码如下:

import torch
import torch.nn as nn


# 1. RNN 送入单个数据
def test01():

    # 输入数据维度 128, 输出维度 256
    rnn = nn.RNN(input_size=128, hidden_size=256)

    # 第一个数字: 表示句子长度
    # 第二个数字: 批量个数
    # 第三个数字: 表示数据维度
    inputs = torch.randn(1, 1, 128)
    hn = torch.zeros(1, 1, 256)

    output, hn = rnn(inputs, hn)
    print(output.shape)
    print(hn.shape)


# 2. RNN层送入批量数据
def test02():

    # 输入数据维度 128, 输出维度 256
    rnn = nn.RNN(input_size=128, hidden_size=256)

    # 第一个数字: 表示句子长度
    # 第二个数字: 批量个数
    # 第三个数字: 表示数据维度
    inputs = torch.randn(1, 32, 128)
    hn = torch.zeros(1, 32, 256)

    output, hn = rnn(inputs, hn)
    print(output.shape)
    print(hn.shape)


if __name__ == '__main__':
    test01()
    test02()

程序输出结果:

torch.Size([1, 1, 256])
torch.Size([1, 1, 256])
torch.Size([1, 32, 256])
torch.Size([1, 32, 256])

3. 小节

在本章节中我们学习了 RNN 层及其原理,并学习了 PyTorch 中 RNN 网络层的基本使用。