PyTorch模型构造实战:从基础到复杂组合

本文通过多个示例演示如何使用PyTorch构建不同类型的神经网络模型,涵盖基础多层感知机、自定义块、顺序块以及复杂组合模型。所有代码均附带输出结果,帮助读者直观理解模型结构。


1. 多层感知机(MLP)

使用nn.Sequential快速构建一个包含隐藏层和ReLU激活函数的简单MLP。

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

# 定义模型
net = nn.Sequential(
    nn.Linear(in_features=20, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=10)
)

# 随机输入(2个样本,每个样本20维)
x = torch.rand(2, 20)
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0424, -0.0431,  0.0191, -0.0467, -0.1238,  0.0123,  0.0224, -0.0914,
         -0.0271, -0.0883],
        [ 0.1497,  0.0056,  0.1736, -0.0222, -0.1749,  0.0234,  0.1242, -0.1502,
         -0.0490, -0.1498]], grad_fn=<AddmmBackward0>)

2. 自定义块(Custom Block)

通过继承nn.Module实现自定义模型,灵活定义前向传播逻辑。

python 复制代码
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)  # 隐藏层
        self.out = nn.Linear(256, 10)     # 输出层
    
    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

# 实例化并推理
net = Model1()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[-0.0344,  0.0446,  0.1053,  0.0658,  0.2332, -0.0105,  0.1963,  0.0181,
          0.1822, -0.1304],
        [-0.1953, -0.0464,  0.1120,  0.0082,  0.1906,  0.0503,  0.2968,  0.0132,
          0.2769, -0.1390]], grad_fn=<AddmmBackward0>)

3. 顺序块(Sequential Block)

nn.Sequential封装在自定义类中,简化模型定义。

python 复制代码
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.se(x)

# 实例化并推理
net = Model2()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.2166, -0.0262, -0.0240, -0.0165,  0.0695, -0.2495,  0.0699, -0.2297,
          0.0436, -0.0792],
        [ 0.2417,  0.0458, -0.0206,  0.0546,  0.0468, -0.3599,  0.1273, -0.2373,
          0.0020, -0.1880]], grad_fn=<AddmmBackward0>)

4. 动态操作的正向传播

在正向传播中执行矩阵运算和条件判断,展示灵活的自定义逻辑。

python 复制代码
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20), requires_grad=False)  # 固定权重
        self.linear = nn.Linear(20, 20)
    
    def forward(self, x):
        x = self.linear(x)
        x = F.relu(torch.mm(x, self.rand_weight) + 1
        x = self.linear(x)
        while x.abs().sum() > 1:  # 动态调整张量大小
            x /= 2
        return x.sum()

# 实例化并推理
net = Model3()
output = net(x)
print(output)

输出结果

bash 复制代码
tensor(-0.1288, grad_fn=<SumBackward0>)

5. 混合组合模型

通过组合不同块构建复杂模型,实现层次化设计。

python 复制代码
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(20, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU()
        )
        self.linear = nn.Linear(32, 16)
    
    def forward(self, x):
        return self.linear(self.se(x))

# 组合多个模型
net = nn.Sequential(
    Model4(),
    nn.Linear(16, 20),
    Model2()
)

# 推理
output = net(x)
print(output)

输出结果

bash 复制代码
tensor([[ 0.0220,  0.0221, -0.0445,  0.0760, -0.0317,  0.1331,  0.0716, -0.0102,
          0.0294, -0.0422],
        [ 0.0170,  0.0196, -0.0564,  0.0732, -0.0360,  0.1253,  0.0783, -0.0079,
          0.0283, -0.0448]], grad_fn=<AddmmBackward0>)

总结

  • nn.Sequential:适合快速堆叠层,适用于简单模型。

  • 自定义类 :通过继承nn.Module实现更灵活的前向传播逻辑。

  • 动态操作:可在正向传播中嵌入矩阵运算、循环等复杂操作。

  • 组合模型:通过混合不同块构建复杂网络,提升代码复用性。

完整代码已通过测试,建议结合实际任务调整模型结构和参数。欢迎在评论区讨论更多PyTorch技巧!


希望这篇文章能帮助你掌握PyTorch模型构造的核心方法!如果有其他问题,欢迎留言交流。

相关推荐
Raink老师4 小时前
【AI面试临阵磨枪】Harness 的环境隔离(沙箱)如何设计?文件、网络、命令、权限四层隔离?
人工智能·ai 面试
人工智能AI技术5 小时前
Python 断言 assert 基础用法
人工智能
清水白石0085 小时前
Python 编程实战全景:从基础语法到插件架构、异步性能与工程最佳实践
开发语言·python·架构
我是发哥哈5 小时前
横向评测:五款主流AI培训课程效果与选型分析
人工智能
GetcharZp5 小时前
告别昂贵显卡!llama.cpp 终极指南:在你的电脑上满速运行大模型!
人工智能
AI木马人6 小时前
3.【Prompt工程实战】如何设计一个可复用的Prompt系统?(避免每次手写提示词)
linux·服务器·人工智能·深度学习·prompt
lwf0061646 小时前
导数学习日记
学习·算法·机器学习
Agent产品评测局6 小时前
临床前同源性反应种属筛选:利用AI Agent加速筛选的实操方案 —— 2026企业级智能体选型与技术落地指南
人工智能·ai·chatgpt
yaoxin5211236 小时前
390. Java IO API - WatchDir 示例
java·前端·python
ting94520006 小时前
HunyuanOCR 全方位深度解析
人工智能·架构