PyTorch模型构造实战:从基础到复杂组合

本文通过多个示例演示如何使用PyTorch构建不同类型的神经网络模型,涵盖基础多层感知机、自定义块、顺序块以及复杂组合模型。所有代码均附带输出结果,帮助读者直观理解模型结构。


1. 多层感知机(MLP)

使用nn.Sequential快速构建一个包含隐藏层和ReLU激活函数的简单MLP。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

# 定义模型
net = nn.Sequential(
    nn.Linear(in_features=20, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=10)
)

# 随机输入(2个样本,每个样本20维)
x = torch.rand(2, 20)
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0424, -0.0431,  0.0191, -0.0467, -0.1238,  0.0123,  0.0224, -0.0914,
         -0.0271, -0.0883],
        [ 0.1497,  0.0056,  0.1736, -0.0222, -0.1749,  0.0234,  0.1242, -0.1502,
         -0.0490, -0.1498]], grad_fn=<AddmmBackward0>)

2. 自定义块(Custom Block)

通过继承nn.Module实现自定义模型,灵活定义前向传播逻辑。

python 复制代码
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)  # 隐藏层
        self.out = nn.Linear(256, 10)     # 输出层
    
    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

# 实例化并推理
net = Model1()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[-0.0344,  0.0446,  0.1053,  0.0658,  0.2332, -0.0105,  0.1963,  0.0181,
          0.1822, -0.1304],
        [-0.1953, -0.0464,  0.1120,  0.0082,  0.1906,  0.0503,  0.2968,  0.0132,
          0.2769, -0.1390]], grad_fn=<AddmmBackward0>)

3. 顺序块(Sequential Block)

nn.Sequential封装在自定义类中,简化模型定义。

python 复制代码
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.se(x)

# 实例化并推理
net = Model2()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.2166, -0.0262, -0.0240, -0.0165,  0.0695, -0.2495,  0.0699, -0.2297,
          0.0436, -0.0792],
        [ 0.2417,  0.0458, -0.0206,  0.0546,  0.0468, -0.3599,  0.1273, -0.2373,
          0.0020, -0.1880]], grad_fn=<AddmmBackward0>)

4. 动态操作的正向传播

在正向传播中执行矩阵运算和条件判断,展示灵活的自定义逻辑。

python 复制代码
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20), requires_grad=False)  # 固定权重
        self.linear = nn.Linear(20, 20)
    
    def forward(self, x):
        x = self.linear(x)
        x = F.relu(torch.mm(x, self.rand_weight) + 1
        x = self.linear(x)
        while x.abs().sum() > 1:  # 动态调整张量大小
            x /= 2
        return x.sum()

# 实例化并推理
net = Model3()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor(-0.1288, grad_fn=<SumBackward0>)

5. 混合组合模型

通过组合不同块构建复杂模型,实现层次化设计。

python 复制代码
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU()
        )
        self.linear = nn.Linear(32, 16)
    
    def forward(self, x):
        return self.linear(self.se(x))

# 组合多个模型
net = nn.Sequential(
    Model4(),
    nn.Linear(16, 20),
    Model2()
)

# 推理
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0220,  0.0221, -0.0445,  0.0760, -0.0317,  0.1331,  0.0716, -0.0102,
          0.0294, -0.0422],
        [ 0.0170,  0.0196, -0.0564,  0.0732, -0.0360,  0.1253,  0.0783, -0.0079,
          0.0283, -0.0448]], grad_fn=<AddmmBackward0>)

总结

  • nn.Sequential:适合快速堆叠层,适用于简单模型。

  • 自定义类 :通过继承nn.Module实现更灵活的前向传播逻辑。

  • 动态操作:可在正向传播中嵌入矩阵运算、循环等复杂操作。

  • 组合模型:通过混合不同块构建复杂网络,提升代码复用性。

完整代码已通过测试,建议结合实际任务调整模型结构和参数。欢迎在评论区讨论更多PyTorch技巧!


希望这篇文章能帮助你掌握PyTorch模型构造的核心方法!如果有其他问题,欢迎留言交流。

相关推荐
用户5191495848453 分钟前
30条顶级APT与蓝队攻防单行命令:网络战场终极对决
人工智能·aigc
双向333 分钟前
AI 辅助文档生成:从接口注释到自动化 API 文档上线
人工智能
CoovallyAIHub19 分钟前
SBP-YOLO:面向嵌入式悬架的轻量实时模型,实现减速带与坑洼高精度检测
深度学习·算法·计算机视觉
算法打盹中24 分钟前
基于树莓派与Jetson Nano集群的实验边缘设备上视觉语言模型(VLMs)的性能评估与实践探索
人工智能·计算机视觉·语言模型·自然语言处理·树莓派·多模态·jetson nano
卿·静29 分钟前
Node.js对接即梦AI实现“千军万马”视频
前端·javascript·人工智能·后端·node.js
YangYang9YangYan30 分钟前
2025年金融专业人士职业认证发展路径分析
大数据·人工智能·金融
AIbase202431 分钟前
GEO优化服务:技术演进如何重塑搜索优化行业新范式
大数据·人工智能
HuggingFace36 分钟前
ZeroGPU Spaces 加速实践:PyTorch 提前编译全解析
pytorch·zerogpu
摆烂z40 分钟前
ollama笔记
人工智能