PyTorch模型构造实战:从基础到复杂组合

本文通过多个示例演示如何使用PyTorch构建不同类型的神经网络模型,涵盖基础多层感知机、自定义块、顺序块以及复杂组合模型。所有代码均附带输出结果,帮助读者直观理解模型结构。


1. 多层感知机(MLP)

使用nn.Sequential快速构建一个包含隐藏层和ReLU激活函数的简单MLP。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

# 定义模型
net = nn.Sequential(
    nn.Linear(in_features=20, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=10)
)

# 随机输入(2个样本,每个样本20维)
x = torch.rand(2, 20)
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0424, -0.0431,  0.0191, -0.0467, -0.1238,  0.0123,  0.0224, -0.0914,
         -0.0271, -0.0883],
        [ 0.1497,  0.0056,  0.1736, -0.0222, -0.1749,  0.0234,  0.1242, -0.1502,
         -0.0490, -0.1498]], grad_fn=<AddmmBackward0>)

2. 自定义块(Custom Block)

通过继承nn.Module实现自定义模型,灵活定义前向传播逻辑。

python 复制代码
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)  # 隐藏层
        self.out = nn.Linear(256, 10)     # 输出层
    
    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

# 实例化并推理
net = Model1()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[-0.0344,  0.0446,  0.1053,  0.0658,  0.2332, -0.0105,  0.1963,  0.0181,
          0.1822, -0.1304],
        [-0.1953, -0.0464,  0.1120,  0.0082,  0.1906,  0.0503,  0.2968,  0.0132,
          0.2769, -0.1390]], grad_fn=<AddmmBackward0>)

3. 顺序块(Sequential Block)

nn.Sequential封装在自定义类中,简化模型定义。

python 复制代码
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.se(x)

# 实例化并推理
net = Model2()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.2166, -0.0262, -0.0240, -0.0165,  0.0695, -0.2495,  0.0699, -0.2297,
          0.0436, -0.0792],
        [ 0.2417,  0.0458, -0.0206,  0.0546,  0.0468, -0.3599,  0.1273, -0.2373,
          0.0020, -0.1880]], grad_fn=<AddmmBackward0>)

4. 动态操作的正向传播

在正向传播中执行矩阵运算和条件判断,展示灵活的自定义逻辑。

python 复制代码
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20), requires_grad=False)  # 固定权重
        self.linear = nn.Linear(20, 20)
    
    def forward(self, x):
        x = self.linear(x)
        x = F.relu(torch.mm(x, self.rand_weight) + 1
        x = self.linear(x)
        while x.abs().sum() > 1:  # 动态调整张量大小
            x /= 2
        return x.sum()

# 实例化并推理
net = Model3()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor(-0.1288, grad_fn=<SumBackward0>)

5. 混合组合模型

通过组合不同块构建复杂模型,实现层次化设计。

python 复制代码
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU()
        )
        self.linear = nn.Linear(32, 16)
    
    def forward(self, x):
        return self.linear(self.se(x))

# 组合多个模型
net = nn.Sequential(
    Model4(),
    nn.Linear(16, 20),
    Model2()
)

# 推理
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0220,  0.0221, -0.0445,  0.0760, -0.0317,  0.1331,  0.0716, -0.0102,
          0.0294, -0.0422],
        [ 0.0170,  0.0196, -0.0564,  0.0732, -0.0360,  0.1253,  0.0783, -0.0079,
          0.0283, -0.0448]], grad_fn=<AddmmBackward0>)

总结

  • nn.Sequential:适合快速堆叠层,适用于简单模型。

  • 自定义类 :通过继承nn.Module实现更灵活的前向传播逻辑。

  • 动态操作:可在正向传播中嵌入矩阵运算、循环等复杂操作。

  • 组合模型:通过混合不同块构建复杂网络,提升代码复用性。

完整代码已通过测试,建议结合实际任务调整模型结构和参数。欢迎在评论区讨论更多PyTorch技巧!


希望这篇文章能帮助你掌握PyTorch模型构造的核心方法!如果有其他问题,欢迎留言交流。

相关推荐
AngelPP1 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年1 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
AI探索者1 小时前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者1 小时前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
九狼1 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS2 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区3 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈3 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
FishCoderh3 小时前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅3 小时前
Python函数入门详解(定义+调用+参数)
python