PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
lifloveyou1 分钟前
table接口结构
python
阿乔外贸日记2 分钟前
2026尼日利亚五项清关政策更新,拉高能源装备进口综合成本
大数据·人工智能·搜索引擎·智能手机·云计算·能源
民乐团扒谱机22 分钟前
【AI笔记】短时纯音时长对音高感知偏移效应研究综述
人工智能·笔记
侃谈科技圈32 分钟前
破除数据中台落地困境:2026数据治理平台差异化能力与选型决策指南
大数据·人工智能
大象说1 小时前
Python多进程共享队列无报错僵死 120G Nginx访问日志清洗踩坑全记录
人工智能·自然语言处理
Cosolar1 小时前
AutoGen 精通教程:从零到企业级多 Agent 系统架构师
人工智能·后端·面试
甲维斯1 小时前
Claude Code 省钱小妙招!200K和自动压缩
人工智能
DO_Community1 小时前
DigitalOcean 的 AI 推理路由器是如何构建的
人工智能·开源·agent·claude·deepseek
Elastic 中国社区官方博客1 小时前
Elasticsearch DiskBBQ:使用原生 SIMD Blocks 实现快 40% 的向量评分计算
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索·diskbbq
Axis tech1 小时前
爱迪斯通携手智元机器人亮相COMPUTEX 2026大会
人工智能·机器人