PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
没事学AI8 分钟前
美团搜索推荐统一Agent之交互协议与多Agent协同
人工智能·agent·美团·多agent
傻啦嘿哟26 分钟前
Python3解释器深度解析与实战教程:从源码到性能优化的全路径探索
开发语言·python
霖0027 分钟前
FPGA的PS基础1
数据结构·人工智能·windows·git·算法·fpga开发
Emma歌小白28 分钟前
groupby.agg去重后的展平列表通用方法flatten_unique
python
weixin_4569042732 分钟前
基于Tensorflow2.15的图像分类系统
人工智能·分类·tensorflow
修仙的人1 小时前
【开发环境】 VSCode 快速搭建 Python 项目开发环境
前端·后端·python
hhhh明1 小时前
Windows11 运行IsaacSim GPU Vulkan崩溃
vscode·python
在钱塘江2 小时前
LangGraph构建Ai智能体-12-高级RAG之自适应RAG
人工智能·python
聚客AI2 小时前
🚀碾压传统方案!vLLM与TGI/TensorRT-LLM性能实测对比
人工智能·llm·掘金·日新计划
站大爷IP2 小时前
Python列表基础操作全解析:从创建到灵活应用
python