优化器(一)torch.optim.SGD-随机梯度下降法

torch.optim.SGD-随机梯度下降法

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(running_loss)
相关推荐
m沐沐几秒前
【机器学习】Python 实现垃圾邮件分类(随机森林 + 可视化 + 特征重要性)
人工智能·python·随机森林·机器学习·分类·pycharm·回归算法
程序员cxuan几秒前
这个 6.6 k star 的仓库,我差点删库了。
人工智能·后端·程序员
扫地僧9854 分钟前
一个基于 PyTorch 手语翻译模型Xuanmen_Net
人工智能·pytorch·python
搬砖的小码农_Sky4 分钟前
Windows环境下OpenClaw本地部署完整指南
人工智能·windows·ai·人机交互·agi
风舞雪凌月8 分钟前
【总结】国产AI大模型公司汇总
人工智能
Hali_Botebie9 分钟前
【光流】自动驾驶光流任务 DeFlow: Decoder of Scene Flow Network in Autonomous Driving
人工智能·机器学习·自动驾驶
IT_陈寒12 分钟前
被Vite的HMR坑惨了,原来这样配置才能用对!
前端·人工智能·后端
“码”力全开15 分钟前
解耦安防碎片化:基于 Docker 与边缘计算的 AI 视频中台架构设计(支持 GB28181/RTSP 与源码交付)
人工智能·docker·边缘计算
sali-tec16 分钟前
C# 基于OpenCv的视觉工作流-章80-长短脚
图像处理·人工智能·opencv·算法·计算机视觉
AI科技星17 分钟前
国家重点研发计划项目申报书
人工智能·线性代数·架构·概率论·学习方法