跳转至

4 张量拼接操作

学习目标

  • 掌握torch.cat torch.stack使用

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在后面将要学习到的残差网络、注意力机制中都使用到了张量拼接。

1. torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

import torch


def test():

    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])

    print(data1)
    print(data2)
    print('-' * 50)

    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)

    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)

    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

程序输出结果:

tensor([[[6, 8, 3, 5],
         [1, 1, 3, 8],
         [9, 0, 4, 4],
         [1, 4, 7, 0],
         [5, 1, 4, 8]],

        [[0, 1, 4, 4],
         [4, 1, 8, 7],
         [5, 2, 6, 6],
         [2, 6, 1, 6],
         [0, 7, 8, 9]],

        [[0, 6, 8, 8],
         [5, 4, 5, 8],
         [3, 5, 5, 9],
         [3, 5, 2, 4],
         [3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],
         [0, 1, 8, 2],
         [4, 9, 9, 8],
         [5, 1, 5, 9],
         [9, 4, 3, 0]],

        [[7, 6, 3, 3],
         [4, 3, 3, 2],
         [2, 1, 1, 1],
         [3, 0, 8, 2],
         [8, 6, 6, 5]],

        [[0, 7, 2, 4],
         [4, 3, 8, 3],
         [4, 2, 1, 9],
         [4, 2, 8, 9],
         [3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],
         [1, 1, 3, 8, 0, 1, 8, 2],
         [9, 0, 4, 4, 4, 9, 9, 8],
         [1, 4, 7, 0, 5, 1, 5, 9],
         [5, 1, 4, 8, 9, 4, 3, 0]],

        [[0, 1, 4, 4, 7, 6, 3, 3],
         [4, 1, 8, 7, 4, 3, 3, 2],
         [5, 2, 6, 6, 2, 1, 1, 1],
         [2, 6, 1, 6, 3, 0, 8, 2],
         [0, 7, 8, 9, 8, 6, 6, 5]],

        [[0, 6, 8, 8, 0, 7, 2, 4],
         [5, 4, 5, 8, 4, 3, 8, 3],
         [3, 5, 5, 9, 4, 2, 1, 9],
         [3, 5, 2, 4, 4, 2, 8, 9],
         [3, 8, 1, 1, 3, 7, 0, 8]]])

2. torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

import torch


def test():

    data1= torch.randint(0, 10, [2, 3])
    data2= torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)

    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

程序输出结果:

tensor([[5, 8, 7],
        [6, 0, 6]])
tensor([[5, 8, 0],
        [9, 0, 1]])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
tensor([[[5, 5],
         [8, 8],
         [7, 0]],

        [[6, 9],
         [0, 0],
         [6, 1]]])

3. 小节

张量的拼接操作也是在后面我们经常使用一种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。