优化器(一)torch.optim.SGD-随机梯度下降法

torch.optim.SGD-随机梯度下降法

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(running_loss)
相关推荐
龙亘川5 分钟前
医院通用人工智能平台设计与落地实践(2026)—— 面向智慧医院的 AI 操作系统架构解析
人工智能·医院通用人工智能平台技术白皮书
SelectDB技术团队6 分钟前
SelectDB Enterprise 4.0.5:强化安全与治理,构建企业级实时分析与 AI 数据底座
数据库·人工智能·apache doris
輕華9 分钟前
LSTM实战:遗忘门、输入门与输出门解决长期依赖
人工智能·rnn·lstm
Li emily9 分钟前
解决了美股api历史数据调用不稳定问题
人工智能·api·fastapi
weixin_5134499612 分钟前
PCA、SVD 、 ICP 、kd-tree算法的简单整理总结
c++·人工智能·学习·算法·机器人
code_pgf22 分钟前
Qwen2.5-VL 算法解析
人工智能·深度学习·算法·transformer
xiaotao13138 分钟前
01-编程基础与数学基石:概率与统计
人工智能·python·numpy·pandas
云烟成雨TD38 分钟前
Spring AI Alibaba 1.x 系列【23】短期记忆
java·人工智能·spring
竹之却1 小时前
【Agent-阿程】OpenClaw v2026.4.15 版本更新全解析
人工智能·ai·openclaw
嵌入式小企鹅1 小时前
DeepSeek-V4昇腾首发、国芯抗量子MCU突破、AI编程Agent抢班夺权
人工智能·学习·ai·程序员·算力·risc-v