1. ShuffleNet 网络简介
ShuffleNet 是旷视科技提出的一种高效卷积神经网络模型,旨在使用有限的计算资源实现高模型精度。其核心思想包括:
- Pointwise Group Convolution(逐点分组卷积):通过对通道进行分组,每组卷积核仅处理输入特征图的一部分通道,从而降低计算量。
- Channel Shuffle(通道重排):解决 Group Convolution 导致的信息交流不足的问题,通过通道重排增强信息交互。
2. 网络架构
2.1 Pointwise Group Convolution
-
分组卷积(Group Convolution):将卷积核分组,每组处理输入特征图的部分通道。相比于标准卷积,分组卷积的参数量减少,计算效率提高。
-
深度可分离卷积(Depthwise Convolution):每个卷积核只处理一个输入通道,计算量大幅降低。
-
逐点分组卷积(Pointwise Group Convolution):每组的卷积核为1×1卷积,进一步减少计算量。
pythonfrom mindspore import nn import mindspore.ops as ops from mindspore import Tensor class GroupConv(nn.Cell): def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False): super(GroupConv, self).__init__() self.groups = groups self.convs = nn.CellList() for _ in range(groups): self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups, kernel_size=kernel_size, stride=stride, has_bias=has_bias, padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform')) def construct(self, x): features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1) outputs = () for i in range(self.groups): outputs = outputs + (self.convs[i](features[i].astype("float32")),) out = ops.cat(outputs, axis=1) return out
2.2 Channel Shuffle
-
Channel Shuffle:将不同组别通道均匀分散重组,使得下一层能处理不同组别通道的信息,从而提高网络的特征提取能力。
pythonclass ShuffleV1Block(nn.Cell): def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride): super(ShuffleV1Block, self).__init__() self.stride = stride pad = ksize // 2 self.group = group if stride == 2: outputs = oup - inp else: outputs = oup self.relu = nn.ReLU() branch_main_1 = [ GroupConv(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, pad_mode="pad", pad=0, groups=1 if first_group else group), nn.BatchNorm2d(mid_channels), nn.ReLU(), ] branch_main_2 = [ nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride, pad_mode='pad', padding=pad, group=mid_channels, weight_init='xavier_uniform', has_bias=False), nn.BatchNorm2d(mid_channels), GroupConv(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, pad_mode="pad", pad=0, groups=group), nn.BatchNorm2d(outputs), ] self.branch_main_1 = nn.SequentialCell(branch_main_1) self.branch_main_2 = nn.SequentialCell(branch_main_2) if stride == 2: self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same') def construct(self, old_x): left = old_x right = old_x out = old_x right = self.branch_main_1(right) if self.group > 1: right = self.channel_shuffle(right) right = self.branch_main_2(right) if self.stride == 1: out = self.relu(left + right) elif self.stride == 2: left = self.branch_proj(left) out = ops.cat((left, right), 1) out = self.relu(out) return out def channel_shuffle(self, x): batchsize, num_channels, height, width = ops.shape(x) group_channels = num_channels // self.group x = ops.reshape(x, (batchsize, group_channels, self.group, height, width)) x = ops.transpose(x, (0, 2, 1, 3, 4)) x = ops.reshape(x, (batchsize, num_channels, height, width)) return x
2.3 ShuffleNetV1 架构
-
ShuffleNetV1:包括输入卷积层、多个 ShuffleNet 模块、全局平均池化层、全连接层。
pythonclass ShuffleNetV1(nn.Cell): def __init__(self, n_class=1000, model_size='2.0x', group=3): super(ShuffleNetV1, self).__init__() self.stage_repeats = [4, 8, 4] self.model_size = model_size if group == 3: if model_size == '0.5x': self.stage_out_channels = [-1, 12, 120, 240, 480] elif model_size == '1.0x': self.stage_out_channels = [-1, 24, 240, 480, 960] elif model_size == '1.5x': self.stage_out_channels = [-1, 24, 360, 720, 1440] elif model_size == '2.0x': self.stage_out_channels = [-1, 48, 480, 960, 1920] else: raise NotImplementedError elif group == 8: if model_size == '0.5x': self.stage_out_channels = [-1, 16, 192, 384, 768] elif model_size == '1.0x': self.stage_out_channels = [-1, 24, 384, 768, 1536] elif model_size == '1.5x': self.stage_out_channels = [-1, 24, 576, 1152, 2304] elif model_size == '2.0x': self.stage_out_channels = [-1, 48, 768, 1536, 3072] else: raise NotImplementedError input_channel = self.stage_out_channels[1] self.first_conv = nn.SequentialCell( nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False), nn.BatchNorm2d(input_channel), nn.ReLU(), ) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') features = [] for idxstage in range(len(self.stage_repeats)): numrepeat = self.stage_repeats[idxstage] output_channel = self.stage_out_channels[idxstage + 2] for i in range(numrepeat): stride = 2 if i == 0 else 1 first_group = idxstage == 0 and i == 0 features.append(ShuffleV1Block(input_channel, output_channel, group=group, first_group=first_group, mid_channels=output_channel // 4, ksize=3, stride=stride)) input_channel = output_channel self.features = nn.SequentialCell(features) self.globalpool = nn.AvgPool2d(7) self.classifier = nn.Dense(self.stage_out_channels[-1], n_class) def construct(self, x): x = self.first_conv(x) x = self.maxpool(x) x = self.features(x) x = self.globalpool(x) x = ops.reshape(x, (-1, self.stage_out_channels[-1])) x = self.classifier(x) return x
3. 模型训练与评估
3.1 数据集准备
-
CIFAR-10 数据集:包含 60000 张 32×32 彩色图像,分为 10 类。训练集 50000 张,测试集 10000 张。
pythonfrom download import download url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz" download(url, cache_dir='./datasets')
3.2 训练配置
-
训练超参数:
- 学习率:0.1
- 批量大小:128
- 优化器:SGD
pythonfrom mindspore import context from mindspore.train import Model from mindspore.nn import SGD, SoftmaxCrossEntropyWithLogits, Accuracy from mindspore.dataset import Cifar10Dataset from mindspore import nn context.set_context(mode=context.GRAPH_MODE, device_target="CPU") net = ShuffleNetV1() loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') optimizer = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) metrics = {"accuracy": Accuracy()} model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=metrics) train_dataset = Cifar10Dataset(dataset_dir='./datasets/cifar-10-binary/', usage='train') test_dataset = Cifar10Dataset(dataset_dir='./datasets/cifar-10-binary/', usage='test') model.train(10, train_dataset) eval_result = model.eval(test_dataset) print("Evaluation result:", eval_result)
4. 总结
- ShuffleNet 的优势:通过分组卷积和通道重排显著减少计算量,提高效率。
- 模型应用:适用于资源受限的设备,如移动端和嵌入式系统。
在学习 ShuffleNet 时,可以通过代码实践来深入理解其优化原理和应用场景,并通过比较不同网络模型的性能来评估 ShuffleNet 的实际效果。