PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
Q***f6351 分钟前
机器学习书籍
人工智能·机器学习
顾安r3 分钟前
11.20 开源APP
服务器·前端·javascript·python·css3
小毅&Nora13 分钟前
【AI微服务】【Spring AI Alibaba】 ① 技术内核全解析:架构、组件与无缝扩展新模型能力
人工智能·微服务·架构
D***t13119 分钟前
DeepSeek模型在自然语言处理中的创新应用
人工智能·自然语言处理
WWZZ202519 分钟前
快速上手大模型:深度学习10(卷积神经网络2、模型训练实践、批量归一化)
人工智能·深度学习·神经网络·算法·机器人·大模型·具身智能
萧鼎42 分钟前
Python PyTesseract OCR :从基础到项目实战
开发语言·python·ocr
2501_941404311 小时前
绿色科技与可持续发展:科技如何推动环境保护与资源管理
大数据·人工智能
希露菲叶特格雷拉特1 小时前
PyTorch深度学习进阶(四)(数据增广)
人工智能·pytorch·深度学习
强盛小灵通专卖员1 小时前
基于RT-DETR的电力设备过热故障红外图像检测
人工智能·目标检测·sci·研究生·小论文·大论文·延毕
倔强青铜三1 小时前
AI编程革命:React + shadcn/ui 将终结前端框架之战
前端·人工智能·ai编程