PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
喝拿铁写前端1 小时前
别再让 AI 直接写页面了:一种更稳的中后台开发方式
前端·人工智能
Swizard2 小时前
别再让你的 Python 傻等了:三分钟带你通过 asyncio 实现性能起飞
python
tongxianchao2 小时前
UPDP: A Unified Progressive Depth Pruner for CNN and Vision Transformer
人工智能·cnn·transformer
塔能物联运维3 小时前
设备边缘计算任务调度卡顿 后来动态分配CPU/GPU资源
人工智能·边缘计算
过期的秋刀鱼!3 小时前
人工智能-深度学习-线性回归
人工智能·深度学习
木头左3 小时前
高级LSTM架构在量化交易中的特殊入参要求与实现
人工智能·rnn·lstm
IE064 小时前
深度学习系列84:使用kokoros生成tts语音
人工智能·深度学习
欧阳天羲4 小时前
#前端开发未来3年(2026-2028)核心趋势与AI应用实践
人工智能·前端框架
IE064 小时前
深度学习系列83:使用outetts
人工智能·深度学习
水中加点糖4 小时前
源码运行RagFlow并实现AI搜索(文搜文档、文搜图、视频理解)与自定义智能体(一)
人工智能·二次开发·ai搜索·文档解析·ai知识库·ragflow·mineru