【PyTorch】继承 nn.Module 创建简单神经网络

在面向对象编程(OOP)中,继承 是一种允许你创建一个新类的机制,新类可以继承已有类的特性(如方法和属性),并且可以对其进行修改或扩展。

nn.Module 是 PyTorch 所有神经网络模块的基类,几乎所有的神经网络层、模型和操作都要继承自它,这是为了确保模型能够正确地与 PyTorch 的自动求导机制、优化器和其他工具一起工作。

1. 方法作用

nn.Module 类提供了几个核心方法:

__init__(self)init () 方法是初始化模型的地方。在这个方法里,你定义所有网络的层 (如 nn.Linear、nn.Conv2d 等)、操作、权重等。也就是说,你在 init () 中定义了网络的结构。这个方法中,你需要调用 super().init() 来初始化父类的属性。

forward(self, *input):在这个方法里,你定义前向传播的过程,也就是输入数据如何经过不同的层处理得到输出,也就是前向传播过程。在 forward() 方法中,你指定了输入数据如何通过各个层和操作,最终得到输出。前向传播是模型的核心计算部分。

2. 简单例子

2.1 定义模型

假设你要自定义一个简单的全连接神经网络(MLP),你需要继承 nn.Module 来实现。这里是一个典型的结构:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()  # 继承父类 nn.Module 的初始化方法

        # 定义网络中的层
        self.fc1 = nn.Linear(2, 4)  # 第一层: 输入2维,输出4维
        self.fc2 = nn.Linear(4, 1)  # 第二层: 输入4维,输出1维

    def forward(self, x):
        # 定义前向传播过程
        x = torch.relu(self.fc1(x))  # 第一层通过ReLU激活函数
        x = self.fc2(x)  # 第二层没有激活函数,直接输出
        return x

# 实例化模型
model = SimpleNN()
# 查看模型结构
print(model)

解读一下代码:

SimpleNN 类继承了 nn.Module,这是所有 PyTorch 模型的基础类。继承 nn.Module 使得我们能够使用 PyTorch 自动求导、优化器以及模型参数管理等功能。

__init__(self)

init() 中,我们定义了网络的层次结构(即各层的类型和大小)。比如,self.fc1 = nn.Linear(2, 4) 表示一个全连接层(fully connected layer),它的输入维度是 2,输出维度是 4。

注意:你需要调用 super().init() 来初始化父类 nn.Module,这样才能让模型正确地处理参数和梯度。

forward(self, *input)

在 forward() 中,我们定义了数据的流动过程。数据通过 fc1 层(第一层)并通过 ReLU 激活函数,然后再经过 fc2 层(第二层)输出结果。

需要注意的是,forward() 方法是我们定义数据如何从输入经过各层计算到输出的地方。

model = SimpleNN()会创建一个 SimpleNN 类的实例,这时模型的层和参数都已经初始化好了。

因此,我们通过继承 nn.Module 来创建模型类,使得我们能够利用 PyTorch 提供的很多便利功能,如自动求导、模型参数管理等。

2.2 训练网络

一旦你定义了模型,就可以使用 PyTorch 的优化器(如 SGD、Adam 等)进行训练了。以下是一个训练的简单框架:

复制代码
# 定义损失函数和优化器
criterion = nn.MSELoss()  # 使用均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用SGD优化器,学习率为0.01

# 模拟输入数据和目标输出数据
inputs = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])  # 假设输入是2维数据
targets = torch.tensor([[1.0], [2.0], [3.0]])  # 假设目标输出是1维数据

# 训练过程
for epoch in range(100):  # 训练100个epoch
    optimizer.zero_grad()  # 每次迭代时,先将梯度清零
    
    # 前向传播
    outputs = model(inputs)  # 计算模型的预测值
    
    # 计算损失
    loss = criterion(outputs, targets)  # 计算损失值
    
    # 反向传播和优化
    loss.backward()  # 计算梯度
    optimizer.step()  # 更新模型参数

    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
相关推荐
Mintopia12 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮13 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬13 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia13 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区14 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两16 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪17 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat2325517 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
王鑫星17 小时前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能