Pytorch个人学习记录总结 10

目录

优化器


优化器

官方文档地址:torch.optimhttps://pytorch.org/docs/stable/optim.html

Debug过程中查看的grad所在的位置:

model --> Protected Atributes --> _modules --> 'model' --> Protected Atributes --> _modules --> '0'(任选一个conv层) --> weight(查看weight下的data和grad的变化)

简易训练代码,添加了Loss、Optim。

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):  # 模型前向传播
        return self.model(x)


model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数
optim = torch.optim.SGD(model.parameters(), lr=0.01)  # lr不能过大或者过小。刚开始的lr可设置得较大一点,后面再对lr进行调节
len = len(dataloader)

for epoch in range(20):
    total_loss = 0.0
    for imgs, targets in dataloader:
        outputs = model(imgs)
        res_loss = loss_cross(outputs, targets)

        optim.zero_grad()  # 优化器对model中的每一个参数进行梯度清零
        res_loss.backward()  # 损失反向传播
        optim.step()  # 对model参数开始调优
        total_loss += res_loss
    print('epoch:{}\ttotal_loss:{}\tmean_loss:{}.'.format(epoch, total_loss, total_loss / len))
# epoch:0	total_loss:9374.806640625	mean_loss:1.8749613761901855.
# epoch:1	total_loss:7721.240234375	mean_loss:1.544248104095459.
# epoch:2	total_loss:6830.775390625	mean_loss:1.3661550283432007.
相关推荐
bestcxx1 分钟前
0.2、AI Agent 开发中 ReAct 和 MAS 的概念
人工智能·python·dify·ai agent
Q一件事9 分钟前
arcgis重采样插值方法的选择
人工智能·arcgis
楼田莉子18 分钟前
C++学习:异常及其处理
开发语言·c++·学习·visual studio
fsnine21 分钟前
Python Web框架对比与模型部署
开发语言·前端·python
能不能别报错27 分钟前
K8s学习笔记(二十) 亲和性、污点、容忍、驱逐
笔记·学习·kubernetes
Xxtaoaooo33 分钟前
Sora文生视频技术拆解:Diffusion Transformer架构与时空建模原理
人工智能·架构·音视频·transformer·sora
kuniqiw33 分钟前
远程处理器协议框架学习
学习
lisw0534 分钟前
数字化科技简化移民流程的 5 种方式
大数据·人工智能·机器学习
空白到白43 分钟前
Transformer-解码器_编码器部分
人工智能·深度学习·transformer
悟乙己43 分钟前
PandasAI :使用 AI 优化你的分析工作流
人工智能·pandas·pandasai