Pytorch个人学习记录总结 08

目录

神经网络-搭建小实战和Sequential的使用

版本1------未用Sequential

版本2------用Sequential


神经网络-搭建小实战和Sequential的使用

  1. torch.nn.Sequential官方文档地址,模块将按照它们在构造函数中传递的顺序添加。
  2. 代码实现的是下图:

版本1------未用Sequential

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 3,32,32 ---> 32,32,32
        self.conv1 = Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,32,32 ---> 32,16,16
        self.maxpool1 = MaxPool2d(kernel_size=2, stride=2)
        # 32,16,16 ---> 32,16,16
        self.conv2 = Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,16,16 ---> 32,8,8
        self.maxpool2 = MaxPool2d(kernel_size=2, stride=2)
        # 32,8,8 ---> 64,8,8
        self.conv3 = Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        # 64,8,8 ---> 64,4,4
        self.maxpool3 = MaxPool2d(kernel_size=2, stride=2)
        # 64,4,4 ---> 1024
        self.flatten = Flatten()  # 因为start_dim默认为1,所以可不再另外设置
        # 1024 ---> 64
        self.linear1 = Linear(1024, 64)
        # 64 ---> 10
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

版本2------用Sequential

代码更简洁,而且会给每层自动从0开始编序。

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

在代码最末尾加上writer.add_gragh(model, input)就可看到模型计算图,可放大查看。

python 复制代码
writer = SummaryWriter('./logs/Seq')
writer.add_graph(model, input)
writer.close()
相关推荐
木叶子---2 分钟前
Spring 枚举转换器冲突问题分析与解决
java·python·spring
亚空间仓鼠3 分钟前
网络学习实例:网络理论知识
网络·学习·智能路由器
lizz66611 分钟前
Hermes-Agent:配置gateway网关,chat交互入口(钉钉Dingtalk)
人工智能
༒࿈南林࿈༒14 分钟前
链家二手房数据自动化点选验证码
python·自动化·点选验证码
财经汇报15 分钟前
从AI到抗量子:下一代金融基础设施正在发生什么变化?
人工智能·量子计算
IT_陈寒31 分钟前
Vite静态资源加载把我坑惨了
前端·人工智能·后端
后端小肥肠33 分钟前
我把自己蒸馏成小肥肠.skill,相关答疑全能做,一人公司终于能聚焦核心业务
人工智能·agent
天一生水water1 小时前
Time-Series-Library 仓库的使用
人工智能
HeteroCat1 小时前
DeepSeek V4 来了:我熬了一中午,把技术报告啃完了
人工智能
阿杰学AI1 小时前
AI核心知识135—大语言模型之 OpenClaw(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·ai编程·openclaw