深度学习篇---ShuffleNet网络结构

在 PyTorch 中实现 ShuffleNet(以经典的 ShuffleNet v1 为例),核心是实现它的两个 "招牌动作"------分组卷积通道混洗。我们从最基础的模块开始,一步步搭建,确保你能理解每个操作的作用。

一、先明确 ShuffleNet v1 的核心结构

ShuffleNet v1 的结构可以概括为:

复制代码
输入(224×224彩色图) → 
初始卷积层 → 初始池化层 → 
3个阶段的ShuffleNet单元(每个阶段包含步长=2的下采样单元+多个步长=1的特征提取单元) → 
全局平均池化 → 全连接层(输出1000类)

其中,ShuffleNet 单元是核心,分为 "步长 = 1"(保持尺寸)和 "步长 = 2"(下采样)两种。

二、PyTorch 实现 ShuffleNet v1 的步骤

步骤 1:导入必要的库

和之前实现其他 CNN 一样,先准备好工具:

python 复制代码
import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现核心操作 ------ 通道混洗(Channel Shuffle)

通道混洗是 ShuffleNet 的标志性创新,用于解决分组卷积的 "信息隔绝" 问题。我们先实现这个操作:

python 复制代码
def channel_shuffle(x, groups):
    """
    通道混洗操作:将分组卷积后的通道重新打乱分配
    x: 输入特征图,形状为(batch_size, channels, height, width)
    groups: 分组数量
    """
    batch_size, channels, height, width = x.size()
    
    # 1. 确保通道数能被分组数整除(ShuffleNet的设计要求)
    assert channels % groups == 0, "通道数必须是分组数的整数倍"
    channels_per_group = channels // groups  # 每组的通道数
    
    # 2. 通道混洗的核心步骤:
    # 拆分成 (batch_size, groups, channels_per_group, height, width)
    x = x.view(batch_size, groups, channels_per_group, height, width)
    # 交换groups和channels_per_group维度 → (batch_size, channels_per_group, groups, height, width)
    x = x.transpose(1, 2).contiguous()
    # 重新展平通道维度 → (batch_size, channels, height, width)
    x = x.view(batch_size, -1, height, width)
    
    return x

通俗解释

假设输入是 8 组(groups=8),每组 32 通道(共 256 通道),通道混洗会把 "8 组 ×32 通道" 变成 "32 组 ×8 通道",让新的每组通道都包含原来 8 组的信息,打破分组间的隔绝。

步骤 3:实现 ShuffleNet 的核心单元

ShuffleNet 有两种单元:步长 = 1 的单元 (特征提取,尺寸不变)和步长 = 2 的单元(下采样,尺寸减半)。

3.1 步长 = 1 的 ShuffleNet 单元(特征提取)
python 复制代码
class ShuffleNetUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups, stride=1):
        super(ShuffleNetUnitV1, self).__init__()
        self.stride = stride
        self.groups = groups
        
        # 确保输出通道数是分组数的整数倍
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数(降维用,减少计算量)
        
        # 主分支:1×1分组卷积 → 3×3深度卷积 → 1×1分组卷积
        self.main_branch = nn.Sequential(
            # 1×1分组卷积(降维)
            nn.Conv2d(
                in_channels, mid_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 3×3深度卷积(每个通道单独卷积,步长控制是否下采样)
            nn.Conv2d(
                mid_channels, mid_channels, kernel_size=3, 
                stride=stride, padding=1, groups=mid_channels, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            
            # 1×1分组卷积(升维)
            nn.Conv2d(
                mid_channels, out_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(out_channels),
        )
        
        # 激活函数(放在单元最后)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # 步长=1时,使用残差连接(输入直接加输出)
        if self.stride == 1:
            # 先做通道混洗(关键!让分组卷积的信息流通)
            x = channel_shuffle(x, self.groups)
            # 主分支计算
            out = self.main_branch(x)
            # 残差连接:输入+输出
            out += x
            return self.relu(out)
        # 步长=2时的处理(下采样,后面单独实现)
        else:
            # 先做通道混洗
            x = channel_shuffle(x, self.groups)
            # 主分支计算
            out = self.main_branch(x)
            return out
3.2 步长 = 2 的 ShuffleNet 单元(下采样)

步长 = 2 的单元需要下采样(尺寸减半),因此用 "双分支" 设计:

python 复制代码
class ShuffleNetDownUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(ShuffleNetDownUnitV1, self).__init__()
        self.groups = groups
        
        # 确保输出通道数是分组数的整数倍
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数
        
        # 主分支:1×1分组卷积 → 3×3深度卷积(步长2) → 1×1分组卷积
        self.main_branch = nn.Sequential(
            nn.Conv2d(
                in_channels, mid_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 步长=2,实现下采样(尺寸减半)
            nn.Conv2d(
                mid_channels, mid_channels, kernel_size=3, 
                stride=2, padding=1, groups=mid_channels, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            
            nn.Conv2d(
                mid_channels, out_channels - in_channels, kernel_size=1,  # 输出通道数=总通道数-捷径分支通道数
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(out_channels - in_channels),
        )
        
        # 捷径分支:平均池化(步长2,下采样)
        self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        # 激活函数
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # 先做通道混洗
        x = channel_shuffle(x, self.groups)
        # 主分支计算
        main_out = self.main_branch(x)
        # 捷径分支计算(下采样)
        shortcut_out = self.shortcut_branch(x)
        # 拼接两个分支(主分支+捷径分支,总通道数=out_channels)
        out = torch.cat([main_out, shortcut_out], dim=1)
        return self.relu(out)

关键差异

  • 步长 = 1:用残差连接(输入 + 输出),保持通道数不变;
  • 步长 = 2:用双分支拼接(主分支 + 捷径分支),通道数翻倍,尺寸减半。

步骤 4:搭建 ShuffleNet v1 完整网络

用上面定义的单元,按 ShuffleNet v1 的结构搭建完整网络:

python 复制代码
class ShuffleNetV1(nn.Module):
    def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
        super(ShuffleNetV1, self).__init__()
        self.groups = groups
        
        # 基础通道数(根据scale_factor调整模型大小)
        base_channels = [24, 192, 384, 768]
        base_channels = [int(c * scale_factor) for c in base_channels]
        
        # 1. 初始卷积层
        self.features = nn.Sequential(
            nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels[0]),
            nn.ReLU(inplace=True),
            
            # 2. 初始池化层
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 3. 三个阶段的ShuffleNet单元
        # 阶段1:3个步长=1的单元
        self.stage1 = self._make_stage(
            in_channels=base_channels[0],
            out_channels=base_channels[1],
            groups=groups,
            num_units=3
        )
        
        # 阶段2:1个步长=2的下采样单元 + 7个步长=1的单元
        self.stage2 = self._make_stage(
            in_channels=base_channels[1],
            out_channels=base_channels[2],
            groups=groups,
            num_units=7,
            is_downsample=True
        )
        
        # 阶段3:1个步长=2的下采样单元 + 3个步长=1的单元
        self.stage3 = self._make_stage(
            in_channels=base_channels[2],
            out_channels=base_channels[3],
            groups=groups,
            num_units=3,
            is_downsample=True
        )
        
        # 4. 全局平均池化
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 5. 全连接层(输出类别)
        self.classifier = nn.Linear(base_channels[3], num_classes)
    
    def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
        """构建一个阶段的网络(包含多个ShuffleNet单元)"""
        stage = []
        
        # 如果需要下采样,先加一个步长=2的单元
        if is_downsample:
            stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
            in_channels = out_channels  # 更新输入通道数
        
        # 加入num_units个步长=1的单元
        for _ in range(num_units):
            stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
        
        return nn.Sequential(*stage)
    
    def forward(self, x):
        x = self.features(x)         # 初始卷积和池化
        x = self.stage1(x)           # 阶段1
        x = self.stage2(x)           # 阶段2
        x = self.stage3(x)           # 阶段3
        x = self.global_pool(x)      # 全局池化
        x = x.view(x.size(0), -1)    # 拉平成向量
        x = self.classifier(x)       # 全连接层输出
        return x

结构解释

  • 分组数(groups):控制分组卷积的组数(可选 1、2、3、4、8),组数越大,计算量越小(但需平衡精度);
  • 缩放因子(scale_factor):控制通道数(如 0.5、1.0、1.5),用于调整模型大小(0.5 是轻量版,1.5 是高精度版);
  • 阶段设计:每个阶段先通过步长 = 2 的单元下采样(尺寸减半),再用多个步长 = 1 的单元提取特征,逐步提升特征抽象程度。

步骤 5:准备数据(用 CIFAR-10 演示)

ShuffleNet 适合移动端,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

python 复制代码
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([
    transforms.Resize(256),  # 缩放为256×256
    transforms.RandomCrop(224),  # 随机裁剪成224×224
    transforms.RandomHorizontalFlip(),  # 随机翻转(数据增强)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

# 批量加载数据(ShuffleNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

步骤 6:初始化模型、损失函数和优化器

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 分组数=8,缩放因子=1.0,输出10类(CIFAR-10)
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

步骤 7:训练和测试函数

训练逻辑和之前的模型类似:

python 复制代码
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()  # 清空梯度
        output = model(data)   # 模型预测
        loss = criterion(output, target)  # 计算损失
        loss.backward()        # 反向传播
        optimizer.step()       # 更新参数
        
        # 打印进度
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 8:开始训练和测试

ShuffleNet 非常轻量,训练速度很快,这里训练 20 轮:

python 复制代码
for epoch in range(1, 21):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)

在 CIFAR-10 上,ShuffleNet v1(groups=8,scale_factor=1.0)训练充分后准确率能达到 85% 左右,且参数量仅约 2.3 百万(是 MobileNet v1 的 50%,VGG-16 的 1.7%)。

三、完整代码总结

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 实现通道混洗操作
def channel_shuffle(x, groups):
    batch_size, channels, height, width = x.size()
    assert channels % groups == 0, "通道数必须是分组数的整数倍"
    channels_per_group = channels // groups
    
    # 核心混洗步骤:分组→转置→展平
    x = x.view(batch_size, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).contiguous()
    x = x.view(batch_size, -1, height, width)
    
    return x

# 2. 步长=1的ShuffleNet单元(特征提取)
class ShuffleNetUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups, stride=1):
        super(ShuffleNetUnitV1, self).__init__()
        self.stride = stride
        self.groups = groups
        
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数(降维)
        
        self.main_branch = nn.Sequential(
            # 1×1分组卷积(降维)
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 3×3深度卷积
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, 
                      padding=1, groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            
            # 1×1分组卷积(升维)
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        if self.stride == 1:
            x = channel_shuffle(x, self.groups)  # 通道混洗
            out = self.main_branch(x)
            out += x  # 残差连接
            return self.relu(out)
        else:
            x = channel_shuffle(x, self.groups)
            out = self.main_branch(x)
            return out

# 3. 步长=2的ShuffleNet单元(下采样)
class ShuffleNetDownUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(ShuffleNetDownUnitV1, self).__init__()
        self.groups = groups
        
        assert out_channels % groups == 0
        mid_channels = out_channels // 4
        
        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 步长=2,下采样
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, 
                      padding=1, groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            
            nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1, 
                      groups=groups, bias=False),
            nn.BatchNorm2d(out_channels - in_channels),
        )
        
        # 捷径分支:平均池化下采样
        self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = channel_shuffle(x, self.groups)
        main_out = self.main_branch(x)
        shortcut_out = self.shortcut_branch(x)
        out = torch.cat([main_out, shortcut_out], dim=1)  # 拼接分支
        return self.relu(out)

# 4. 搭建ShuffleNet v1完整网络
class ShuffleNetV1(nn.Module):
    def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
        super(ShuffleNetV1, self).__init__()
        self.groups = groups
        
        # 基础通道数(按缩放因子调整)
        base_channels = [24, 192, 384, 768]
        base_channels = [int(c * scale_factor) for c in base_channels]
        
        # 初始卷积和池化
        self.features = nn.Sequential(
            nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels[0]),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 三个阶段的网络
        self.stage1 = self._make_stage(
            in_channels=base_channels[0],
            out_channels=base_channels[1],
            groups=groups,
            num_units=3
        )
        
        self.stage2 = self._make_stage(
            in_channels=base_channels[1],
            out_channels=base_channels[2],
            groups=groups,
            num_units=7,
            is_downsample=True
        )
        
        self.stage3 = self._make_stage(
            in_channels=base_channels[2],
            out_channels=base_channels[3],
            groups=groups,
            num_units=3,
            is_downsample=True
        )
        
        # 全局池化和分类器
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(base_channels[3], num_classes)
    
    def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
        stage = []
        if is_downsample:
            stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
            in_channels = out_channels
        for _ in range(num_units):
            stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
        return nn.Sequential(*stage)
    
    def forward(self, x):
        x = self.features(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 5. 准备CIFAR-10数据
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

# 8. 测试函数
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# 9. 开始训练和测试
for epoch in range(1, 21):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)
    

四、关键知识点回顾

  1. 核心操作:通道混洗通过 "分组→转置→展平" 三步,打破分组卷积的信息隔绝,这是 ShuffleNet 的灵魂;
  2. 单元设计:步长 = 1 的单元用残差连接(输入 + 输出),步长 = 2 的单元用双分支拼接(主分支 + 捷径分支),实现下采样的同时保证特征流通;
  3. 轻量化优势:ShuffleNet v1(groups=8)参数量仅 2.3M,计算量 140MFLOPs,是同等精度模型中资源占用最少的之一;
  4. 灵活配置 :通过调整groups(分组数)和scale_factor(缩放因子),可在 "精度" 和 "效率" 之间灵活权衡,满足不同移动端场景需求。

通过这段代码,你能亲手实现这个 "通道混洗大师" 模型,感受轻量化 CNN 在资源受限设备上的强大潜力!

相关推荐
Dfreedom.15 分钟前
随机裁剪 vs. 中心裁剪:深度学习中图像预处理的核心技术解析
图像处理·人工智能·深度学习·计算机视觉
brightendavid16 分钟前
再次深入学习深度学习|花书笔记3
深度学习
Baihai_IDP19 分钟前
上下文工程实施过程中会遇到什么挑战?有哪些优化策略?
人工智能·llm·aigc
audyxiao00134 分钟前
一文可视化分析2025年8月arXiv机器学习前沿热点
人工智能·机器学习·arxiv
胖达不服输40 分钟前
「日拱一码」098 机器学习可解释——PDP分析
人工智能·机器学习·机器学习可解释·pdp分析·部分依赖图
未来智慧谷1 小时前
华为发布星河AI广域网解决方案,四大核心能力支撑确定性网络
人工智能·华为·星河ai广域·未来智慧谷
径硕科技JINGdigital1 小时前
工业制造行业营销型 AI Agent 软件排名及服务商推荐
大数据·人工智能
亿信华辰软件1 小时前
装备制造企业支撑智能制造的全生命周期数据治理实践
大数据·人工智能
stjiejieto2 小时前
手机中的轻量化 AI 算法:智能生活的幕后英雄
人工智能·算法·智能手机
qyz_hr2 小时前
国企人力成本管控:红海云eHR系统如何重构大型国有企业编制与预算控制体系
大数据·人工智能·重构