5 张量索引操作¶
学习目标¶
- 掌握张量不同索引操作
我们在操作张量时,经常需要去进行获取或者修改操作,掌握张量的花式索引操作是必须的一项能力。
1. 简单行、列索引¶
准备数据
import torch
data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)
程序输出结果:
tensor([[0, 7, 6, 5, 9],
[6, 8, 3, 1, 0],
[6, 3, 8, 7, 3],
[4, 9, 5, 3, 1]])
--------------------------------------------------
# 1. 简单行、列索引
def test01():
print(data[0])
print(data[:, 0])
print('-' * 50)
if __name__ == '__main__':
test01()
程序输出结果:
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------
2. 列表索引¶
# 2. 列表索引
def test02():
# 返回 (0, 1)、(1, 2) 两个位置的元素
print(data[[0, 1], [1, 2]])
print('-' * 50)
# 返回 0、1 行的 1、2 列共4个元素
print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
test02()
程序输出结果:
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
[8, 3]])
3. 范围索引¶
# 3. 范围索引
def test03():
# 前3行的前2列数据
print(data[:3, :2])
# 第2行到最后的前2列数据
print(data[2:, :2])
if __name__ == '__main__':
test03()
程序输出结果:
tensor([[0, 7],
[6, 8],
[6, 3]])
tensor([[6, 3],
[4, 9]])
4. 布尔索引¶
# 布尔索引
def test():
# 第三列大于5的行数据
print(data[data[:, 2] > 5])
# 第二行大于5的列数据
print(data[:, data[1] > 5])
if __name__ == '__main__':
test04()
程序输出结果:
tensor([[0, 7, 6, 5, 9],
[6, 3, 8, 7, 3]])
tensor([[0, 7],
[6, 8],
[6, 3],
[4, 9]])
5. 多维索引¶
# 多维索引
def test05():
data = torch.randint(0, 10, [3, 4, 5])
print(data)
print('-' * 50)
print(data[0, :, :])
print(data[:, 0, :])
print(data[:, :, 0])
if __name__ == '__main__':
test05()
程序输出结果:
tensor([[[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]],
[[9, 7, 5, 3, 1],
[8, 8, 6, 0, 1],
[6, 9, 0, 2, 1],
[9, 7, 0, 4, 0]],
[[0, 7, 3, 5, 6],
[2, 4, 6, 4, 3],
[2, 0, 3, 7, 9],
[9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
[9, 7, 5, 3, 1],
[0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
[9, 8, 6, 9],
[0, 2, 2, 9]])