学习pytorch13 神经网络-搭建小实战&Sequential的使用

神经网络-搭建小实战&Sequential的使用

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html#torch.nn.Sequential

sequential 将模型结构组合起来 以逗号分割,按顺序执行,和compose使用方式类似。

模型结构

根据模型结构和数据的输入shape,计算用在模型中的超参数

箭头指向部分还需要一层flatten层,展开输入shape为一维

code

py 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter


class MySeq(nn.Module):
    def __init__(self):
        super(MySeq, self).__init__()
        self.conv1 = Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.maxp1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, kernel_size=5, stride=1, padding=2)
        self.maxp2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.maxp3 = MaxPool2d(2)
        self.flatten1 = Flatten()
        self.linear1 = Linear(1024, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxp1(x)
        x = self.conv2(x)
        x = self.maxp2(x)
        x = self.conv3(x)
        x = self.maxp3(x)
        x = self.flatten1(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

class MySeq2(nn.Module):
    def __init__(self):
        super(MySeq2, self).__init__()
        self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Flatten(),
                                 Linear(1024, 64),
                                 Linear(64, 10)
                                 )

    def forward(self, x):
        x = self.model1(x)
        return x


myseq = MySeq()
input = torch.ones(64, 3, 32, 32)
print(myseq)
print(input.shape)
output = myseq(input)
print(output.shape)

myseq2 = MySeq2()
print(myseq2)
output2 = myseq2(input)
print(output2.shape)

wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
wirter.add_graph(myseq2, input)

running log

sh 复制代码
MySeq(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten1): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
torch.Size([64, 3, 32, 32])
torch.Size([64, 10])
MySeq2(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([64, 10])

网络结构可视化

py 复制代码
from torch.utils.tensorboard import SummaryWriter
wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
sh 复制代码
tensorboard --logdir=logs

tensorboard 展示图文件, 双击每层网络,可查看层定义细节

相关推荐
刺客-Andy5 分钟前
Python 第二十节 正则表达式使用详解及注意事项
python·mysql·正则表达式
新子y1 小时前
【小白笔记】「while」在程序语言中的角色
笔记·python
java1234_小锋1 小时前
[免费]基于Python的YOLO深度学习垃圾分类目标检测系统【论文+源码】
python·深度学习·yolo·垃圾分类·垃圾分类检测
凌晨一点的秃头猪2 小时前
面向对象和面向过程 编程思想
python
sensen_kiss2 小时前
INT301 Bio-computation 生物计算(神经网络)Pt.2 监督学习模型:感知器(Perceptron)
神经网络·学习·机器学习
总有刁民想爱朕ha2 小时前
银河麒麟v10批量部署Python Flask项目小白教程
开发语言·python·flask·银河麒麟v10
Python×CATIA工业智造3 小时前
Python函数包装技术详解:从基础装饰器到高级应用
python·pycharm
快秃头的码农4 小时前
LazyLLM,(万象应用开发平台 AppStudio)商汤大装置
python
Cathy Bryant4 小时前
线性代数直觉(四):找到特征向量
笔记·神经网络·考研·机器学习·数学建模
離離原上譜5 小时前
python-docx 安装与快速入门
python·word·python-docx·自动化办公·1024程序员节