跳转至

5 张量索引操作

学习目标

  • 掌握张量不同索引操作

我们在操作张量时,经常需要去进行获取或者修改操作,掌握张量的花式索引操作是必须的一项能力。

1. 简单行、列索引

准备数据

import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

程序输出结果:

tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果:

tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

2. 列表索引

# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

程序输出结果:

tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3. 范围索引

# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

程序输出结果:

tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

4. 布尔索引

# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

程序输出结果:

tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

5. 多维索引

# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

程序输出结果:

tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])