从零构建神经网络:PyTorch的nn.Module详解

**一、为什么需要nn.Module?**‌

nn.Module 是 PyTorch 中所有神经网络模型的基类,它提供了以下核心功能:

  • 模块化设计‌:将网络拆分为可重用的层(如卷积层、全连接层)
  • 自动参数管理 ‌:自动跟踪所有可训练参数(parameters()方法)
  • GPU加速支持‌:一键将模型迁移到GPU
  • 模型保存与加载‌:支持序列化模型结构和参数

二、定义你的第一个神经网络

1. 继承nn.Module构建网络

以下代码实现一个3层全连接网络,用于图像分类(以FashionMNIST为例):

python 复制代码
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 设置随机种子保证可重复性
torch.manual_seed(42)

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入层 -> 隐藏层
        self.relu = nn.ReLU()                          # 激活函数
        self.fc2 = nn.Linear(hidden_size, num_classes) # 隐藏层 -> 输出层
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 初始化模型
model = SimpleNN(input_size=784, hidden_size=128, num_classes=10)
print(model)

输出结果:

python 复制代码
SimpleNN(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

三、数据准备与预处理

1. 加载FashionMNIST数据集

python 复制代码
# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),                      # 转换为Tensor
    transforms.Normalize((0.5,), (0.5,))        # 归一化到[-1, 1]
])

# 下载数据集
train_dataset = torchvision.datasets.FashionMNIST(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.FashionMNIST(
    root='./data',
    train=False,
    transform=transform
)

# 创建数据加载器
batch_size = 100
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False
)

# 查看数据集信息
print("训练集大小:", len(train_dataset))
print("测试集大小:", len(test_dataset))
classes = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

四、训练神经网络

1. 配置损失函数与优化器

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失(已包含Softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

2. 训练循环实现

python 复制代码
num_epochs = 10
total_step = len(train_loader)
loss_history = []
acc_history = []

for epoch in range(num_epochs):
    model.train()  # 设置为训练模式(启用Dropout/BatchNorm)
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移动到设备
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计指标
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    # 计算本epoch指标
    epoch_loss = running_loss / total_step
    epoch_acc = 100 * correct / total
    loss_history.append(epoch_loss)
    acc_history.append(epoch_acc)
    
    # 打印训练进度
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Loss: {epoch_loss:.4f}, '
          f'Accuracy: {epoch_acc:.2f}%')

# 可视化训练过程
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(loss_history, label='Training Loss')
plt.title('Loss Curve')
plt.subplot(1,2,2)
plt.plot(acc_history, label='Training Accuracy')
plt.title('Accuracy Curve')
plt.show()

五、模型评估与预测

1. 测试集评估

python 复制代码
model.eval()  # 设置为评估模式(关闭Dropout/BatchNorm)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'测试集准确率: {100 * correct / total:.2f}%')

2. 可视化预测结果

python 复制代码
# 获取测试集样本
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.reshape(-1, 28*28).to(device)

# 预测结果
outputs = model(images)
_, preds = torch.max(outputs, 1)
preds = preds.cpu().numpy()
images = images.cpu().reshape(-1, 28, 28).numpy()

# 绘制预测结果
plt.figure(figsize=(10,8))
for i in range(20):
    plt.subplot(4,5,i+1)
    plt.imshow(images[i], cmap='gray')
    plt.title(f"Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}")
    plt.axis('off')
plt.tight_layout()
plt.show()

六、模型保存与加载

1. 保存整个模型

python 复制代码
torch.save(model, 'fashion_mnist_model.pth')
loaded_model = torch.load('fashion_mnist_model.pth')

2. 只保存参数(推荐方式)

python 复制代码
torch.save(model.state_dict(), 'model_weights.pth')

# 加载时需要先创建相同结构的模型
new_model = SimpleNN(784, 128, 10).to(device)
new_model.load_state_dict(torch.load('model_weights.pth'))

七、常见问题与调试

Q1:输入形状不匹配报错

  • 错误信息:RuntimeError: mat1 and mat2 shapes cannot be multiplied
  • 解决方法:检查输入是否被正确展平,使用x = x.view(-1, input_size)

Q2:模型准确率始终不变

  • 检查是否忘记调用optimizer.zero_grad()
  • 确认参数requires_grad=True(使用nn.Module时会自动处理)

Q3:过拟合问题

  • 添加正则化:在优化器中设置weight_decay=0.01
  • 添加Dropout层:
python 复制代码
self.dropout = nn.Dropout(p=0.5)  # 在__init__中添加
out = self.dropout(out)           # 在forward中添加

八、小结与下篇预告

  • 关键知识点‌:

    1. nn.Module 提供标准化的模型构建方式
    2. 训练流程四要素:数据加载、前向传播、损失计算、反向传播
    3. 模型评估必须使用model.eval()模式
  • 下篇预告 ‌:

    第四篇将实战MNIST手写数字识别,并深入解析数据增强与模型调优技巧!

相关推荐
文心快码BaiduComate16 小时前
百度云与光本位签署战略合作:用AI Agent 重构芯片研发流程
前端·人工智能·架构
风象南17 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia18 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮18 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬18 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia19 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区19 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两1 天前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪1 天前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain