使用 PyTorch Tensor 的相关数据处理

目录

一、前言

二、Tensor是什么

三、Tensor的创建方式

[3.1 直接创建](#3.1 直接创建)

[3.2 全0 / 全1](#3.2 全0 / 全1)

[3.3 随机数](#3.3 随机数)

[3.4 指定范围](#3.4 指定范围)

[3.5 等间隔数值](#3.5 等间隔数值)

四、Tensor索引与切片

[4.1 基础索引](#4.1 基础索引)

[4.2 切片操作](#4.2 切片操作)

[4.3 子矩阵提取](#4.3 子矩阵提取)

[4.4 条件筛选](#4.4 条件筛选)

五、Tensor形状变换(非常重要)

[5.1 查看形状](#5.1 查看形状)

[5.2 reshape / view](#5.2 reshape / view)

[5.3 flatten(展平)](#5.3 flatten(展平))

[5.4 transpose(转置)](#5.4 transpose(转置))

[5.5 permute(维度重排)](#5.5 permute(维度重排))

六、Tensor拼接与拆分

[6.1 concat拼接](#6.1 concat拼接)

[6.2 stack堆叠](#6.2 stack堆叠)

[6.3 split拆分](#6.3 split拆分)

[6.4 chunk平均切分](#6.4 chunk平均切分)

七、Tensor广播机制(非常重要)

[7.1 什么是广播](#7.1 什么是广播)

[7.2 示例](#7.2 示例)

[7.3 广播规则](#7.3 广播规则)

八、Tensor数学运算

[8.1 基础运算](#8.1 基础运算)

[8.2 矩阵乘法](#8.2 矩阵乘法)

[8.3 点积](#8.3 点积)

[8.4 求和 / 均值](#8.4 求和 / 均值)

[8.5 最大最小值](#8.5 最大最小值)

九、Tensor类型转换

[9.1 dtype转换](#9.1 dtype转换)

[9.2 numpy互转](#9.2 numpy互转)

[9.3 item提取标量](#9.3 item提取标量)

十、GPU加速操作

[10.1 检查GPU](#10.1 检查GPU)

[10.2 放到GPU](#10.2 放到GPU)

[10.3 模型与数据统一设备](#10.3 模型与数据统一设备)

十一、Tensor梯度相关操作

[11.1 开启梯度](#11.1 开启梯度)

[11.2 计算梯度](#11.2 计算梯度)

[11.3 查看梯度](#11.3 查看梯度)

[11.4 禁用梯度](#11.4 禁用梯度)

十二、常见Tensor实战技巧

[12.1 扩展维度](#12.1 扩展维度)

[12.2 删除维度](#12.2 删除维度)

[12.3 归一化](#12.3 归一化)

[12.4 one-hot编码](#12.4 one-hot编码)

[12.5 mask操作](#12.5 mask操作)

十三、Tensor在深度学习中的作用

十四、完整示例:数据处理流水线

十五、总结


在深度学习中,几乎所有计算都离不开一个核心数据结构:

复制代码
Tensor(张量)

无论是:

复制代码
图像数据

文本向量

音频信号

模型参数

本质上都可以表示为 Tensor。

可以简单理解为:

复制代码
Tensor = 多维数组 + GPU加速能力 + 自动求导支持

本文将系统讲解 PyTorch 中 Tensor 的常见数据处理方法,包括:

复制代码
Tensor创建

索引与切片

形状变换

拼接与拆分

广播机制

数值计算

类型转换

GPU操作

常见实战技巧

二、Tensor是什么

Tensor 是 PyTorch 的核心数据结构:

复制代码
0维:标量(scalar)
1维:向量(vector)
2维:矩阵(matrix)
3维及以上:高维张量

示例:

python 复制代码
import torch

a = torch.tensor(3)          # 0维
b = torch.tensor([1,2,3])    # 1维
c = torch.tensor([[1,2],[3,4]])  # 2维

三、Tensor的创建方式

3.1 直接创建

复制代码
t = torch.tensor([[1,2,3],[4,5,6]])

3.2 全0 / 全1

python 复制代码
torch.zeros(3,3)

torch.ones(2,4)

3.3 随机数

复制代码
torch.rand(3,3)   # 0~1均匀分布

torch.randn(3,3)  # 正态分布

3.4 指定范围

python 复制代码
torch.arange(0, 10, 2)

输出:

复制代码
[0, 2, 4, 6, 8]

3.5 等间隔数值

复制代码
torch.linspace(0, 1, 5)

四、Tensor索引与切片

Tensor支持类似 NumPy 的操作。

4.1 基础索引

python 复制代码
x = torch.tensor([[1,2,3],
                  [4,5,6]])

print(x[0,1])  # 2

4.2 切片操作

复制代码
x[:, 1]

表示:

复制代码
取所有行,第2列

4.3 子矩阵提取

复制代码
x[0:2, 1:3]

4.4 条件筛选

复制代码
x[x > 3]

输出:

复制代码
[4,5,6]

五、Tensor形状变换(非常重要)

5.1 查看形状

复制代码
x.shape

5.2 reshape / view

python 复制代码
x = torch.arange(12)

x.reshape(3,4)

或:

复制代码
x.view(3,4)

5.3 flatten(展平)

复制代码
x.flatten()

5.4 transpose(转置)

复制代码
x.T

或:

复制代码
x.transpose(0,1)

5.5 permute(维度重排)

图像数据常用:

python 复制代码
x = torch.randn(1,3,224,224)

x = x.permute(0,2,3,1)

变为:

复制代码
NHWC格式

六、Tensor拼接与拆分

6.1 concat拼接

复制代码
torch.cat([a,b], dim=0)

6.2 stack堆叠

复制代码
torch.stack([a,b], dim=0)

区别:

复制代码
cat:维度不变

stack:增加新维度

6.3 split拆分

复制代码
torch.split(x, 2, dim=1)

6.4 chunk平均切分

复制代码
torch.chunk(x, 3, dim=0)

七、Tensor广播机制(非常重要)

7.1 什么是广播

不同形状Tensor可以自动计算:

复制代码
小维度自动扩展

7.2 示例

复制代码
a = torch.tensor([[1,2,3],
                  [4,5,6]])

b = torch.tensor([1,2,3])

计算:

复制代码
a + b

结果:

复制代码
每一行都加b

7.3 广播规则

复制代码
维度对齐
小维度自动补1

八、Tensor数学运算

8.1 基础运算

复制代码
a + b
a - b
a * b
a / b

8.2 矩阵乘法

复制代码
torch.matmul(a, b)

或:

复制代码
a @ b

8.3 点积

复制代码
torch.dot(a, b)

8.4 求和 / 均值

复制代码
x.sum()

x.mean()

8.5 最大最小值

复制代码
x.max()

x.min()

九、Tensor类型转换

9.1 dtype转换

复制代码
x.float()

x.int()

9.2 numpy互转

python 复制代码
import numpy as np

a = x.numpy()

x = torch.from_numpy(a)

9.3 item提取标量

复制代码
x.item()

十、GPU加速操作

10.1 检查GPU

复制代码
torch.cuda.is_available()

10.2 放到GPU

python 复制代码
device = torch.device("cuda")

x = x.to(device)

10.3 模型与数据统一设备

复制代码
model = model.to(device)

十一、Tensor梯度相关操作

11.1 开启梯度

python 复制代码
x = torch.tensor([1.0,2.0], requires_grad=True)

11.2 计算梯度

python 复制代码
y = x.sum()

y.backward()

11.3 查看梯度

复制代码
x.grad

11.4 禁用梯度

python 复制代码
with torch.no_grad():
    y = x * 2

十二、常见Tensor实战技巧


12.1 扩展维度

复制代码
x.unsqueeze(0)

用于:

复制代码
增加batch维度

12.2 删除维度

python 复制代码
x.squeeze()

12.3 归一化

python 复制代码
x = (x - x.mean()) / x.std()

12.4 one-hot编码

python 复制代码
torch.nn.functional.one_hot(
    torch.tensor([0,1,2]),
    num_classes=3
)

12.5 mask操作

复制代码
x[x > 0] = 1

十三、Tensor在深度学习中的作用

Tensor贯穿整个训练流程:

复制代码
数据输入 → Tensor
特征提取 → Tensor
卷积运算 → Tensor
损失计算 → Tensor
反向传播 → Tensor
参数更新 → Tensor

可以说:

复制代码
Tensor = 深度学习的"基础语言"

十四、完整示例:数据处理流水线

python 复制代码
import torch

x = torch.randn(4,3,32,32)

# 归一化
x = (x - x.mean()) / x.std()

# 改变形状
x = x.view(4, -1)

# 拼接
x2 = torch.randn(4,96)
x = torch.cat([x, x2], dim=1)

# GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)

print(x.shape)

十五、总结

PyTorch Tensor 是深度学习的核心数据结构,其能力可以概括为:

复制代码
多维表达能力
高效计算能力
GPU加速能力
自动求导能力

常见数据处理能力包括:

复制代码
创建 Tensor
索引与切片
形状变换
拼接与拆分
广播计算
数学运算
类型转换
GPU操作
梯度计算

掌握 Tensor 操作,就等于掌握了 PyTorch 的"基础语法体系"。

可以说:

深度学习不是在操作模型,而是在操作 Tensor;模型只是 Tensor 计算规则的组织方式。

相关推荐
Alluxio1 小时前
Alluxio AI 3.9 正式发布:为任意 AI 训练框架提供 checkpoint 加速能力
人工智能
如烟花的信页1 小时前
易盾点选逆向分析
javascript·爬虫·python·js逆向
诺云小星1 小时前
GEO时代已开启:品牌如何获得AI推荐?
人工智能
youcans_1 小时前
【跟我学 AI 编程】(6) Claude Code 与 IDE 的集成
ide·人工智能·ai编程·claude code
X54先生(人文科技)2 小时前
《元创力》纪实录·桥段陶罐的测绘:当“表演性安全”吞噬星辰
人工智能·开源·开源协议·零知识证明
czzxxxxxx2 小时前
创客匠人AI智能体:知识付费的效率革命与未来图景
人工智能
OpenCSG2 小时前
Cosmos3:NVIDIA 把世界模型做成了“理解、生成、模拟、行动”的统一入口
人工智能·大模型·nvidia·opencsg
IvorySQL2 小时前
PostgreSQL 技术日报 (6月1日)|逻辑复制问题修复,AI 行业动态速览
数据库·人工智能·postgresql
金銀銅鐵2 小时前
用 Tkinter 实现一个简单的罗马数字转化工具
后端·python