学习pytorch13 神经网络-搭建小实战&Sequential的使用

神经网络-搭建小实战&Sequential的使用

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html#torch.nn.Sequential

sequential 将模型结构组合起来 以逗号分割,按顺序执行,和compose使用方式类似。

模型结构

根据模型结构和数据的输入shape,计算用在模型中的超参数

箭头指向部分还需要一层flatten层,展开输入shape为一维

code

py 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter


class MySeq(nn.Module):
    def __init__(self):
        super(MySeq, self).__init__()
        self.conv1 = Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.maxp1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, kernel_size=5, stride=1, padding=2)
        self.maxp2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.maxp3 = MaxPool2d(2)
        self.flatten1 = Flatten()
        self.linear1 = Linear(1024, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxp1(x)
        x = self.conv2(x)
        x = self.maxp2(x)
        x = self.conv3(x)
        x = self.maxp3(x)
        x = self.flatten1(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

class MySeq2(nn.Module):
    def __init__(self):
        super(MySeq2, self).__init__()
        self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Flatten(),
                                 Linear(1024, 64),
                                 Linear(64, 10)
                                 )

    def forward(self, x):
        x = self.model1(x)
        return x


myseq = MySeq()
input = torch.ones(64, 3, 32, 32)
print(myseq)
print(input.shape)
output = myseq(input)
print(output.shape)

myseq2 = MySeq2()
print(myseq2)
output2 = myseq2(input)
print(output2.shape)

wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
wirter.add_graph(myseq2, input)

running log

sh 复制代码
MySeq(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxp3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten1): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
torch.Size([64, 3, 32, 32])
torch.Size([64, 10])
MySeq2(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([64, 10])

网络结构可视化

py 复制代码
from torch.utils.tensorboard import SummaryWriter
wirter = SummaryWriter('logs')
wirter.add_graph(myseq, input)
sh 复制代码
tensorboard --logdir=logs

tensorboard 展示图文件, 双击每层网络,可查看层定义细节

相关推荐
无风听海11 分钟前
行向量和列向量在神经网络应用中的选择
人工智能·深度学习·神经网络·行向量·列向量
go&Python16 分钟前
检索模型与RAG
开发语言·python·llama
阿里云大数据AI技术44 分钟前
ODPS 十五周年实录 | Data + AI,MaxCompute 下一个15年的新增长引擎
大数据·python·sql
RainbowJie11 小时前
Gemini CLI 与 MCP 服务器:释放本地工具的强大潜力
java·服务器·spring boot·后端·python·单元测试·maven
能力越小责任越小YA1 小时前
服务器(Linux)新账户搭建Pytorch深度学习环境
人工智能·pytorch·深度学习·环境搭建
工作碎碎念1 小时前
NumPy------数值计算
python
工作碎碎念1 小时前
pandas
python
A7bert7772 小时前
【YOLOv5部署至RK3588】模型训练→转换RKNN→开发板部署
c++·人工智能·python·深度学习·yolo·目标检测·机器学习
冷月半明2 小时前
时间序列篇:Prophet负责优雅,LightGBM负责杀疯
python·算法
教练我想打篮球_基本功重塑版3 小时前
L angChain 加载大模型
python·langchain