Pytorch个人学习记录总结 10

目录

优化器


优化器

官方文档地址:torch.optimhttps://pytorch.org/docs/stable/optim.html

Debug过程中查看的grad所在的位置:

model --> Protected Atributes --> _modules --> 'model' --> Protected Atributes --> _modules --> '0'(任选一个conv层) --> weight(查看weight下的data和grad的变化)

简易训练代码,添加了Loss、Optim。

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):  # 模型前向传播
        return self.model(x)


model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数
optim = torch.optim.SGD(model.parameters(), lr=0.01)  # lr不能过大或者过小。刚开始的lr可设置得较大一点,后面再对lr进行调节
len = len(dataloader)

for epoch in range(20):
    total_loss = 0.0
    for imgs, targets in dataloader:
        outputs = model(imgs)
        res_loss = loss_cross(outputs, targets)

        optim.zero_grad()  # 优化器对model中的每一个参数进行梯度清零
        res_loss.backward()  # 损失反向传播
        optim.step()  # 对model参数开始调优
        total_loss += res_loss
    print('epoch:{}\ttotal_loss:{}\tmean_loss:{}.'.format(epoch, total_loss, total_loss / len))
# epoch:0	total_loss:9374.806640625	mean_loss:1.8749613761901855.
# epoch:1	total_loss:7721.240234375	mean_loss:1.544248104095459.
# epoch:2	total_loss:6830.775390625	mean_loss:1.3661550283432007.
相关推荐
看海天一色听风起雨落1 分钟前
Python学习之装饰器
开发语言·python·学习
小憩-2 分钟前
【机器学习】吴恩达机器学习笔记
人工智能·笔记·机器学习
却道天凉_好个秋29 分钟前
深度学习(二):神经元与神经网络
人工智能·神经网络·计算机视觉·神经元
UQI-LIUWJ30 分钟前
unsloth笔记:运行&微调 gemma
人工智能·笔记·深度学习
XiaoMu_00131 分钟前
基于Python+Streamlit的旅游数据分析与预测系统:从数据可视化到机器学习预测的完整实现
python·信息可视化·旅游
THMAIL33 分钟前
深度学习从入门到精通 - 生成对抗网络(GAN)实战:创造逼真图像的魔法艺术
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·cnn
却道天凉_好个秋35 分钟前
计算机视觉(八):开运算和闭运算
人工智能·计算机视觉·开运算与闭运算
无风听海36 分钟前
神经网络之深入理解偏置
人工智能·神经网络·机器学习·偏置
JoinApper38 分钟前
目标检测系列-Yolov5下载及运行
人工智能·yolo·目标检测
speop1 小时前
llm的一点学习笔记
笔记·学习