学习pytorch13 神经网络-搭建小实战&Sequential的使用

神经网络-搭建小实战&Sequential的使用

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html#torch.nn.Sequential

sequential 将模型结构组合起来 以逗号分割,按顺序执行,和compose使用方式类似。

模型结构

根据模型结构和数据的输入shape,计算用在模型中的超参数

箭头指向部分还需要一层flatten层,展开输入shape为一维

code

py 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter


class MySeq(nn.Module):
    def __init__(self):
        super(MySeq, self).__init__()
        self.conv1 = Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.maxp1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, kernel_size=5, stride=1, padding=2)
        self.maxp2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.maxp3 = MaxPool2d(2)
        self.flatten1 = Flatten()
        self.linear1 = Linear(1024, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxp1(x)
        x = self.conv2(x)
        x = self.maxp2(x)
        x = self.conv3(x)
        x = self.maxp3(x)
        x = self.flatten1(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

class MySeq2(nn.Module):
    def __init__(self):
        super(MySeq2, self).__init__()
        self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Flatten(),
                                 Linear(1024, 64),
                                 Linear(64, 10)
                                 )

    def forward(self, x):
        x = self.model1(x)
        return x


myseq = MySeq()
input = torch.ones(64, 3, 32, 32)
print(myseq)
print(input.shape)
output = myseq(input)
print(output.shape)

myseq2 = MySeq2()
print(myseq2)
output2 = myseq2(input)
print(output2.shape)

wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
wirter.add_graph(myseq2, input)

running log

sh 复制代码
MySeq(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten1): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
torch.Size([64, 3, 32, 32])
torch.Size([64, 10])
MySeq2(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([64, 10])

网络结构可视化

py 复制代码
from torch.utils.tensorboard import SummaryWriter
wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
sh 复制代码
tensorboard --logdir=logs

tensorboard 展示图文件, 双击每层网络,可查看层定义细节

相关推荐
0思必得02 分钟前
[Web自动化] Selenium处理动态网页
前端·爬虫·python·selenium·自动化
韩立学长9 分钟前
【开题答辩实录分享】以《基于Python的大学超市仓储信息管理系统的设计与实现》为例进行选题答辩实录分享
开发语言·python
qq_1927798712 分钟前
高级爬虫技巧:处理JavaScript渲染(Selenium)
jvm·数据库·python
u01092727130 分钟前
使用Plotly创建交互式图表
jvm·数据库·python
爱学习的阿磊31 分钟前
Python GUI开发:Tkinter入门教程
jvm·数据库·python
Imm7771 小时前
中国知名的车膜品牌推荐几家
人工智能·python
tudficdew1 小时前
实战:用Python分析某电商销售数据
jvm·数据库·python
陈天伟教授1 小时前
人工智能应用-机器听觉:15. 声纹识别的应用
人工智能·神经网络·机器学习·语音识别
sjjhd6522 小时前
Python日志记录(Logging)最佳实践
jvm·数据库·python
2301_821369612 小时前
用Python生成艺术:分形与算法绘图
jvm·数据库·python