PyTorch2 Python深度学习 - 张量(Tensor)的定义与操作

锋哥原创的PyTorch2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1eqxNzXEYc

课程介绍

​基于前面的机器学习Scikit-learn,深度学习Tensorflow2课程,我们继续讲解深度学习PyTorch2,所以有些机器学习,深度学习基本概念就不再重复讲解,大家务必学习好前面两个课程。本课程主要讲解基于PyTorch2的深度学习核心知识,主要讲解包括PyTorch2框架入门知识,环境搭建,张量,自动微分,数据加载与预处理,模型训练与优化,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

PyTorch2 Python深度学习 - 张量(Tensor)的定义与操作

张量是一种特殊的数据结构,与数组和矩阵非常相似。在 PyTorch2 中,我们使用张量来定义模型的输入和输出,以及模型的参数。

1,张量的定义

张量是一个具有相同数据类型的元素的多维矩阵。张量的维度被称为"秩"(rank),张量的形状(shape)决定了它的维度。例如:

  • 标量:零维张量(0D)

  • 向量:一维张量(1D)

  • 矩阵:二维张量(2D)

  • 多维数组:三维或更高维的张量(3D+)

2,创建张量

可以通过多种方式来创建张量,常见的有从列表创建,使用 torch.zeros 创建全零张量,使用 torch.ones 创建全一张量,使用 torch.rand 创建随机张量,从 NumPy 数组转换,下面是示例:

复制代码
import torch
​
# 从 Python 列表创建一个一维张量
tensor_1d = torch.tensor([1, 2, 3, 4, 5])
print(tensor_1d)  # 输出:tensor([1, 2, 3, 4, 5])
​
# 从嵌套列表创建二维张量
tensor_2d = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(tensor_2d)  # 输出:tensor([[1, 2], [3, 4], [5, 6]])
​
# 创建一个形状为 (2, 3) 的全零张量
tensor_zeros = torch.zeros(2, 3)
print(tensor_zeros)  # 输出:tensor([[0., 0., 0.], [0., 0., 0.]])
​
# 创建一个形状为 (3, 3) 的全一张量
tensor_ones = torch.ones(3, 3)
print(tensor_ones)  # 输出:tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])
​
# 创建一个形状为 (2, 2) 的随机张量,元素值在 [0, 1) 范围内  torch.randn() 是 PyTorch 中用于生成服从标准正态分布(均值为 0,标准差为 1)的随机张量的函数。
tensor_rand = torch.rand(2, 2)
print(tensor_rand)  # 输出:tensor([[0.1353, 0.7184], [0.5225, 0.8931]])
​
import numpy as np
​
# 创建一个 NumPy 数组
np_array = np.array([1, 2, 3])
​
# 将 NumPy 数组转换为 PyTorch 张量
tensor_from_numpy = torch.tensor(np_array)
print(tensor_from_numpy)  # 输出:tensor([1, 2, 3])

运行输出:

复制代码
tensor([1, 2, 3, 4, 5])
tensor([[1, 2],
        [3, 4],
        [5, 6]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[0.5516, 0.5451],
        [0.2829, 0.1066]])
tensor([1, 2, 3])
3,张量的常见操作

张量的常见操作有张量加法,张量乘法,张量维度变换,张量转置,张量拼接,张量切片等。下面是示例:

复制代码
import torch
​
# 两个张量相加
tensor_a = torch.tensor([1, 2, 3])
tensor_b = torch.tensor([4, 5, 6])
sum_tensor = tensor_a + tensor_b
print(sum_tensor)  # 输出:tensor([5, 7, 9])
​
# 张量元素级别的乘法
product_tensor = tensor_a * tensor_b
print(product_tensor)  # 输出:tensor([4, 10, 18])
​
# 矩阵乘法
matrix_a = torch.tensor([[1, 2], [3, 4]])
matrix_b = torch.tensor([[5, 6], [7, 8]])
matrix_product = torch.matmul(matrix_a, matrix_b)
print(matrix_product)
​
# 改变张量的形状(例如,1D → 2D)
tensor_1d = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped_tensor = tensor_1d.view(2, 3)  # 变换为2行3列的矩阵
print(reshaped_tensor)  # 输出:tensor([[1, 2, 3], [4, 5, 6]])
​
# 对一个二维张量进行转置
tensor_2d = torch.tensor([[1, 2], [3, 4], [5, 6]])
transposed_tensor = tensor_2d.T
print(transposed_tensor)  # 输出:tensor([[1, 3, 5], [2, 4, 6]])
​
# 按维度拼接两个张量
tensor_a = torch.tensor([1, 2])
tensor_b = torch.tensor([3, 4])
concatenated_tensor = torch.cat((tensor_a, tensor_b), dim=0)
print(concatenated_tensor)  # 输出:tensor([1, 2, 3, 4])
​
# 获取张量的切片
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(tensor)
slice_tensor = tensor[1:, 1:]
print(slice_tensor)  # 输出:tensor([[5, 6], [8, 9]])

运行输出:

复制代码
tensor([5, 7, 9])
tensor([ 4, 10, 18])
tensor([[19, 22],
        [43, 50]])
tensor([[1, 2, 3],
        [4, 5, 6]])
tensor([[1, 3, 5],
        [2, 4, 6]])
tensor([1, 2, 3, 4])
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[5, 6],
        [8, 9]])
4,张量的常见属性
  • tensor.shape: 返回张量的形状(即维度)

  • tensor.size(): 返回张量的尺寸,和 .shape 类似

  • tensor.device: 返回张量所在的设备(CPU 或 GPU)

  • tensor.dtype: 返回张量的数据类型

示例:

复制代码
import torch
​
tensor_2d = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
print(tensor_2d.shape)  # 输出:torch.Size([2, 2])
print(tensor_2d.device)  # 输出:cpu
print(tensor_2d.dtype)  # 输出:torch.float32

运行输出:

复制代码
torch.Size([2, 2])
cpu
torch.float32
相关推荐
凡人叶枫1 分钟前
Effective C++ 条款41:了解隐式接口和编译期多态
java·开发语言·c++·effective c++
AC赳赳老秦15 分钟前
用 OpenClaw 搭建服务器故障应急响应系统,自动处理 80% 常见运维故障
android·运维·服务器·python·rxjava·deepseek·openclaw
2601_9547064921 分钟前
云手机技术详解+Python实战调用|2026高稳云手机平台推荐
开发语言·python·智能手机
chushiyunen22 分钟前
java中的路径处理、左右斜杠
java·开发语言·python
jay神1 小时前
基于 FastAPI + Vue 的宠物领养管理系统
前端·vue.js·python·毕业设计·fastapi·宠物
重生之后端学习1 小时前
Java入门
java·开发语言·职场和发展
碧海蓝天20221 小时前
C++法则24:在标准 C++ 中,没有任何可移植的方式判断指针 T* pt 指向的内存位置是否已经 构造了对象,程序员必须手动跟踪哪些元素已构造。
java·开发语言·c++
β添砖java1 小时前
深度学习(22)网络中的网络NiN
人工智能·深度学习
代码不加糖1 小时前
Proxy能够监听到对象中的对象的引用吗?
开发语言·前端·javascript
charlie1145141911 小时前
现代C++指南:Lambda,让我们用另一种方式持有函数
开发语言·c++