小土堆pytorch

Dataset的使用

  • torch.utils.data.Dataset :数据集的抽象类,需要自定义并实现 __len__(数据集大小)和 __getitem__(按索引获取样本)。

tensorboard 主要是训练过程可视化

transform 图像转化

图片->tensor

python 复制代码
trans = transforms.ToTensor()

用法:首先要创建一个自己的工具 然后再使用

即:先定义再调用

torchvision

dataloader用于从 Dataset 中按批次(batch)加载数据

  • batch_size: 每次加载的样本数量。
  • shuffle: 是否对数据进行洗牌,通常训练时需要将数据打乱。
  • drop_last : 如果数据集中的样本数不能被 batch_size 整除,设置为 True 时,丢弃最后一个不完整的 batch

torch.nn

super是为了告诉 Python "先执行父类的初始化代码,然后再执行我的初始化代码"

卷积操作:使用卷积核(Kernel)在输入图像上滑动,提取特征,生成特征图(Feature Maps)。

一个矩阵+卷机核 通过相乘然后累加

stride就是步长 padding是填充

关于in_channel和out_channel

池化:通常在卷积层之后,通过最大池化或平均池化减少特征图的尺寸,同时保留重要特征,生成池化特征图(Pooled Feature Maps)。

sequential 方便建立层级结构

损失函数与反向传播

损失函数衡量目标和输出的差距 反向传播通过计算梯度 来更新参数实现loss最小化

模型保存

训练套路:准备数据、加载数据、准备模型、设置损失函数、设置优化器、开始训练、最后验证、结果聚合展示

Dataset的使用

torch.utils.data.Dataset是PyTorch中用于表示数据集的抽象类,自定义数据集需继承该类并实现__len____getitem__方法。

TensorBoard训练可视化

TensorBoard通过记录训练过程中的标量(如损失、准确率)、图像、计算图等实现可视化。

复制代码
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")  # 日志目录
writer.add_scalar("Loss/train", loss.item(), epoch)  # 记录损失
writer.close()

图像转换(Transform)

torchvision.transforms提供常见的图像预处理操作,需将图像转换为张量(Tensor)并归一化:

复制代码
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # PIL图像或NumPy数组 -> Tensor
    transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化
])

DataLoader加载数据

torch.utils.data.DataLoader实现批量加载和数据打乱:

  • batch_size:每批次样本数。

  • shuffle:训练时通常设为True以打乱数据。

  • drop_last:当样本数不能被batch_size整除时,丢弃末尾不完整批次。

    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=False)

继承与super()

在自定义模型时,super().__init__()确保父类的初始化逻辑优先执行:

复制代码
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()  # 调用父类nn.Module的初始化
        self.conv = nn.Conv2d(3, 64, kernel_size=3)

卷积操作

  • 卷积核(Kernel):滑动窗口提取特征,生成特征图。
  • Stride:滑动步长,影响输出尺寸。
  • Padding:边缘填充,控制输出尺寸。
  • in_channels/out_channels:输入/输出的通道数。

池化层

减少特征图尺寸,保留重要特征:

  • 最大池化(MaxPooling):取窗口内最大值。
  • 平均池化(AvgPooling):取窗口内平均值。

Sequential构建层级结构

nn.Sequential简化网络层堆叠:

复制代码
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(2)
)

损失函数与反向传播

  • 损失函数 :如nn.CrossEntropyLoss衡量预测与目标的差距。
  • 反向传播 :调用loss.backward()计算梯度,优化器通过optimizer.step()更新参数。

模型保存与加载

保存模型参数或整个模型:

复制代码
torch.save(model.state_dict(), "model.pth")  # 仅保存参数
model.load_state_dict(torch.load("model.pth"))  # 加载参数

训练流程模板

  1. 准备数据 :定义DatasetDataLoader
  2. 构建模型 :继承nn.Module或使用Sequential
  3. 设置损失函数与优化器 :如nn.CrossEntropyLosstorch.optim.SGD
  4. 训练循环:前向传播、计算损失、反向传播、参数更新。
  5. 验证与测试:评估模型在验证集/测试集上的表现。
  6. 可视化:使用TensorBoard记录关键指标。
相关推荐
FriendshipT2 小时前
图像分割:PyTorch从零开始实现SegFormer语义分割
人工智能·pytorch·python·深度学习·目标检测·语义分割·实例分割
shelter -唯2 小时前
基于selenium库的爬虫实战:京东手机数据爬取
爬虫·python·selenium
月疯3 小时前
FLASK与JAVA的文件互传并带参数以及流上传(单文件互传亲测)
java·python·flask
雨夜的星光3 小时前
PyCharm 核心快捷键大全 (Windows版)
ide·python·pycharm
my烂笔头3 小时前
cv领域接地气的方向
人工智能·深度学习·计算机视觉
Stream_Silver3 小时前
LangChain入门实践3:PromptTemplate提示词模板详解
java·python·学习·langchain·language model
LaughingZhu3 小时前
Product Hunt 每日热榜 | 2025-10-03
人工智能·经验分享·搜索引擎·产品运营
Godspeed Zhao3 小时前
自动驾驶中的传感器技术65——Navigation(2)
人工智能·机器学习·自动驾驶
智能交通技术3 小时前
iTSTech:智慧物流中自动驾驶、无人机与机器人的协同应用场景分析 2025
人工智能·机器学习·机器人·自动驾驶·无人机