Pytorch个人学习记录总结 10

目录

优化器


优化器

官方文档地址:torch.optimhttps://pytorch.org/docs/stable/optim.html

Debug过程中查看的grad所在的位置:

model --> Protected Atributes --> _modules --> 'model' --> Protected Atributes --> _modules --> '0'(任选一个conv层) --> weight(查看weight下的data和grad的变化)

简易训练代码,添加了Loss、Optim。

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):  # 模型前向传播
        return self.model(x)


model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数
optim = torch.optim.SGD(model.parameters(), lr=0.01)  # lr不能过大或者过小。刚开始的lr可设置得较大一点,后面再对lr进行调节
len = len(dataloader)

for epoch in range(20):
    total_loss = 0.0
    for imgs, targets in dataloader:
        outputs = model(imgs)
        res_loss = loss_cross(outputs, targets)

        optim.zero_grad()  # 优化器对model中的每一个参数进行梯度清零
        res_loss.backward()  # 损失反向传播
        optim.step()  # 对model参数开始调优
        total_loss += res_loss
    print('epoch:{}\ttotal_loss:{}\tmean_loss:{}.'.format(epoch, total_loss, total_loss / len))
# epoch:0	total_loss:9374.806640625	mean_loss:1.8749613761901855.
# epoch:1	total_loss:7721.240234375	mean_loss:1.544248104095459.
# epoch:2	total_loss:6830.775390625	mean_loss:1.3661550283432007.
相关推荐
橙子小哥的代码世界2 分钟前
【机器学习】【KMeans聚类分析实战】用户分群聚类详解——SSE、CH 指数、SC全解析,实战电信客户分群案例
人工智能·python·机器学习·kmeans·数据科学·聚类算法·肘部法
计算机徐师兄2 分钟前
Python基于Flask的豆瓣Top250电影数据可视化分析与评分预测系统(附源码,技术说明)
python·flask·豆瓣top250电影数据可视化·豆瓣top250电影评分预测·豆瓣电影数据可视化分析系统·豆瓣电影评分预测系统·豆瓣电影数据
k layc7 分钟前
【论文解读】《Training Large Language Models to Reason in a Continuous Latent Space》
人工智能·python·机器学习·语言模型·自然语言处理·大模型推理
im长街8 分钟前
Ubuntu22.04 - brpc的安装和使用
学习
知识分享小能手13 分钟前
Html5学习教程,从入门到精通,HTML5 简介语法知识点及案例代码(1)
开发语言·前端·javascript·学习·前端框架·html·html5
代码猪猪傻瓜coding15 分钟前
【模块】 ASFF 模块
人工智能·深度学习
阿正的梦工坊21 分钟前
Sliding Window Attention(滑动窗口注意力)解析: Pytorch实现并结合全局注意力(Global Attention )
人工智能·pytorch·python
喜-喜1 小时前
Python pip 缓存清理:全面方法与操作指南
python·缓存·pip
rgb2gray1 小时前
GeoHD - 一种用于智慧城市热点探测的Python工具箱
人工智能·python·智慧城市
火车叼位1 小时前
5个Why、SWOT, 5W2H等方法论总结,让你的提示词更加精炼
人工智能