跳转至

3 张量类型转换

学习目标

  • 掌握张量类型转换方法

张量的类型转换也是经常使用的一种操作,是必须掌握的知识点。在本小节,我们主要学习如何将 numpy 数组和 PyTorch Tensor 的转化方法.

1. 张量转换为 numpy 数组

使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。

# 1. 将张量转换为 numpy 数组
def test01():

    data_tensor = torch.tensor([2, 3, 4])
    # 使用张量对象中的 numpy 函数进行转换
    data_numpy = data_tensor.numpy()

    print(type(data_tensor))
    print(type(data_numpy))

    # 注意: data_tensor 和 data_numpy 共享内存
    # 修改其中的一个,另外一个也会发生改变
    # data_tensor[0] = 100
    data_numpy[0] = 100

    print(data_tensor)
    print(data_numpy)


# 2. 对象拷贝避免共享内存
def test02():

    data_tensor = torch.tensor([2, 3, 4])
    # 使用张量对象中的 numpy 函数进行转换
    data_numpy = data_tensor.numpy()

    print(type(data_tensor))
    print(type(data_numpy))

    # 注意: data_tensor 和 data_numpy 共享内存
    # 修改其中的一个,另外一个也会发生改变
    # data_tensor[0] = 100
    data_numpy[0] = 100

    print(data_tensor)
    print(data_numpy)

2. numpy 转换为张量

  1. 使用 from_numpy 可以将 ndarray 数组转换为 Tensor,默认共享内存,使用 copy 函数避免共享。
  2. 使用 torch.tensor 可以将 ndarray 数组转换为 Tensor,默认不共享内存。
# 1. 使用 from_numpy 函数
def test01():

    data_numpy = np.array([2, 3, 4])
    # 将 numpy 数组转换为张量类型
    # 1. from_numpy
    # 2. torch.tensor(ndarray)

    # 浅拷贝
    data_tensor = torch.from_numpy(data_numpy)

    # nunpy 和 tensor 共享内存
    # data_numpy[0] = 100
    data_tensor[0] = 100

    print(data_tensor)
    print(data_numpy)


# 2. 使用 torch.tensor 函数
def test02():

    data_numpy = np.array([2, 3, 4])

    data_tensor = torch.tensor(data_numpy)

    # nunpy 和 tensor 不共享内存
    # data_numpy[0] = 100
    data_tensor[0] = 100

    print(data_tensor)
    print(data_numpy)

3. 标量张量和数字的转换

对于只有一个元素的张量,使用 item 方法将该值从张量中提取出来。

# 3. 标量张量和数字的转换
def test03():

    # 当张量只包含一个元素时, 可以通过 item 函数提取出该值
    data = torch.tensor([30,])
    print(data.item())

    data = torch.tensor(30)
    print(data.item())


if __name__ == '__main__':
    test03()

程序输出结果:

30
30

小节

在本小节中, 我们主要学习了 numpy 和 tensor 互相转换的规则, 以及标量张量与数值之间的转换规则。