Pytorch个人学习记录总结 08

目录

神经网络-搭建小实战和Sequential的使用

版本1------未用Sequential

版本2------用Sequential


神经网络-搭建小实战和Sequential的使用

  1. torch.nn.Sequential官方文档地址,模块将按照它们在构造函数中传递的顺序添加。
  2. 代码实现的是下图:

版本1------未用Sequential

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 3,32,32 ---> 32,32,32
        self.conv1 = Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,32,32 ---> 32,16,16
        self.maxpool1 = MaxPool2d(kernel_size=2, stride=2)
        # 32,16,16 ---> 32,16,16
        self.conv2 = Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,16,16 ---> 32,8,8
        self.maxpool2 = MaxPool2d(kernel_size=2, stride=2)
        # 32,8,8 ---> 64,8,8
        self.conv3 = Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        # 64,8,8 ---> 64,4,4
        self.maxpool3 = MaxPool2d(kernel_size=2, stride=2)
        # 64,4,4 ---> 1024
        self.flatten = Flatten()  # 因为start_dim默认为1,所以可不再另外设置
        # 1024 ---> 64
        self.linear1 = Linear(1024, 64)
        # 64 ---> 10
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

版本2------用Sequential

代码更简洁,而且会给每层自动从0开始编序。

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

在代码最末尾加上writer.add_gragh(model, input)就可看到模型计算图,可放大查看。

python 复制代码
writer = SummaryWriter('./logs/Seq')
writer.add_graph(model, input)
writer.close()
相关推荐
江_小_白44 分钟前
自动驾驶之激光雷达
人工智能·机器学习·自动驾驶
yusaisai大鱼2 小时前
TensorFlow如何调用GPU?
人工智能·tensorflow
湫ccc3 小时前
《Python基础》之字符串格式化输出
开发语言·python
Red Red3 小时前
网安基础知识|IDS入侵检测系统|IPS入侵防御系统|堡垒机|VPN|EDR|CC防御|云安全-VDC/VPC|安全服务
网络·笔记·学习·安全·web安全
mqiqe4 小时前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin4 小时前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪4084 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
珠海新立电子科技有限公司4 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
湫ccc5 小时前
《Python基础》之基本数据类型
开发语言·python
IT古董5 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论