「日拱一码」087 机器学习——SPARROW

目录

[SPARROW 介绍](#SPARROW 介绍)

核心思想:稀疏掩码训练

[与 Lottery Ticket Hypothesis (LTH) 的关系](#与 Lottery Ticket Hypothesis (LTH) 的关系)

代码示例

代码关键点解释:


在机器学习领域,"SPARROW" 并不是一个像 Scikit-learn、TensorFlow 或 PyTorch 那样广为人知的通用框架或算法名称。经过查询,最相关的 "SPARROW" 是指一篇重要的研究论文或其中提出的技术。

最著名的 "SPARROW" 来自 Google Research 在 2020年发表的一篇论文 《Rigging the Lottery: Making All Tickets Winners》

SPARROW 介绍

核心思想:稀疏掩码训练

传统的模型剪枝流程是:训练一个大模型 -> 剪枝(移除不重要的权重) -> 微调。这个过程通常非常耗时。

SPARROW(在论文中更常被称为 稀疏掩码训练Lottery Ticket Hypothesis 的扩展)提出了一种截然不同的方法:

在训练一开始就随机初始化一个网络,并立即应用一个预先定义好的稀疏性掩码(Sparsity Mask),使得网络从一开始就是稀疏的。然后,在整个训练过程中,这个掩码保持不变,只更新那些未被掩码掩盖的权重。

这种方法的核心优势在于:

  1. 效率高:模型从始至终都是稀疏的,训练和推理的计算开销、内存占用都显著降低。
  2. 性能好:论文表明,通过找到合适的初始化和固定掩码(即"中奖彩票"),这种稀疏网络可以达到甚至有时超过原始稠密模型的精度。
  3. 简单直接:无需复杂的剪枝调度或微调阶段。

与 Lottery Ticket Hypothesis (LTH) 的关系

LTH 假设指出:一个随机初始化的稠密网络中,包含一个子网络("中奖彩票"),当被单独训练时,其性能可以媲美原始网络。

SPARROW 可以看作是 LTH 的一个极其高效的实践版本 。它不是先训练再找彩票,而是假设一个随机初始化的掩码就是一张"潜在"的中奖彩票,并直接训练这个稀疏子网络,省去了寻找彩票的昂贵过程。

代码示例

下面实现一个最简单的 SPARROW 风格训练示例。流程如下:

  1. 创建一个全连接神经网络。
  2. 随机生成一个固定掩码,指定哪些权重参与训练(例如,50% 的稀疏度)。
  3. 在训练过程中,应用这个掩码:在前向传播时,权重会被掩码;在反向传播后,只有未被掩码的权重会被更新。
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 1. 定义模型
class SimpleFC(nn.Module):
    def __init__(self):
        super(SimpleFC, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784) # 展平输入
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# 2. 创建模型、优化器、损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleFC().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# 3. 创建并应用固定稀疏掩码 (50% 稀疏度)
def create_sparsity_mask(model, sparsity=0.5):
    masks = {}
    for name, param in model.named_parameters():
        if 'weight' in name:
            # 为权重矩阵创建一个相同形状的随机掩码
            mask = torch.rand_like(param.data) > sparsity
            masks[name] = mask.to(device)
            # 初始应用掩码:将不参与训练的权重置零
            param.data *= mask
    return masks

sparsity_mask = create_sparsity_mask(model, sparsity=0.5)

# 4. 训练循环
def train(model, device, train_loader, optimizer, criterion, mask, epochs=5):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # SPARROW 关键步骤:在 optimizer.step() 之前,应用掩码到梯度上
            # 确保只有被掩码选中的权重才有非零梯度,从而被更新
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.grad *= mask[name]

            optimizer.step()

            # SPARROW 另一个关键步骤:在参数更新后,再次应用掩码到权重上
            # 确保被剪枝的权重始终保持为零
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.data *= mask[name]

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                      f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# 5. 加载数据并开始训练
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

print("开始训练 SPARROW 稀疏模型...")
train(model, device, train_loader, optimizer, criterion, sparsity_mask, epochs=3)
# 开始训练 SPARROW 稀疏模型...
# Epoch: 0 [0/60000 (0%)]	Loss: 2.300159
# Epoch: 0 [6400/60000 (11%)]	Loss: 0.135472
# Epoch: 0 [12800/60000 (21%)]	Loss: 0.146012
# Epoch: 0 [19200/60000 (32%)]	Loss: 0.177537
# Epoch: 0 [25600/60000 (43%)]	Loss: 0.034564
# Epoch: 0 [32000/60000 (53%)]	Loss: 0.165950
# Epoch: 0 [38400/60000 (64%)]	Loss: 0.214527
# Epoch: 0 [44800/60000 (75%)]	Loss: 0.239639
# Epoch: 0 [51200/60000 (85%)]	Loss: 0.173407
# Epoch: 0 [57600/60000 (96%)]	Loss: 0.087583
# Epoch: 1 [0/60000 (0%)]	Loss: 0.040576
# Epoch: 1 [6400/60000 (11%)]	Loss: 0.092811
# Epoch: 1 [12800/60000 (21%)]	Loss: 0.397150
# Epoch: 1 [19200/60000 (32%)]	Loss: 0.221431
# Epoch: 1 [25600/60000 (43%)]	Loss: 0.218968
# Epoch: 1 [32000/60000 (53%)]	Loss: 0.164273
# Epoch: 1 [38400/60000 (64%)]	Loss: 0.122340
# Epoch: 1 [44800/60000 (75%)]	Loss: 0.197523
# Epoch: 1 [51200/60000 (85%)]	Loss: 0.268147
# Epoch: 1 [57600/60000 (96%)]	Loss: 0.203193
# Epoch: 2 [0/60000 (0%)]	Loss: 0.115242
# Epoch: 2 [6400/60000 (11%)]	Loss: 0.276544
# Epoch: 2 [12800/60000 (21%)]	Loss: 0.515723
# Epoch: 2 [19200/60000 (32%)]	Loss: 0.202442
# Epoch: 2 [25600/60000 (43%)]	Loss: 0.092944
# Epoch: 2 [32000/60000 (53%)]	Loss: 0.090384
# Epoch: 2 [38400/60000 (64%)]	Loss: 0.145279
# Epoch: 2 [44800/60000 (75%)]	Loss: 0.155133
# Epoch: 2 [51200/60000 (85%)]	Loss: 0.091369
# Epoch: 2 [57600/60000 (96%)]	Loss: 0.216552

代码关键点解释:

  1. 创建掩码 (create_sparsity_mask): 为每个权重矩阵生成一个随机二进制掩码(1表示保留,0表示剪枝)
  2. 初始化应用掩码: 在训练开始前,将模型的权重与掩码相乘,使部分权重归零
  3. 梯度掩码 : 在反向传播计算出梯度后、优化器更新权重之前,将梯度与掩码相乘。这确保了只有被保留的权重才会被更新,被剪枝的权重的梯度始终为零
  4. 权重掩码: 在优化器更新完权重后,再次将权重与掩码相乘。这是一个保护步骤,防止由于优化器(如带有动量的SGD)的更新操作可能使本应为零的权重产生微小的数值变化
相关推荐
Uzuki2 小时前
目标检测 | 基于Weiler–Atherton算法的IoU求解
目标检测·机器学习·自动驾驶·图形学
minhuan2 小时前
构建AI智能体:三十一、AI医疗场景实践:医学知识精准问答+临床智能辅助决策CDSS
人工智能·医学知识问答·临床辅助决策·cdss·医学模型
大千AI助手3 小时前
线性预热机制(Linear Warmup):深度学习训练稳定性的关键策略
人工智能·深度学习·大模型·模型训练·学习率·warmup·线性预热机制
七牛云行业应用3 小时前
企业级AI大模型选型指南:从评估部署到安全实践
大数据·人工智能·安全
진영_3 小时前
深度学习打卡第N6周:中文文本分类-Pytorch实现
人工智能·深度学习
龙亘川3 小时前
智慧城市SaaS平台之智慧城管十大核心功能(六):业务指导系统
人工智能·智慧城市
龙亘川3 小时前
智慧城市SaaS平台之智慧城管十大核心功能(七):后台支撑系统
服务器·人工智能·系统架构·智慧城市·运维开发·智慧城市saas平台
cms小程序插件【官方】3 小时前
pbootcms版AI自动发文插件升级到2.0版,支持AI配图、自动提取关键词
人工智能
AI 嗯啦3 小时前
计算机视觉----图像投影(透视)变换(小案例)
人工智能·opencv·计算机视觉