Pytorch个人学习记录总结 08

目录

神经网络-搭建小实战和Sequential的使用

版本1------未用Sequential

版本2------用Sequential


神经网络-搭建小实战和Sequential的使用

  1. torch.nn.Sequential官方文档地址,模块将按照它们在构造函数中传递的顺序添加。
  2. 代码实现的是下图:

版本1------未用Sequential

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 3,32,32 ---> 32,32,32
        self.conv1 = Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,32,32 ---> 32,16,16
        self.maxpool1 = MaxPool2d(kernel_size=2, stride=2)
        # 32,16,16 ---> 32,16,16
        self.conv2 = Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
        # 32,16,16 ---> 32,8,8
        self.maxpool2 = MaxPool2d(kernel_size=2, stride=2)
        # 32,8,8 ---> 64,8,8
        self.conv3 = Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        # 64,8,8 ---> 64,4,4
        self.maxpool3 = MaxPool2d(kernel_size=2, stride=2)
        # 64,4,4 ---> 1024
        self.flatten = Flatten()  # 因为start_dim默认为1,所以可不再另外设置
        # 1024 ---> 64
        self.linear1 = Linear(1024, 64)
        # 64 ---> 10
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

版本2------用Sequential

代码更简洁,而且会给每层自动从0开始编序。

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)


model = Model()
print(model)

input = torch.ones((64, 3, 32, 32))
out = model(input)
print(out.shape)	# torch.Size([64, 10])

在代码最末尾加上writer.add_gragh(model, input)就可看到模型计算图,可放大查看。

python 复制代码
writer = SummaryWriter('./logs/Seq')
writer.add_graph(model, input)
writer.close()
相关推荐
zhousenshan2 分钟前
Python爬虫常用框架
开发语言·爬虫·python
非门由也3 分钟前
《sklearn机器学习——管道和复合估计器》回归中转换目标
机器学习·回归·sklearn
茯苓gao4 分钟前
STM32G4 速度环开环,电流环闭环 IF模式建模
笔记·stm32·单片机·嵌入式硬件·学习
dlraba8028 分钟前
基于 OpenCV 的信用卡数字识别:从原理到实现
人工智能·opencv·计算机视觉
是誰萆微了承諾17 分钟前
【golang学习笔记 gin 】1.2 redis 的使用
笔记·学习·golang
IMER SIMPLE23 分钟前
人工智能-python-深度学习-经典神经网络AlexNet
人工智能·python·深度学习
CodeCraft Studio39 分钟前
国产化Word处理组件Spire.DOC教程:使用 Python 将 Markdown 转换为 HTML 的详细教程
python·html·word·markdown·国产化·spire.doc·文档格式转换
DKPT1 小时前
Java内存区域与内存溢出
java·开发语言·jvm·笔记·学习
aaaweiaaaaaa1 小时前
HTML和CSS学习
前端·css·学习·html
专注API从业者1 小时前
Python/Java 代码示例:手把手教程调用 1688 API 获取商品详情实时数据
java·linux·数据库·python