PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
勇敢一点♂12 分钟前
canal-python的安装与入门
数据库·python
云空18 分钟前
《DeepSeek R1:7b 写一个python程序调用摄像头获取视频并显示》
开发语言·python·音视频
weixin_307779131 小时前
AWS门店人流量数据分析项目的设计与实现
python·数据分析·系统架构·云计算·aws
CodeClimb2 小时前
【华为OD-E卷 - 115 数组组成的最小数字 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
CodeClimb2 小时前
【华为OD-E卷 - 114 找最小数 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
max5006002 小时前
介绍使用 WGAN(Wasserstein GAN)网络对天然和爆破的地震波形图进行分类的实现步骤
人工智能·生成对抗网络·分类
白白糖2 小时前
Day 27 卡玛笔记
python·力扣
风靡晚2 小时前
论文解读:《基于TinyML毫米波雷达的座舱检测、定位与分类》
人工智能·算法·分类·信息与通信·信号处理
亲持红叶2 小时前
Boosting 框架
人工智能·python·机器学习·集成学习·boosting