PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
计算机学长felix5 分钟前
基于Django的“酒店推荐系统”设计与开发(源码+数据库+文档+PPT)
数据库·python·mysql·django·vue
站大爷IP6 分钟前
Python随机数函数全解析:5个核心工具的实战指南
python
知来者逆8 分钟前
视觉语言模型应用开发——Qwen 2.5 VL模型视频理解与定位能力深度解析及实践指南
人工智能·语言模型·自然语言处理·音视频·视觉语言模型·qwen 2.5 vl
IT_陈寒9 分钟前
Java性能优化:10个让你的Spring Boot应用提速300%的隐藏技巧
前端·人工智能·后端
Android出海13 分钟前
Android 15重磅升级:16KB内存页机制详解与适配指南
android·人工智能·新媒体运营·产品运营·内容运营
cyyt14 分钟前
深度学习周报(9.1~9.7)
人工智能·深度学习
悟乙己14 分钟前
使用 Python 中的强化学习最大化简单 RAG 性能
开发语言·python·agent·rag·n8n
聚客AI17 分钟前
🌸万字解析:大规模语言模型(LLM)推理中的Prefill与Decode分离方案
人工智能·llm·掘金·日新计划
max50060019 分钟前
图像处理:实现多图点重叠效果
开发语言·图像处理·人工智能·python·深度学习·音视频
AI原吾30 分钟前
玩转物联网只需十行代码,可它为何悄悄停止维护
python·物联网·hbmqtt