深度学习篇---ShuffleNet网络结构

在 PyTorch 中实现 ShuffleNet(以经典的 ShuffleNet v1 为例),核心是实现它的两个 "招牌动作"------分组卷积通道混洗。我们从最基础的模块开始,一步步搭建,确保你能理解每个操作的作用。

一、先明确 ShuffleNet v1 的核心结构

ShuffleNet v1 的结构可以概括为:

复制代码
输入(224×224彩色图) → 
初始卷积层 → 初始池化层 → 
3个阶段的ShuffleNet单元(每个阶段包含步长=2的下采样单元+多个步长=1的特征提取单元) → 
全局平均池化 → 全连接层(输出1000类)

其中,ShuffleNet 单元是核心,分为 "步长 = 1"(保持尺寸)和 "步长 = 2"(下采样)两种。

二、PyTorch 实现 ShuffleNet v1 的步骤

步骤 1:导入必要的库

和之前实现其他 CNN 一样,先准备好工具:

python 复制代码
import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现核心操作 ------ 通道混洗(Channel Shuffle)

通道混洗是 ShuffleNet 的标志性创新,用于解决分组卷积的 "信息隔绝" 问题。我们先实现这个操作:

python 复制代码
def channel_shuffle(x, groups):
    """
    通道混洗操作:将分组卷积后的通道重新打乱分配
    x: 输入特征图,形状为(batch_size, channels, height, width)
    groups: 分组数量
    """
    batch_size, channels, height, width = x.size()
    
    # 1. 确保通道数能被分组数整除(ShuffleNet的设计要求)
    assert channels % groups == 0, "通道数必须是分组数的整数倍"
    channels_per_group = channels // groups  # 每组的通道数
    
    # 2. 通道混洗的核心步骤:
    # 拆分成 (batch_size, groups, channels_per_group, height, width)
    x = x.view(batch_size, groups, channels_per_group, height, width)
    # 交换groups和channels_per_group维度 → (batch_size, channels_per_group, groups, height, width)
    x = x.transpose(1, 2).contiguous()
    # 重新展平通道维度 → (batch_size, channels, height, width)
    x = x.view(batch_size, -1, height, width)
    
    return x

通俗解释

假设输入是 8 组(groups=8),每组 32 通道(共 256 通道),通道混洗会把 "8 组 ×32 通道" 变成 "32 组 ×8 通道",让新的每组通道都包含原来 8 组的信息,打破分组间的隔绝。

步骤 3:实现 ShuffleNet 的核心单元

ShuffleNet 有两种单元:步长 = 1 的单元 (特征提取,尺寸不变)和步长 = 2 的单元(下采样,尺寸减半)。

3.1 步长 = 1 的 ShuffleNet 单元(特征提取)
python 复制代码
class ShuffleNetUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups, stride=1):
        super(ShuffleNetUnitV1, self).__init__()
        self.stride = stride
        self.groups = groups
        
        # 确保输出通道数是分组数的整数倍
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数(降维用,减少计算量)
        
        # 主分支:1×1分组卷积 → 3×3深度卷积 → 1×1分组卷积
        self.main_branch = nn.Sequential(
            # 1×1分组卷积(降维)
            nn.Conv2d(
                in_channels, mid_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 3×3深度卷积(每个通道单独卷积,步长控制是否下采样)
            nn.Conv2d(
                mid_channels, mid_channels, kernel_size=3, 
                stride=stride, padding=1, groups=mid_channels, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            
            # 1×1分组卷积(升维)
            nn.Conv2d(
                mid_channels, out_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(out_channels),
        )
        
        # 激活函数(放在单元最后)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # 步长=1时,使用残差连接(输入直接加输出)
        if self.stride == 1:
            # 先做通道混洗(关键!让分组卷积的信息流通)
            x = channel_shuffle(x, self.groups)
            # 主分支计算
            out = self.main_branch(x)
            # 残差连接:输入+输出
            out += x
            return self.relu(out)
        # 步长=2时的处理(下采样,后面单独实现)
        else:
            # 先做通道混洗
            x = channel_shuffle(x, self.groups)
            # 主分支计算
            out = self.main_branch(x)
            return out
3.2 步长 = 2 的 ShuffleNet 单元(下采样)

步长 = 2 的单元需要下采样(尺寸减半),因此用 "双分支" 设计:

python 复制代码
class ShuffleNetDownUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(ShuffleNetDownUnitV1, self).__init__()
        self.groups = groups
        
        # 确保输出通道数是分组数的整数倍
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数
        
        # 主分支:1×1分组卷积 → 3×3深度卷积(步长2) → 1×1分组卷积
        self.main_branch = nn.Sequential(
            nn.Conv2d(
                in_channels, mid_channels, kernel_size=1, 
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 步长=2,实现下采样(尺寸减半)
            nn.Conv2d(
                mid_channels, mid_channels, kernel_size=3, 
                stride=2, padding=1, groups=mid_channels, bias=False
            ),
            nn.BatchNorm2d(mid_channels),
            
            nn.Conv2d(
                mid_channels, out_channels - in_channels, kernel_size=1,  # 输出通道数=总通道数-捷径分支通道数
                groups=groups, bias=False
            ),
            nn.BatchNorm2d(out_channels - in_channels),
        )
        
        # 捷径分支:平均池化(步长2,下采样)
        self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        # 激活函数
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # 先做通道混洗
        x = channel_shuffle(x, self.groups)
        # 主分支计算
        main_out = self.main_branch(x)
        # 捷径分支计算(下采样)
        shortcut_out = self.shortcut_branch(x)
        # 拼接两个分支(主分支+捷径分支,总通道数=out_channels)
        out = torch.cat([main_out, shortcut_out], dim=1)
        return self.relu(out)

关键差异

  • 步长 = 1:用残差连接(输入 + 输出),保持通道数不变;
  • 步长 = 2:用双分支拼接(主分支 + 捷径分支),通道数翻倍,尺寸减半。

步骤 4:搭建 ShuffleNet v1 完整网络

用上面定义的单元,按 ShuffleNet v1 的结构搭建完整网络:

python 复制代码
class ShuffleNetV1(nn.Module):
    def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
        super(ShuffleNetV1, self).__init__()
        self.groups = groups
        
        # 基础通道数(根据scale_factor调整模型大小)
        base_channels = [24, 192, 384, 768]
        base_channels = [int(c * scale_factor) for c in base_channels]
        
        # 1. 初始卷积层
        self.features = nn.Sequential(
            nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels[0]),
            nn.ReLU(inplace=True),
            
            # 2. 初始池化层
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 3. 三个阶段的ShuffleNet单元
        # 阶段1:3个步长=1的单元
        self.stage1 = self._make_stage(
            in_channels=base_channels[0],
            out_channels=base_channels[1],
            groups=groups,
            num_units=3
        )
        
        # 阶段2:1个步长=2的下采样单元 + 7个步长=1的单元
        self.stage2 = self._make_stage(
            in_channels=base_channels[1],
            out_channels=base_channels[2],
            groups=groups,
            num_units=7,
            is_downsample=True
        )
        
        # 阶段3:1个步长=2的下采样单元 + 3个步长=1的单元
        self.stage3 = self._make_stage(
            in_channels=base_channels[2],
            out_channels=base_channels[3],
            groups=groups,
            num_units=3,
            is_downsample=True
        )
        
        # 4. 全局平均池化
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 5. 全连接层(输出类别)
        self.classifier = nn.Linear(base_channels[3], num_classes)
    
    def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
        """构建一个阶段的网络(包含多个ShuffleNet单元)"""
        stage = []
        
        # 如果需要下采样,先加一个步长=2的单元
        if is_downsample:
            stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
            in_channels = out_channels  # 更新输入通道数
        
        # 加入num_units个步长=1的单元
        for _ in range(num_units):
            stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
        
        return nn.Sequential(*stage)
    
    def forward(self, x):
        x = self.features(x)         # 初始卷积和池化
        x = self.stage1(x)           # 阶段1
        x = self.stage2(x)           # 阶段2
        x = self.stage3(x)           # 阶段3
        x = self.global_pool(x)      # 全局池化
        x = x.view(x.size(0), -1)    # 拉平成向量
        x = self.classifier(x)       # 全连接层输出
        return x

结构解释

  • 分组数(groups):控制分组卷积的组数(可选 1、2、3、4、8),组数越大,计算量越小(但需平衡精度);
  • 缩放因子(scale_factor):控制通道数(如 0.5、1.0、1.5),用于调整模型大小(0.5 是轻量版,1.5 是高精度版);
  • 阶段设计:每个阶段先通过步长 = 2 的单元下采样(尺寸减半),再用多个步长 = 1 的单元提取特征,逐步提升特征抽象程度。

步骤 5:准备数据(用 CIFAR-10 演示)

ShuffleNet 适合移动端,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

python 复制代码
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([
    transforms.Resize(256),  # 缩放为256×256
    transforms.RandomCrop(224),  # 随机裁剪成224×224
    transforms.RandomHorizontalFlip(),  # 随机翻转(数据增强)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

# 批量加载数据(ShuffleNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

步骤 6:初始化模型、损失函数和优化器

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 分组数=8,缩放因子=1.0,输出10类(CIFAR-10)
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

步骤 7:训练和测试函数

训练逻辑和之前的模型类似:

python 复制代码
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()  # 清空梯度
        output = model(data)   # 模型预测
        loss = criterion(output, target)  # 计算损失
        loss.backward()        # 反向传播
        optimizer.step()       # 更新参数
        
        # 打印进度
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 8:开始训练和测试

ShuffleNet 非常轻量,训练速度很快,这里训练 20 轮:

python 复制代码
for epoch in range(1, 21):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)

在 CIFAR-10 上,ShuffleNet v1(groups=8,scale_factor=1.0)训练充分后准确率能达到 85% 左右,且参数量仅约 2.3 百万(是 MobileNet v1 的 50%,VGG-16 的 1.7%)。

三、完整代码总结

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 实现通道混洗操作
def channel_shuffle(x, groups):
    batch_size, channels, height, width = x.size()
    assert channels % groups == 0, "通道数必须是分组数的整数倍"
    channels_per_group = channels // groups
    
    # 核心混洗步骤:分组→转置→展平
    x = x.view(batch_size, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).contiguous()
    x = x.view(batch_size, -1, height, width)
    
    return x

# 2. 步长=1的ShuffleNet单元(特征提取)
class ShuffleNetUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups, stride=1):
        super(ShuffleNetUnitV1, self).__init__()
        self.stride = stride
        self.groups = groups
        
        assert out_channels % groups == 0
        mid_channels = out_channels // 4  # 中间通道数(降维)
        
        self.main_branch = nn.Sequential(
            # 1×1分组卷积(降维)
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 3×3深度卷积
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, 
                      padding=1, groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            
            # 1×1分组卷积(升维)
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        if self.stride == 1:
            x = channel_shuffle(x, self.groups)  # 通道混洗
            out = self.main_branch(x)
            out += x  # 残差连接
            return self.relu(out)
        else:
            x = channel_shuffle(x, self.groups)
            out = self.main_branch(x)
            return out

# 3. 步长=2的ShuffleNet单元(下采样)
class ShuffleNetDownUnitV1(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(ShuffleNetDownUnitV1, self).__init__()
        self.groups = groups
        
        assert out_channels % groups == 0
        mid_channels = out_channels // 4
        
        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            # 步长=2,下采样
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, 
                      padding=1, groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
            
            nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1, 
                      groups=groups, bias=False),
            nn.BatchNorm2d(out_channels - in_channels),
        )
        
        # 捷径分支:平均池化下采样
        self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = channel_shuffle(x, self.groups)
        main_out = self.main_branch(x)
        shortcut_out = self.shortcut_branch(x)
        out = torch.cat([main_out, shortcut_out], dim=1)  # 拼接分支
        return self.relu(out)

# 4. 搭建ShuffleNet v1完整网络
class ShuffleNetV1(nn.Module):
    def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):
        super(ShuffleNetV1, self).__init__()
        self.groups = groups
        
        # 基础通道数(按缩放因子调整)
        base_channels = [24, 192, 384, 768]
        base_channels = [int(c * scale_factor) for c in base_channels]
        
        # 初始卷积和池化
        self.features = nn.Sequential(
            nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_channels[0]),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 三个阶段的网络
        self.stage1 = self._make_stage(
            in_channels=base_channels[0],
            out_channels=base_channels[1],
            groups=groups,
            num_units=3
        )
        
        self.stage2 = self._make_stage(
            in_channels=base_channels[1],
            out_channels=base_channels[2],
            groups=groups,
            num_units=7,
            is_downsample=True
        )
        
        self.stage3 = self._make_stage(
            in_channels=base_channels[2],
            out_channels=base_channels[3],
            groups=groups,
            num_units=3,
            is_downsample=True
        )
        
        # 全局池化和分类器
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(base_channels[3], num_classes)
    
    def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):
        stage = []
        if is_downsample:
            stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))
            in_channels = out_channels
        for _ in range(num_units):
            stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))
        return nn.Sequential(*stage)
    
    def forward(self, x):
        x = self.features(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 5. 准备CIFAR-10数据
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

# 8. 测试函数
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# 9. 开始训练和测试
for epoch in range(1, 21):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)
    

四、关键知识点回顾

  1. 核心操作:通道混洗通过 "分组→转置→展平" 三步,打破分组卷积的信息隔绝,这是 ShuffleNet 的灵魂;
  2. 单元设计:步长 = 1 的单元用残差连接(输入 + 输出),步长 = 2 的单元用双分支拼接(主分支 + 捷径分支),实现下采样的同时保证特征流通;
  3. 轻量化优势:ShuffleNet v1(groups=8)参数量仅 2.3M,计算量 140MFLOPs,是同等精度模型中资源占用最少的之一;
  4. 灵活配置 :通过调整groups(分组数)和scale_factor(缩放因子),可在 "精度" 和 "效率" 之间灵活权衡,满足不同移动端场景需求。

通过这段代码,你能亲手实现这个 "通道混洗大师" 模型,感受轻量化 CNN 在资源受限设备上的强大潜力!

相关推荐
yueyuebaobaoxinx7 小时前
AI + 机器人:当大语言模型赋予机械 “思考能力”,未来工厂将迎来怎样变革?
人工智能·语言模型·机器人
scott1985127 小时前
世界模型的典型框架与分类
人工智能·计算机视觉·生成式·世界模型
LeeZhao@7 小时前
【项目】多模态RAG必备神器—olmOCR重塑PDF文本提取格局
人工智能·语言模型·自然语言处理·pdf·aigc
Pocker_Spades_A7 小时前
Trae + MCP : 一键生成专业封面
人工智能·mcp·蓝耘
幂简集成explinks7 小时前
GPT-Realtime 弹幕TTS API:低延迟秒开集成实战
人工智能·后端·算法
我希望的一路生花8 小时前
Total PDF Converter多功能 PDF 批量转换工具,无水印 + 高效处理指南
前端·人工智能·3d·adobe·pdf
WuLaHH8 小时前
OpenCV-CUDA 图像处理
图像处理·人工智能·opencv
数据猿8 小时前
网络安全运营迈向AI时代 战场天平将被如何重塑?
人工智能·安全·web安全
IT_陈寒8 小时前
10个Vite配置技巧让你的开发效率提升200%,第7个绝了!
前端·人工智能·后端