优化器(一)torch.optim.SGD-随机梯度下降法

torch.optim.SGD-随机梯度下降法

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(running_loss)
相关推荐
居然JuRan3 小时前
DeepSeek-R1-Distill-Qwen-7B vLLM 部署调用
人工智能
mwq301233 小时前
GPT:GELU (Gaussian Error Linear Unit) 激活函数详解
人工智能
数据库安全3 小时前
山东省某三甲医院基于分类分级的数据安全防护建设实践
大数据·人工智能
七牛云行业应用3 小时前
从API调用到智能体编排:GPT-5时代的AI开发新模式
大数据·人工智能·gpt·openai·agent开发
StarPrayers.3 小时前
用 PyTorch 搭建 CIFAR10 线性分类器:从数据加载到模型推理全流程解析
人工智能·pytorch·python
Francek Chen4 小时前
【深度学习计算机视觉】13:实战Kaggle比赛:图像分类 (CIFAR-10)
深度学习·计算机视觉·分类
Ro Jace4 小时前
模式识别与机器学习课程笔记(11):深度学习
笔记·深度学习·机器学习
碱化钾4 小时前
Lipschitz连续及其常量
人工智能·机器学习
两万五千个小时4 小时前
LangChain 入门教程:06LangGraph工作流编排
人工智能·后端
渡我白衣4 小时前
深度学习进阶(六)——世界模型与具身智能:AI的下一次跃迁
人工智能·深度学习