PyTorch——搭建小实战和Sequential的使用(7)


python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class TY(nn.Module):
    def __init__(self):
        """
        初始化TY卷积神经网络模型
        模型结构:3层卷积+池化,2层全连接
        设计目标:处理32x32像素的RGB图像分类任务
        """
        # 调用父类构造函数
        super(TY, self).__init__()
        
        # 卷积层1: 输入3通道(RGB),输出32通道
        # 5x5卷积核,padding=2保持特征图尺寸不变
        self.conv1 = Conv2d(3, 32, 5, padding=2)
        # 最大池化层1: 2x2窗口,步长2,尺寸减半
        self.maxpool1 = MaxPool2d(2)
        
        # 卷积层2: 输入32通道,输出32通道
        self.conv2 = Conv2d(32, 32, 5, padding=2)
        # 最大池化层2
        self.maxpool2 = MaxPool2d(2)
        
        # 卷积层3: 输入32通道,输出64通道
        # 增加通道数提取更复杂特征
        self.conv3 = Conv2d(32, 64, 5, padding=2)
        # 最大池化层3
        self.maxpool3 = MaxPool2d(2)
        
        # 展平多维张量为一维向量
        self.flatten = Flatten()
        
        # 全连接层1: 输入1024维,输出64维
        # 1024 = 64通道 x 4x4特征图(经过3次池化后尺寸为32→16→8→4)
        self.Linear1 = Linear(1024, 64)
        # 全连接层2: 输入64维,输出10维(对应10个分类类别)
        self.Linear2 = Linear(64, 10)

    def forward(self, x):
        """
        定义模型前向传播过程
        参数:
            x: 输入张量,形状为[batch_size, 3, 32, 32]
        返回:
            x: 输出张量,形状为[batch_size, 10]
        """
        # 第一层卷积+ReLU激活+池化
        # 输入: [batch, 3, 32, 32] → 输出: [batch, 32, 16, 16]
        x = self.conv1(x)
        x = self.maxpool1(x)
        
        # 第二层卷积+ReLU激活+池化
        # 输入: [batch, 32, 16, 16] → 输出: [batch, 32, 8, 8]
        x = self.conv2(x)
        x = self.maxpool2(x)
        
        # 第三层卷积+ReLU激活+池化
        # 输入: [batch, 32, 8, 8] → 输出: [batch, 64, 4, 4]
        x = self.conv3(x)
        x = self.maxpool3(x)
        
        # 展平操作
        # 输入: [batch, 64, 4, 4] → 输出: [batch, 64*4*4=1024]
        x = self.flatten(x)
        
        # 全连接层1 + ReLU激活
        # 输入: [batch, 1024] → 输出: [batch, 64]
        x = self.Linear1(x)
        
        # 全连接层2 (分类层)
        # 输入: [batch, 64] → 输出: [batch, 10]
        x = self.Linear2(x)
        
        return x

# 创建模型实例
ty = TY()
# 打印模型结构
print(ty)

# 创建测试输入:64张32x32的RGB图像(全1值)
input = torch.ones((64, 3, 32, 32))
# 执行前向传播
output = ty(input)
# 打印输出形状,应为[64, 10]
print(f"输出形状: {output.shape}")

torch.ones用法


python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter

# 定义TY卷积神经网络模型,继承自PyTorch的nn.Module
class TY(nn.Module):
    def __init__(self):
        # 调用父类构造函数
        super(TY,self).__init__()
        # 使用Sequential容器构建网络,按顺序堆叠各层
        self.model1 = Sequential(
            # 第一个卷积层:3通道输入,32通道输出,5x5卷积核,padding=2保持尺寸
            Conv2d(3,32,5,padding=2),
            # 第一个池化层:2x2窗口,下采样至16x16
            MaxPool2d(2),
            # 第二个卷积层:32通道输入,32通道输出
            Conv2d(32,32,5,padding=2),
            # 第二个池化层:下采样至8x8
            MaxPool2d(2),
            # 第三个卷积层:32通道输入,64通道输出
            Conv2d(32,64,5,padding=2),
            # 第三个池化层:下采样至4x4
            MaxPool2d(2),
            # 展平多维张量为一维向量:64x4x4=1024
            Flatten(),
            # 第一个全连接层:1024维输入,64维输出
            Linear(1024,64),
            # 第二个全连接层:64维输入,10维输出(对应10个分类)
            Linear(64, 10),
        )

    def forward(self, x):
        # 定义前向传播路径
        x = self.model1(x)
        return x

# 创建模型实例
ty = TY()
# 打印模型结构
print(ty)
# 创建测试输入:64个样本,3通道,32x32尺寸
input = torch.ones((64,3,32,32))
# 执行前向传播
output = ty(input)
# 打印输出形状,验证网络结构正确性
print(output.shape)

# 创建TensorBoard日志写入器,保存日志到'./logs_seq'目录
writer = SummaryWriter("./logs_seq")
# 将模型结构写入TensorBoard,便于可视化分析
writer.add_graph(ty,input)
# 关闭写入器,释放资源
writer.close()

相关推荐
产业家1 分钟前
Sora 后思考:从 AI 工具到 AI 平台,产业 AGI 又近了一步
人工智能·chatgpt·agi
量化交易曾小健(金融号)5 分钟前
人大计算金融课程名称:《机器学习》(题库)/《大数据与机器学习》(非题库) 姜昊教授
人工智能
一晌小贪欢9 分钟前
Python爬虫第7课:多线程与异步爬虫技术
开发语言·爬虫·python·网络爬虫·python爬虫·python3
IT_陈寒12 分钟前
Redis 性能翻倍的 5 个隐藏技巧,99% 的开发者都不知道第3点!
前端·人工智能·后端
W_chuanqi15 分钟前
RDEx:一种效果驱动的混合单目标优化器,自适应选择与融合多种算子与策略
人工智能·算法·机器学习·性能优化
好奇龙猫17 分钟前
[AI学习:SPIN -win-安装SPIN-工具过程 SPIN win 电脑安装=accoda 环境-第四篇:代码修复]
人工智能·学习
Pocker_Spades_A25 分钟前
AI搜索自由:Perplexica+cpolar构建你的私人知识引擎
人工智能
~kiss~26 分钟前
图像的脉冲噪声和中值滤波
图像处理·人工智能·计算机视觉
居7然29 分钟前
DeepSeek-7B-chat 4bits量化 QLora 微调
人工智能·分布式·架构·大模型·transformer
卡奥斯开源社区官方31 分钟前
OpenAI万亿美元计划技术拆解:AI智能体的架构演进与商业化实践
人工智能