优化器(一)torch.optim.SGD-随机梯度下降法

torch.optim.SGD-随机梯度下降法

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(running_loss)
相关推荐
耘瞳科技26 分钟前
喜讯 | 耘瞳科技视觉检测与测量装备荣膺“2024机器视觉创新产品TOP10”
人工智能·科技·视觉检测
__Benco3 小时前
OpenHarmony子系统开发 - DFX(一)
人工智能·harmonyos
小西几哦3 小时前
3D点云配准RPM-Net模型解读(附论文+源码)
人工智能·pytorch·3d
CareyWYR3 小时前
每周AI论文速递(250331-250404)
人工智能
码视野3 小时前
基于快速开发平台与智能手表的区域心电监测与AI预警系统(源码+论文+部署讲解等)
人工智能·智能手表·毕业论文·计算机论文·物联网论文
skywalk81633 小时前
OpenRouter开源的AI大模型路由工具,统一API调用
服务器·前端·人工智能·openrouter
ejinxian4 小时前
大模型应用初学指南
人工智能·大模型·向量数据库
秋94 小时前
使用人工智能大模型kimi,如何免费高效制作PPT?
人工智能·kimi·制作ppt
IT古董4 小时前
【漫话机器学习系列】181.没有免费的午餐定理(NFL)
人工智能·机器学习