Tensor 基本操作1 | PyTorch 深度学习实战

目录

创建 Tensor

使用 Torch 接口创建 Tensor

import torch

参考:https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

常用操作

unsqueeze

将多维数组解套,并嵌入新的一层维度。

    data = [[1, 2],[3, 4]]
    x_data = torch.tensor(data)
    print("x_data")
    print(x_data)

    x2_data = x_data.unsqueeze(-1)
    print("x_data>> unsqueeze -1")
    print(x2_data)

    x2_data = x_data.unsqueeze(0)
    print("x_data>> unsqueeze 0")
    print(x2_data)

    x2_data = x_data.unsqueeze(1)
    print("x_data>> unsqueeze 1")
    print(x2_data)

    x2_data = x_data.unsqueeze(2)
    print("x_data>> unsqueeze 2")
    print(x2_data)

结果:

x_data
tensor([[1, 2],
        [3, 4]])
x_data>> unsqueeze -1   # -1 代表最内层,将最内层的数用一个新的维度包起来
tensor([[[1],
         [2]],

        [[3],
         [4]]])
x_data>> unsqueeze 0 # 0 代表最外层,将原来的多维数组整个多套一层
tensor([[[1, 2],
         [3, 4]]])
x_data>> unsqueeze 1 # 代表原来第一维里的每个元素,套一层
tensor([[[1, 2]],

        [[3, 4]]])
x_data>> unsqueeze 2 # 代表原来第二维里的每个元素,套一层
tensor([[[1],        # 当前一共两维,所以效果和 -1 一样
         [2]],

        [[3],
         [4]]])

squeeze

去掉指定或全部的维度中只有一个元素的多维数组。

比如输入为 Ax1xBxCx1xD 维的数组,输出变成了 AxBxCxD 维的数组。

https://pytorch.org/docs/stable/generated/torch.squeeze.html

    data = [[1], [2],[3], [4]]
    x_data = torch.tensor(data)
    print("x_data")
    print(x_data)

    x2_data = x_data.squeeze()
    print("x_data>> squeeze")
    print(x2_data)

    x2_data = x_data.squeeze(1)
    print("x_data>> squeeze 1")
    print(x2_data)

结果:

x_data
tensor([[1],
        [2],
        [3],
        [4]])
x_data>> squeeze
tensor([1, 2, 3, 4])
x_data>> squeeze 1
tensor([1, 2, 3, 4])

Softmax

https://pytorch.org/docs/stable/generated/torch.softmax.html

归一化操作。

代码1
    data = torch.tensor([1,2,3], dtype=torch.float) # 维度 3; 注意,此处 dtype 是 int 或 long 接口报错
    x_data = torch.softmax(data, 0)
    print("x_data")
    print(x_data)

结果:

x_data
tensor([0.0900, 0.2447, 0.6652])  # 维度 3
代码2
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1
    x_data2 = torch.softmax(data, 0)
    print("x_data2")
    print(x_data2)

结果:

x_data2  # 维度 3x1
tensor([[0.0900],
        [0.2447],
        [0.6652]])
代码3
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1
    x_data2 = torch.softmax(data, 1) # 沿着第一维求
    print("x_data2")
    print(x_data2)

结果:

x_data2
tensor([[1.],
        [1.],
        [1.]])

此时,每维都是 1 个元素,针对自身求 softmax,所以,结果是 1.

argmax

https://pytorch.org/docs/stable/generated/torch.argmax.html

返回一个多维数组的最大值的索引,如果是多维数组,则返回第一维的索引。

item

https://pytorch.org/docs/stable/generated/torch.Tensor.item.html

返回一个 Tensor 中携带的 Python Number 对象。该接口只对 Tensor 是一维的有效。

x = torch.tensor([1.0])
x.item()
相关推荐
XianxinMao6 分钟前
AI发展困境:技术路径与实践约束的博弈
人工智能·语言模型
池央6 分钟前
WGAN - 瓦萨斯坦生成对抗网络
人工智能·神经网络·生成对抗网络
WPG大大通10 分钟前
Pytorch - YOLOv11自定义资料训练
人工智能·机器学习·计算机视觉·视觉检测·大大通
LDG_AGI10 分钟前
【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)
人工智能·深度学习
点云SLAM27 分钟前
CVPR 2024 图像、视频处理总汇(视频字幕、图像超分辨率、图像分类和压缩等)
图像处理·深度学习·计算机视觉·视频处理·3dgs·cvpr2024
OceanBase数据库官方博客35 分钟前
阳振坤:AI 大模型的基础是数据,AI越发达,数据库价值越大
数据库·人工智能·oceanbase·分布式数据库
正在走向自律38 分钟前
Text2Sql:开启自然语言与数据库交互新时代(30/30)
数据库·人工智能·oracle·text2sql·ai智能体
深图智能1 小时前
PyTorch使用教程(11)-cuda的使用方法
人工智能·pytorch·python·深度学习
唯余木叶下弦声1 小时前
Python人脸识别库DeepFace使用教程及源码解析
开发语言·人工智能·python
小胖学前端2 小时前
AIHawk:AI驱动的自动化求职助手,帮你轻松找到理想工作
人工智能