PyTorch入门学习(十四):优化器

目录

一、优化器的重要性

[二、PyTorch 中的深度学习](#二、PyTorch 中的深度学习)

三、优化器的选择


一、优化器的重要性

深度学习模型通常包含大量的参数,因此训练过程涉及到优化这些参数以减小损失函数的值。这个过程类似于找到函数的最小值,但由于模型通常非常复杂,所以需要依赖数值优化算法,即优化器。优化器的任务是调整模型参数,以最小化损失函数,从而提高模型的性能。

二、PyTorch 中的深度学习

PyTorch 是一个流行的深度学习框架,它提供了广泛的工具和库,用于创建、训练和部署深度学习模型。下面我们将通过一个简单的示例来了解如何使用 PyTorch 构建一个图像分类模型并训练它。

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

# 加载 CIFAR-10 数据集
dataset = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=1)

# 定义一个简单的神经网络模型
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10),
        )

    def forward(self, x):
        x = self.model1(x)
        return x

tudui = Tudui()

# 使用交叉熵损失函数
loss_cross = nn.CrossEntropyLoss()

# 使用随机梯度下降(SGD)优化器
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)

# 训练模型
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, labels = data
        outputs = tudui(imgs)
        results = loss_cross(outputs, labels)
        optim.zero_grad()
        results.backward()
        optim.step()
        running_loss = running_loss + results
    print(running_loss)

在上面的代码中,使用 PyTorch 创建了一个名为 Tudui 的神经网络模型,并使用 CIFAR-10 数据集进行训练。在训练过程中,使用了随机梯度下降(SGD)作为优化器来调整模型的参数,以降低交叉熵损失函数的值。

三、优化器的选择

在深度学习中,有多种不同类型的优化器可供选择,每种都有其独特的特点。常见的优化器包括随机梯度下降(SGD)、Adam、RMSprop 等。选择合适的优化器通常取决于具体问题的性质和模型的结构。在上述示例中,使用了SGD,但可以根据需要尝试不同的优化器,以找到最适合的问题的那一个。

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐
DreamLife☼9 小时前
OpenBCI-脑机接口在康复医疗中的应用
深度学习·cnn·脑电·康复·fes·openbci·外骨骼
硅谷秋水10 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
郭泽斌之心10 小时前
MQL5 EA 怎么和外部程序通信?文件三件套协议:参数热更新不重启、状态心跳、远程触发
人工智能·经验分享·深度学习·ea·fay数字人·easydeal
AI人工智能+10 小时前
智能文档抽取系统以专业的文档解析底座和大模型智能语义理解能力为核心,洞察文档的语义内涵与逻辑结构
深度学习·自然语言处理·ocr·文档抽取
nap-joker11 小时前
用于转录组信息精确肿瘤学和药物机制分析的多模态可解释深度学习
人工智能·深度学习·药物敏感性·多层级生物网络·细胞异质性·可解释性多模态
YOLO数据集集合11 小时前
无人机山地灾害巡检数据集 | 滑坡多区域实例分割 遥感影像解译 地质灾害预警深度学习数据10296期
人工智能·深度学习·目标检测·计算机视觉·无人机
袁小皮皮不皮12 小时前
1.HCIP BFD 学习笔记(优化版)
服务器·网络·笔记·网络协议·学习·智能路由器·ip
手写码匠12 小时前
手写 GraphRAG:从零实现图增强检索增强生成系统
人工智能·深度学习·算法·aigc
装不满的克莱因瓶13 小时前
【自动驾驶领域】学习 Cityscapes 数据集——城市街景语义理解的标准基准
人工智能·pytorch·python·深度学习·学习·机器学习·自动驾驶
清辞85313 小时前
产品经理需求推进流程
大数据·深度学习·学习·产品经理