PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)
相关推荐
DX_水位流量监测8 分钟前
水库水雨情监测系统:水位、雨量、流量等参数全天候实时监测
大数据·开发语言·前端·网络·人工智能·信息可视化
苦瓜汤补钙8 分钟前
文本区域提取和分析——Python版本
开发语言·图像处理·python·计算机视觉
warren@伟_18 分钟前
Event-Based Visible and Infrared Fusion via Multi-Task Collaboration
人工智能·python·数码相机·计算机视觉
我叫czc22 分钟前
【Python高级374】正则表达式
python·mysql·正则表达式
dundunmm23 分钟前
【论文阅读】SCGC : Self-supervised contrastive graph clustering
论文阅读·人工智能·算法·数据挖掘·聚类·深度聚类·图聚类
古-月25 分钟前
【计算机视觉】单目深度估计模型-Depth Anything-V2
人工智能·计算机视觉
麦田里的稻草人w33 分钟前
【pyqt】(四)Designer布局
python·pyqt
胜天半月子1 小时前
Python | 学习type()方法动态创建类
开发语言·python·学习
鳄鱼的眼药水2 小时前
TT100K数据集, YOLO格式, COCO格式
人工智能·python·yolo·yolov5·yolov8
Tomorrow'sThinker2 小时前
25年1月更新。Windows 上搭建 Python 开发环境:Python + PyCharm 安装全攻略(文中有安装包不用官网下载)
开发语言·python·pycharm