跳转至

7 张量运算函数

学习目标

  • 掌握张量相关运算函数

1. 常见运算函数

PyTorch 为每个张量封装很多实用的计算函数,例如计算均值、平方根、求和等等

import torch


def test():

    data = torch.randint(0, 10, [2, 3], dtype=torch.float64)
    print(data)
    print('-' * 50)

    # 1. 计算均值
    # 注意: tensor 必须为 Float 或者 Double 类型
    print(data.mean())
    print(data.mean(dim=0))  # 按列计算均值
    print(data.mean(dim=1))  # 按行计算均值
    print('-' * 50)

    # 2. 计算总和
    print(data.sum())
    print(data.sum(dim=0))
    print(data.sum(dim=1))
    print('-' * 50)

    # 3. 计算平方
    print(data.pow(2))
    print('-' * 50)

    # 4. 计算平方根
    print(data.sqrt())
    print('-' * 50)

    # 5. 指数计算, e^n 次方
    print(data.exp())
    print('-' * 50)

    # 6. 对数计算
    print(data.log())  # 以 e 为底
    print(data.log2())
    print(data.log10())


if __name__ == '__main__':
    test()

程序运行结果:

tensor([[4., 0., 7.],
        [6., 3., 5.]], dtype=torch.float64)
--------------------------------------------------
tensor(4.1667, dtype=torch.float64)
tensor([5.0000, 1.5000, 6.0000], dtype=torch.float64)
tensor([3.6667, 4.6667], dtype=torch.float64)
--------------------------------------------------
tensor(25., dtype=torch.float64)
tensor([10.,  3., 12.], dtype=torch.float64)
tensor([11., 14.], dtype=torch.float64)
--------------------------------------------------
tensor([[16.,  0., 49.],
        [36.,  9., 25.]], dtype=torch.float64)
--------------------------------------------------
tensor([[2.0000, 0.0000, 2.6458],
        [2.4495, 1.7321, 2.2361]], dtype=torch.float64)
--------------------------------------------------
tensor([[5.4598e+01, 1.0000e+00, 1.0966e+03],
        [4.0343e+02, 2.0086e+01, 1.4841e+02]], dtype=torch.float64)
--------------------------------------------------
tensor([[1.3863,   -inf, 1.9459],
        [1.7918, 1.0986, 1.6094]], dtype=torch.float64)
tensor([[2.0000,   -inf, 2.8074],
        [2.5850, 1.5850, 2.3219]], dtype=torch.float64)
tensor([[0.6021,   -inf, 0.8451],
        [0.7782, 0.4771, 0.6990]], dtype=torch.float64)