在 PyTorch 中实现 ShuffleNet(以经典的 ShuffleNet v1 为例),核心是实现它的两个 "招牌动作"------分组卷积 和通道混洗。我们从最基础的模块开始,一步步搭建,确保你能理解每个操作的作用。
一、先明确 ShuffleNet v1 的核心结构
ShuffleNet v1 的结构可以概括为:
输入(224×224彩色图) →
初始卷积层 → 初始池化层 →
3个阶段的ShuffleNet单元(每个阶段包含步长=2的下采样单元+多个步长=1的特征提取单元) →
全局平均池化 → 全连接层(输出1000类)
其中,ShuffleNet 单元是核心,分为 "步长 = 1"(保持尺寸)和 "步长 = 2"(下采样)两种。
二、PyTorch 实现 ShuffleNet v1 的步骤
步骤 1:导入必要的库
和之前实现其他 CNN 一样,先准备好工具:
python
import torch # 核心库
import torch.nn as nn # 神经网络层
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader # 数据加载器
from torchvision import datasets, transforms # 图像数据处理
步骤 2:实现核心操作 ------ 通道混洗(Channel Shuffle)
通道混洗是 ShuffleNet 的标志性创新,用于解决分组卷积的 "信息隔绝" 问题。我们先实现这个操作:
python
def channel_shuffle(x, groups):
"""
通道混洗操作:将分组卷积后的通道重新打乱分配
x: 输入特征图,形状为(batch_size, channels, height, width)
groups: 分组数量
"""
batch_size, channels, height, width = x.size()
# 1. 确保通道数能被分组数整除(ShuffleNet的设计要求)
assert channels % groups == 0, "通道数必须是分组数的整数倍"
channels_per_group = channels // groups # 每组的通道数
# 2. 通道混洗的核心步骤:
# 拆分成 (batch_size, groups, channels_per_group, height, width)
x = x.view(batch_size, groups, channels_per_group, height, width)
# 交换groups和channels_per_group维度 → (batch_size, channels_per_group, groups, height, width)
x = x.transpose(1, 2).contiguous()
# 重新展平通道维度 → (batch_size, channels, height, width)
x = x.view(batch_size, -1, height, width)
return x
通俗解释 :
假设输入是 8 组(groups=8),每组 32 通道(共 256 通道),通道混洗会把 "8 组 ×32 通道" 变成 "32 组 ×8 通道",让新的每组通道都包含原来 8 组的信息,打破分组间的隔绝。
步骤 3:实现 ShuffleNet 的核心单元
ShuffleNet 有两种单元:步长 = 1 的单元 (特征提取,尺寸不变)和步长 = 2 的单元(下采样,尺寸减半)。
3.1 步长 = 1 的 ShuffleNet 单元(特征提取)
python
class ShuffleNetUnitV1(nn.Module):
def __init__(self, in_channels, out_channels, groups, stride=1):
super(ShuffleNetUnitV1, self).__init__()
self.stride = stride
self.groups = groups
# 确保输出通道数是分组数的整数倍
assert out_channels % groups == 0
mid_channels = out_channels // 4 # 中间通道数(降维用,减少计算量)
# 主分支:1×1分组卷积 → 3×3深度卷积 → 1×1分组卷积
self.main_branch = nn.Sequential(
# 1×1分组卷积(降维)
nn.Conv2d(
in_channels, mid_channels, kernel_size=1,
groups=groups, bias=False
),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# 3×3深度卷积(每个通道单独卷积,步长控制是否下采样)
nn.Conv2d(
mid_channels, mid_channels, kernel_size=3,
stride=stride, padding=1, groups=mid_channels, bias=False
),
nn.BatchNorm2d(mid_channels),
# 1×1分组卷积(升维)
nn.Conv2d(
mid_channels, out_channels, kernel_size=1,
groups=groups, bias=False
),
nn.BatchNorm2d(out_channels),
)
# 激活函数(放在单元最后)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 步长=1时,使用残差连接(输入直接加输出)
if self.stride == 1:
# 先做通道混洗(关键!让分组卷积的信息流通)
x = channel_shuffle(x, self.groups)
# 主分支计算
out = self.main_branch(x)
# 残差连接:输入+输出
out += x
return self.relu(out)
# 步长=2时的处理(下采样,后面单独实现)
else:
# 先做通道混洗
x = channel_shuffle(x, self.groups)
# 主分支计算
out = self.main_branch(x)
return out
3.2 步长 = 2 的 ShuffleNet 单元(下采样)
步长 = 2 的单元需要下采样(尺寸减半),因此用 "双分支" 设计:
python
class ShuffleNetDownUnitV1(nn.Module):
def __init__(self, in_channels, out_channels, groups):
super(ShuffleNetDownUnitV1, self).__init__()
self.groups = groups
# 确保输出通道数是分组数的整数倍
assert out_channels % groups == 0
mid_channels = out_channels // 4 # 中间通道数
# 主分支:1×1分组卷积 → 3×3深度卷积(步长2) → 1×1分组卷积
self.main_branch = nn.Sequential(
nn.Conv2d(
in_channels, mid_channels, kernel_size=1,
groups=groups, bias=False
),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# 步长=2,实现下采样(尺寸减半)
nn.Conv2d(
mid_channels, mid_channels, kernel_size=3,
stride=2, padding=1, groups=mid_channels, bias=False
),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(
mid_channels, out_channels - in_channels, kernel_size=1, # 输出通道数=总通道数-捷径分支通道数
groups=groups, bias=False
),
nn.BatchNorm2d(out_channels - in_channels),
)
# 捷径分支:平均池化(步长2,下采样)
self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
# 激活函数
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 先做通道混洗
x = channel_shuffle(x, self.groups)
# 主分支计算
main_out = self.main_branch(x)
# 捷径分支计算(下采样)
shortcut_out = self.shortcut_branch(x)
# 拼接两个分支(主分支+捷径分支,总通道数=out_channels)
out = torch.cat([main_out, shortcut_out], dim=1)
return self.relu(out)
关键差异:
- 步长 = 1:用残差连接(输入 + 输出),保持通道数不变;
- 步长 = 2:用双分支拼接(主分支 + 捷径分支),通道数翻倍,尺寸减半。
步骤 4:搭建 ShuffleNet v1 完整网络
用上面定义的单元,按 ShuffleNet v1 的结构搭建完整网络:
python
class ShuffleNetV1(nn.Module):
def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
super(ShuffleNetV1, self).__init__()
self.groups = groups
# 基础通道数(根据scale_factor调整模型大小)
base_channels = [24, 192, 384, 768]
base_channels = [int(c * scale_factor) for c in base_channels]
# 1. 初始卷积层
self.features = nn.Sequential(
nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels[0]),
nn.ReLU(inplace=True),
# 2. 初始池化层
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 3. 三个阶段的ShuffleNet单元
# 阶段1:3个步长=1的单元
self.stage1 = self._make_stage(
in_channels=base_channels[0],
out_channels=base_channels[1],
groups=groups,
num_units=3
)
# 阶段2:1个步长=2的下采样单元 + 7个步长=1的单元
self.stage2 = self._make_stage(
in_channels=base_channels[1],
out_channels=base_channels[2],
groups=groups,
num_units=7,
is_downsample=True
)
# 阶段3:1个步长=2的下采样单元 + 3个步长=1的单元
self.stage3 = self._make_stage(
in_channels=base_channels[2],
out_channels=base_channels[3],
groups=groups,
num_units=3,
is_downsample=True
)
# 4. 全局平均池化
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
# 5. 全连接层(输出类别)
self.classifier = nn.Linear(base_channels[3], num_classes)
def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
"""构建一个阶段的网络(包含多个ShuffleNet单元)"""
stage = []
# 如果需要下采样,先加一个步长=2的单元
if is_downsample:
stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
in_channels = out_channels # 更新输入通道数
# 加入num_units个步长=1的单元
for _ in range(num_units):
stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
return nn.Sequential(*stage)
def forward(self, x):
x = self.features(x) # 初始卷积和池化
x = self.stage1(x) # 阶段1
x = self.stage2(x) # 阶段2
x = self.stage3(x) # 阶段3
x = self.global_pool(x) # 全局池化
x = x.view(x.size(0), -1) # 拉平成向量
x = self.classifier(x) # 全连接层输出
return x
结构解释:
- 分组数(groups):控制分组卷积的组数(可选 1、2、3、4、8),组数越大,计算量越小(但需平衡精度);
- 缩放因子(scale_factor):控制通道数(如 0.5、1.0、1.5),用于调整模型大小(0.5 是轻量版,1.5 是高精度版);
- 阶段设计:每个阶段先通过步长 = 2 的单元下采样(尺寸减半),再用多个步长 = 1 的单元提取特征,逐步提升特征抽象程度。
步骤 5:准备数据(用 CIFAR-10 演示)
ShuffleNet 适合移动端,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:
python
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([
transforms.Resize(256), # 缩放为256×256
transforms.RandomCrop(224), # 随机裁剪成224×224
transforms.RandomHorizontalFlip(), # 随机翻转(数据增强)
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
# 批量加载数据(ShuffleNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
步骤 6:初始化模型、损失函数和优化器
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 分组数=8,缩放因子=1.0,输出10类(CIFAR-10)
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
步骤 7:训练和测试函数
训练逻辑和之前的模型类似:
python
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad() # 清空梯度
output = model(data) # 模型预测
loss = criterion(output, target) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 打印进度
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
步骤 8:开始训练和测试
ShuffleNet 非常轻量,训练速度很快,这里训练 20 轮:
python
for epoch in range(1, 21):
train(model, train_loader, criterion, optimizer, epoch)
test(model, test_loader)
在 CIFAR-10 上,ShuffleNet v1(groups=8,scale_factor=1.0)训练充分后准确率能达到 85% 左右,且参数量仅约 2.3 百万(是 MobileNet v1 的 50%,VGG-16 的 1.7%)。
三、完整代码总结
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 实现通道混洗操作
def channel_shuffle(x, groups):
batch_size, channels, height, width = x.size()
assert channels % groups == 0, "通道数必须是分组数的整数倍"
channels_per_group = channels // groups
# 核心混洗步骤:分组→转置→展平
x = x.view(batch_size, groups, channels_per_group, height, width)
x = x.transpose(1, 2).contiguous()
x = x.view(batch_size, -1, height, width)
return x
# 2. 步长=1的ShuffleNet单元(特征提取)
class ShuffleNetUnitV1(nn.Module):
def __init__(self, in_channels, out_channels, groups, stride=1):
super(ShuffleNetUnitV1, self).__init__()
self.stride = stride
self.groups = groups
assert out_channels % groups == 0
mid_channels = out_channels // 4 # 中间通道数(降维)
self.main_branch = nn.Sequential(
# 1×1分组卷积(降维)
nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# 3×3深度卷积
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride,
padding=1, groups=mid_channels, bias=False),
nn.BatchNorm2d(mid_channels),
# 1×1分组卷积(升维)
nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if self.stride == 1:
x = channel_shuffle(x, self.groups) # 通道混洗
out = self.main_branch(x)
out += x # 残差连接
return self.relu(out)
else:
x = channel_shuffle(x, self.groups)
out = self.main_branch(x)
return out
# 3. 步长=2的ShuffleNet单元(下采样)
class ShuffleNetDownUnitV1(nn.Module):
def __init__(self, in_channels, out_channels, groups):
super(ShuffleNetDownUnitV1, self).__init__()
self.groups = groups
assert out_channels % groups == 0
mid_channels = out_channels // 4
self.main_branch = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# 步长=2,下采样
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2,
padding=1, groups=mid_channels, bias=False),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1,
groups=groups, bias=False),
nn.BatchNorm2d(out_channels - in_channels),
)
# 捷径分支:平均池化下采样
self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = channel_shuffle(x, self.groups)
main_out = self.main_branch(x)
shortcut_out = self.shortcut_branch(x)
out = torch.cat([main_out, shortcut_out], dim=1) # 拼接分支
return self.relu(out)
# 4. 搭建ShuffleNet v1完整网络
class ShuffleNetV1(nn.Module):
def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
super(ShuffleNetV1, self).__init__()
self.groups = groups
# 基础通道数(按缩放因子调整)
base_channels = [24, 192, 384, 768]
base_channels = [int(c * scale_factor) for c in base_channels]
# 初始卷积和池化
self.features = nn.Sequential(
nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels[0]),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 三个阶段的网络
self.stage1 = self._make_stage(
in_channels=base_channels[0],
out_channels=base_channels[1],
groups=groups,
num_units=3
)
self.stage2 = self._make_stage(
in_channels=base_channels[1],
out_channels=base_channels[2],
groups=groups,
num_units=7,
is_downsample=True
)
self.stage3 = self._make_stage(
in_channels=base_channels[2],
out_channels=base_channels[3],
groups=groups,
num_units=3,
is_downsample=True
)
# 全局池化和分类器
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(base_channels[3], num_classes)
def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
stage = []
if is_downsample:
stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
in_channels = out_channels
for _ in range(num_units):
stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
return nn.Sequential(*stage)
def forward(self, x):
x = self.features(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# 5. 准备CIFAR-10数据
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# 8. 测试函数
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
# 9. 开始训练和测试
for epoch in range(1, 21):
train(model, train_loader, criterion, optimizer, epoch)
test(model, test_loader)
四、关键知识点回顾
- 核心操作:通道混洗通过 "分组→转置→展平" 三步,打破分组卷积的信息隔绝,这是 ShuffleNet 的灵魂;
- 单元设计:步长 = 1 的单元用残差连接(输入 + 输出),步长 = 2 的单元用双分支拼接(主分支 + 捷径分支),实现下采样的同时保证特征流通;
- 轻量化优势:ShuffleNet v1(groups=8)参数量仅 2.3M,计算量 140MFLOPs,是同等精度模型中资源占用最少的之一;
- 灵活配置 :通过调整
groups
(分组数)和scale_factor
(缩放因子),可在 "精度" 和 "效率" 之间灵活权衡,满足不同移动端场景需求。
通过这段代码,你能亲手实现这个 "通道混洗大师" 模型,感受轻量化 CNN 在资源受限设备上的强大潜力!