组件 | 作用 |
---|---|
层(Layer) | 网络基本单元,如卷积层(Conv2d)、线性层(Linear),负责张量数据变换 |
模型(Model) | 由多层按逻辑组合而成的整体,实现从输入到输出的映射 |
损失函数 | 衡量预测值与真实值的差距,如交叉熵损失(CrossEntropyLoss),是参数优化的目标 |
优化器 | 通过反向传播更新模型参数以最小化损失,如 Adam、SGD |
PyTorch 模型构建
- nn.Module:可训练参数的 "管理者"
特点:所有带可学习参数的层(如 Conv2d、Linear)均继承自 nn.Module,能自动追踪参数,支持与模型容器结合使用。
用法 :自定义模型需继承 nn.Module,在__init__
中定义层,在forward
中实现前向传播逻辑。
示例:定义一个简单线性层模块
python运行
import torch.nn as nn
class SimpleLinear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim) # 可学习参数由nn.Module管理
def forward(self, x):
return self.linear(x)
nn.functional:纯函数式工具
特点:无参数的 "纯函数" 集合,如激活函数(ReLU)、池化(max_pool2d),需手动传入参数(若有),无法与模型容器直接结合。
注意 :dropout 操作若用 nn.functional 实现,需手动区分训练 / 测试模式;而 nn.Dropout(继承自 nn.Module)可通过model.eval()
自动切换状态。
三种模型构建方法
1. 直接继承 nn.Module:最灵活
适用于复杂网络结构,需手动定义每一层的连接逻辑。例如构建含批归一化的全连接网络:
python运行
import torch.nn.functional as F
class FCModel(nn.Module):
def __init__(self, in_dim=784, n_hidden=300, out_dim=10):
super().__init__()
self.flatten = nn.Flatten() # 展平28*28图像
self.linear1 = nn.Linear(in_dim, n_hidden)
self.bn1 = nn.BatchNorm1d(n_hidden) # 批归一化
def forward(self, x):
x = self.flatten(x)
x = F.relu(self.bn1(self.linear1(x))) # 前向传播逻辑
return x
2. nn.Sequential:按序堆叠,快速高效
适合层与层按顺序连接的简单网络,支持三种定义方式:
可变参数:直接传入层实例,无需命名
python运行
seq = nn.Sequential(nn.Flatten(), nn.Linear(784, 300), nn.ReLU())
-
add_module:为每层指定名称,便于后续查看 python运行
seq = nn.Sequential() seq.add_module("flatten", nn.Flatten()) seq.add_module("linear1", nn.Linear(784, 300))
-
OrderedDict:用有序字典定义,兼顾顺序与命名 python运行
from collections import OrderedDict seq = nn.Sequential(OrderedDict([ ("flatten", nn.Flatten()), ("linear1", nn.Linear(784, 300)) ]))
nn.Sequential 封装残差块:
python运行
class ResBlockWrapper(nn.Module):
def __init__(self):
super().__init__()
# 用nn.Sequential封装残差块内的层
self.res_block = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64)
)
def forward(self, x):
return F.relu(x + self.res_block(x))
从自定义模块到 ResNet18
1. 定义两种残差块
ResNet18 包含两种残差块,分别处理 "维度不变" 和 "维度下采样" 场景:
python运行
class RestNetBasicBlock(nn.Module):
# 基础残差块:输入输出维度一致,无需额外调整
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return F.relu(x + out) # 残差连接
class RestNetDownBlock(nn.Module):
# 下采样残差块:用1×1卷积调整输入维度,适配残差连接
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride[0], padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride[1], padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1×1卷积调整输入通道和分辨率
self.extra = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride[0]),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
extra_x = self.extra(x) # 维度调整
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return F.relu(extra_x + out)
2. 组合成 ResNet18 架构
基于两种残差块,按 "初始卷积→残差层→全局池化→全连接" 的顺序构建 ResNet18,适配 3 通道的人脸图像:
python运行
class RestNet18(nn.Module):
def __init__(self, num_classes): # num_classes:人脸类别数
super().__init__()
# 初始层:降维+下采样
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
# 4个残差层:2个基础块+2个下采样块组合
self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), RestNetBasicBlock(64, 64, 1))
self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2,1]), RestNetBasicBlock(128, 128, 1))
self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2,1]), RestNetBasicBlock(256, 256, 1))
self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2,1]), RestNetBasicBlock(512, 512, 1))
# 分类头:全局平均池化+全连接
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
# 前向传播:按层顺序执行
x = self.bn1(self.conv1(x))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1) # 展平为一维向量
return self.fc(x)
总结
PyTorch 模型构建的核心在于 "灵活组合":通过 nn.Module 管理可训练参数,用 nn.Sequential 等容器简化层连接,结合自定义模块(如残差块)可实现复杂架构。从基础全连接网络到 ResNet18