PyTorch模型构造实战:从基础到复杂组合

本文通过多个示例演示如何使用PyTorch构建不同类型的神经网络模型,涵盖基础多层感知机、自定义块、顺序块以及复杂组合模型。所有代码均附带输出结果,帮助读者直观理解模型结构。


1. 多层感知机(MLP)

使用nn.Sequential快速构建一个包含隐藏层和ReLU激活函数的简单MLP。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

# 定义模型
net = nn.Sequential(
    nn.Linear(in_features=20, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=10)
)

# 随机输入(2个样本,每个样本20维)
x = torch.rand(2, 20)
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0424, -0.0431,  0.0191, -0.0467, -0.1238,  0.0123,  0.0224, -0.0914,
         -0.0271, -0.0883],
        [ 0.1497,  0.0056,  0.1736, -0.0222, -0.1749,  0.0234,  0.1242, -0.1502,
         -0.0490, -0.1498]], grad_fn=<AddmmBackward0>)

2. 自定义块(Custom Block)

通过继承nn.Module实现自定义模型,灵活定义前向传播逻辑。

python 复制代码
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)  # 隐藏层
        self.out = nn.Linear(256, 10)     # 输出层
    
    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

# 实例化并推理
net = Model1()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[-0.0344,  0.0446,  0.1053,  0.0658,  0.2332, -0.0105,  0.1963,  0.0181,
          0.1822, -0.1304],
        [-0.1953, -0.0464,  0.1120,  0.0082,  0.1906,  0.0503,  0.2968,  0.0132,
          0.2769, -0.1390]], grad_fn=<AddmmBackward0>)

3. 顺序块(Sequential Block)

nn.Sequential封装在自定义类中,简化模型定义。

python 复制代码
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.se(x)

# 实例化并推理
net = Model2()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.2166, -0.0262, -0.0240, -0.0165,  0.0695, -0.2495,  0.0699, -0.2297,
          0.0436, -0.0792],
        [ 0.2417,  0.0458, -0.0206,  0.0546,  0.0468, -0.3599,  0.1273, -0.2373,
          0.0020, -0.1880]], grad_fn=<AddmmBackward0>)

4. 动态操作的正向传播

在正向传播中执行矩阵运算和条件判断,展示灵活的自定义逻辑。

python 复制代码
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20), requires_grad=False)  # 固定权重
        self.linear = nn.Linear(20, 20)
    
    def forward(self, x):
        x = self.linear(x)
        x = F.relu(torch.mm(x, self.rand_weight) + 1
        x = self.linear(x)
        while x.abs().sum() > 1:  # 动态调整张量大小
            x /= 2
        return x.sum()

# 实例化并推理
net = Model3()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor(-0.1288, grad_fn=<SumBackward0>)

5. 混合组合模型

通过组合不同块构建复杂模型,实现层次化设计。

python 复制代码
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU()
        )
        self.linear = nn.Linear(32, 16)
    
    def forward(self, x):
        return self.linear(self.se(x))

# 组合多个模型
net = nn.Sequential(
    Model4(),
    nn.Linear(16, 20),
    Model2()
)

# 推理
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0220,  0.0221, -0.0445,  0.0760, -0.0317,  0.1331,  0.0716, -0.0102,
          0.0294, -0.0422],
        [ 0.0170,  0.0196, -0.0564,  0.0732, -0.0360,  0.1253,  0.0783, -0.0079,
          0.0283, -0.0448]], grad_fn=<AddmmBackward0>)

总结

  • nn.Sequential:适合快速堆叠层,适用于简单模型。

  • 自定义类 :通过继承nn.Module实现更灵活的前向传播逻辑。

  • 动态操作:可在正向传播中嵌入矩阵运算、循环等复杂操作。

  • 组合模型:通过混合不同块构建复杂网络,提升代码复用性。

完整代码已通过测试,建议结合实际任务调整模型结构和参数。欢迎在评论区讨论更多PyTorch技巧!


希望这篇文章能帮助你掌握PyTorch模型构造的核心方法!如果有其他问题,欢迎留言交流。

相关推荐
xwz小王子3 分钟前
Taccel:一个高性能的GPU加速视触觉机器人模拟平台
人工智能·机器人
黑匣子~3 分钟前
java集成telegram机器人
java·python·机器人·telegram
漫谈网络38 分钟前
Telnetlib三种异常处理方案
python·异常处理·telnet·telnetlib
Xudde.1 小时前
加速pip下载:永久解决网络慢问题
网络·python·学习·pip
兆。1 小时前
电子商城后台管理平台-Flask Vue项目开发
前端·vue.js·后端·python·flask
深空数字孪生1 小时前
AI时代的数据可视化:未来已来
人工智能·信息可视化
Icoolkj1 小时前
探秘 Canva AI 图像生成器:重塑设计创作新范式
人工智能
未名编程1 小时前
LeetCode 88. 合并两个有序数组 | Python 最简写法 + 实战注释
python·算法·leetcode
魔障阿Q1 小时前
windows使用bat脚本激活conda环境
人工智能·windows·python·深度学习·conda
Wnq100721 小时前
巡检机器人数据处理技术的创新与实践
网络·数据库·人工智能·机器人·巡检机器人