跳转至

7.5 XLNet模型深度解析

XLNet模型

XLNet模型源代码分析

  • XLNet模型注意力机制代码分析:
class XLNetRelativeAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        if config.d_model % config.n_head != 0:
            raise ValueError(
                f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
                f"heads ({config.n_head}"
            )

        self.n_head = config.n_head
        self.d_head = config.d_head
        self.d_model = config.d_model
        self.scale = 1 / (config.d_head**0.5)

        self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
        self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
        self.v = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
        self.o = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
        self.r = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))

        self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
        self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
        self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
        self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))

        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

    def prune_heads(self, heads):
        raise NotImplementedError

    @staticmethod
    def rel_shift(x, klen=-1):
        """perform relative shift to form the relative attention score."""
        x_size = x.shape

        x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
        x = x[1:, ...]
        x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
        # x = x[:, 0:klen, :, :]
        x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))

        return x

    @staticmethod
    def rel_shift_bnij(x, klen=-1):
        x_size = x.shape

        x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
        x = x[:, :, 1:, :]
        x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
        # Note: the tensor-slice form was faster in my testing than torch.index_select
        #       However, tracing doesn't like the nature of the slice, and if klen changes
        #       during the run then it'll fail, whereas index_select will be fine.
        x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
        # x = x[:, :, :, :klen]

        return x

  • XLNet模型的前馈全连接层代码分析:
class XLNetFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.layer_1 = nn.Linear(config.d_model, config.d_inner)
        self.layer_2 = nn.Linear(config.d_inner, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        if isinstance(config.ff_activation, str):
            self.activation_function = ACT2FN[config.ff_activation]
        else:
            self.activation_function = config.ff_activation

    def forward(self, inp):
        output = inp
        output = self.layer_1(output)
        output = self.activation_function(output)
        output = self.dropout(output)
        output = self.layer_2(output)
        output = self.dropout(output)
        output = self.layer_norm(output + inp)
        return output

  • XLNet模型的网络层代码分析:
class XLNetLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.rel_attn = XLNetRelativeAttention(config)
        self.ff = XLNetFeedForward(config)
        self.dropout = nn.Dropout(config.dropout)
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1

    def forward(
        self,
        output_h,
        output_g,
        attn_mask_h,
        attn_mask_g,
        r,
        seg_mat,
        mems=None,
        target_mapping=None,
        head_mask=None,
        output_attentions=False,
    ):
        outputs = self.rel_attn(
            output_h,
            output_g,
            attn_mask_h,
            attn_mask_g,
            r,
            seg_mat,
            mems=mems,
            target_mapping=target_mapping,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        output_h, output_g = outputs[:2]

        if output_g is not None:
            output_g = apply_chunking_to_forward(
                self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
            )
        output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)

        outputs = (output_h, output_g) + outputs[2:]  # Add again attentions if there are there
        return outputs

    def ff_chunk(self, output_x):
        output_x = self.ff(output_x)
        return output_x


XLNet模型的应用

  • 在头条投满分项目中, 将预训练模型替换成XLNet, 检验模型在分类任务中的表现:
# coding: UTF-8
import numpy as np
import torch
import time
from train_eval import train_kd, train
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif, build_dataset_CNN

parser = argparse.ArgumentParser(description="Chinese Text Classification")
parser.add_argument("--task", type=str, required=True, help="choose a task: xlnet, or kd")
args = parser.parse_args()

if __name__ == "__main__":
    dataset = "toutiao"

    if args.task == "xlnet":
        model_name = "xlnet"
        x = import_module("models." + model_name)
        config = x.Config(dataset)
        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed_all(1)
        torch.backends.cudnn.deterministic = True  # 保证每次结果一样

        print("Loading data for XLNet Model...")
        train_data, dev_data, test_data = build_dataset(config)
        train_iter = build_iterator(train_data, config)
        dev_iter = build_iterator(dev_data, config)
        test_iter = build_iterator(test_data, config)

        model = x.Model(config).to(config.device)
        train(config, model, train_iter, dev_iter, test_iter)

  • 输出结果:
Loading data for XLNet Model...
180000it [00:27, 6663.30it/s]
10000it [00:01, 6824.81it/s]
10000it [00:01, 6846.25it/s]
14%|█████████▏                                                       | 200/1407 [04:44<29:31,  1.47s/it]Iter:    200,  Train Loss:  0.48,  Train Acc: 84.38%,  Val Loss:  0.31,  Val Acc: 90.46%,  Time: 0:05:27 *
 28%|██████████████████▍                                              | 400/1407 [10:18<24:29,  1.46s/it]Iter:    400,  Train Loss:  0.43,  Train Acc: 85.16%,  Val Loss:  0.27,  Val Acc: 91.31%,  Time: 0:11:06 *
 43%|███████████████████████████▋                                     | 600/1407 [15:57<19:40,  1.46s/it]Iter:    600,  Train Loss:  0.23,  Train Acc: 91.41%,  Val Loss:  0.22,  Val Acc: 92.40%,  Time: 0:16:45 *
 57%|████████████████████████████████████▉                            | 800/1407 [21:36<14:50,  1.47s/it]Iter:    800,  Train Loss:  0.12,  Train Acc: 92.19%,  Val Loss:  0.22,  Val Acc: 92.78%,  Time: 0:22:24 *
 71%|█████████████████████████████████████████████▍                  | 1000/1407 [27:14<09:55,  1.46s/it]Iter:   1000,  Train Loss:  0.18,  Train Acc: 92.97%,  Val Loss:  0.24,  Val Acc: 92.25%,  Time: 0:27:55 
No optimization for a long time, auto-stopping...
Test Loss:  0.21,  Test Acc: 93.35%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.9406    0.9190    0.9297      1000
       realty     0.9595    0.9230    0.9409      1000
       stocks     0.8904    0.8940    0.8922      1000
    education     0.9593    0.9660    0.9626      1000
      science     0.9086    0.9050    0.9068      1000
      society     0.9277    0.9360    0.9318      1000
     politics     0.9114    0.9160    0.9137      1000
       sports     0.9675    0.9810    0.9742      1000
         game     0.9474    0.9360    0.9416      1000
entertainment     0.9239    0.9590    0.9411      1000

     accuracy                         0.9335     10000
    macro avg     0.9336    0.9335    0.9335     10000
 weighted avg     0.9336    0.9335    0.9335     10000

Confusion Matrix...
[[919  10  40   3   8   3  10   3   1   3]
 [ 10 923   9   2   5  15   9   3   9  15]
 [ 40  16 894   1  22   0  22   1   1   3]
 [  2   1   3 966   0   9  10   1   1   7]
 [  1   2  22   7 905  13  13   2  23  12]
 [  1   4   3  14   5 936  19   2   3  13]
 [  3   3  25  11   9  26 916   1   1   5]
 [  0   3   3   0   0   2   1 981   1   9]
 [  1   0   3   0  37   4   3   4 936  12]
 [  0   0   2   3   5   1   2  16  12 959]]
Time usage: 0:00:40

  • 结论: 将预训练模型替换为XLNet后, 在测试集上得到93.35%的F1分数表现, 相比较于其他几个专门处理分类任务的模型来说并不突出, 说明了XLNet更适合处理超长文本, 并且更优秀的表现是在生成式任务上.


小节总结

  • 本小节学习了XLNet的核心源代码.

  • 本小节学习了XLNet模型在多分类任务中的应用, 并完成了实际代码的编写和测试.