PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
摸鱼仙人~13 分钟前
深度对比:Prompt Tuning、P-tuning 与 Prefix Tuning 有何不同?
人工智能·prompt
塔能物联运维34 分钟前
隧道照明“智能进化”:PLC 通信 + AI 调光守护夜间通行生命线
大数据·人工智能
瑶光守护者34 分钟前
【AI经典论文解读】《Denoising Diffusion Implicit Models(去噪扩散隐式模型)》论文深度解读
人工智能
wwwzhouhui37 分钟前
2026年1月18日-Obsidian + AI,笔记效率提升10倍!一键生成Canvas和小红书风格笔记
人工智能·obsidian·skills
我星期八休息43 分钟前
MySQL数据可视化实战指南
数据库·人工智能·mysql·算法·信息可视化
UR的出不克1 小时前
使用 Python 爬取 Bilibili 弹幕数据并导出 Excel
java·python·excel
wuk9981 小时前
基于遗传算法优化BP神经网络实现非线性函数拟合
人工智能·深度学习·神经网络
码农三叔1 小时前
(1-3)人形机器人的发展历史、趋势与应用场景:人形机器人关键技术体系总览
人工智能·机器人
Arms2061 小时前
python时区库学习
开发语言·python·学习
白日做梦Q1 小时前
深度学习中的正则化技术全景:从Dropout到权重衰减的优化逻辑
人工智能·深度学习