「日拱一码」087 机器学习——SPARROW

目录

[SPARROW 介绍](#SPARROW 介绍)

核心思想:稀疏掩码训练

[与 Lottery Ticket Hypothesis (LTH) 的关系](#与 Lottery Ticket Hypothesis (LTH) 的关系)

代码示例

代码关键点解释:


在机器学习领域,"SPARROW" 并不是一个像 Scikit-learn、TensorFlow 或 PyTorch 那样广为人知的通用框架或算法名称。经过查询,最相关的 "SPARROW" 是指一篇重要的研究论文或其中提出的技术。

最著名的 "SPARROW" 来自 Google Research 在 2020年发表的一篇论文 《Rigging the Lottery: Making All Tickets Winners》

SPARROW 介绍

核心思想:稀疏掩码训练

传统的模型剪枝流程是:训练一个大模型 -> 剪枝(移除不重要的权重) -> 微调。这个过程通常非常耗时。

SPARROW(在论文中更常被称为 稀疏掩码训练Lottery Ticket Hypothesis 的扩展)提出了一种截然不同的方法:

在训练一开始就随机初始化一个网络,并立即应用一个预先定义好的稀疏性掩码(Sparsity Mask),使得网络从一开始就是稀疏的。然后,在整个训练过程中,这个掩码保持不变,只更新那些未被掩码掩盖的权重。

这种方法的核心优势在于:

  1. 效率高:模型从始至终都是稀疏的,训练和推理的计算开销、内存占用都显著降低。
  2. 性能好:论文表明,通过找到合适的初始化和固定掩码(即"中奖彩票"),这种稀疏网络可以达到甚至有时超过原始稠密模型的精度。
  3. 简单直接:无需复杂的剪枝调度或微调阶段。

与 Lottery Ticket Hypothesis (LTH) 的关系

LTH 假设指出:一个随机初始化的稠密网络中,包含一个子网络("中奖彩票"),当被单独训练时,其性能可以媲美原始网络。

SPARROW 可以看作是 LTH 的一个极其高效的实践版本 。它不是先训练再找彩票,而是假设一个随机初始化的掩码就是一张"潜在"的中奖彩票,并直接训练这个稀疏子网络,省去了寻找彩票的昂贵过程。

代码示例

下面实现一个最简单的 SPARROW 风格训练示例。流程如下:

  1. 创建一个全连接神经网络。
  2. 随机生成一个固定掩码,指定哪些权重参与训练(例如,50% 的稀疏度)。
  3. 在训练过程中,应用这个掩码:在前向传播时,权重会被掩码;在反向传播后,只有未被掩码的权重会被更新。
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 1. 定义模型
class SimpleFC(nn.Module):
    def __init__(self):
        super(SimpleFC, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784) # 展平输入
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# 2. 创建模型、优化器、损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleFC().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# 3. 创建并应用固定稀疏掩码 (50% 稀疏度)
def create_sparsity_mask(model, sparsity=0.5):
    masks = {}
    for name, param in model.named_parameters():
        if 'weight' in name:
            # 为权重矩阵创建一个相同形状的随机掩码
            mask = torch.rand_like(param.data) > sparsity
            masks[name] = mask.to(device)
            # 初始应用掩码:将不参与训练的权重置零
            param.data *= mask
    return masks

sparsity_mask = create_sparsity_mask(model, sparsity=0.5)

# 4. 训练循环
def train(model, device, train_loader, optimizer, criterion, mask, epochs=5):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # SPARROW 关键步骤:在 optimizer.step() 之前,应用掩码到梯度上
            # 确保只有被掩码选中的权重才有非零梯度,从而被更新
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.grad *= mask[name]

            optimizer.step()

            # SPARROW 另一个关键步骤:在参数更新后,再次应用掩码到权重上
            # 确保被剪枝的权重始终保持为零
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.data *= mask[name]

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                      f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# 5. 加载数据并开始训练
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

print("开始训练 SPARROW 稀疏模型...")
train(model, device, train_loader, optimizer, criterion, sparsity_mask, epochs=3)
# 开始训练 SPARROW 稀疏模型...
# Epoch: 0 [0/60000 (0%)]	Loss: 2.300159
# Epoch: 0 [6400/60000 (11%)]	Loss: 0.135472
# Epoch: 0 [12800/60000 (21%)]	Loss: 0.146012
# Epoch: 0 [19200/60000 (32%)]	Loss: 0.177537
# Epoch: 0 [25600/60000 (43%)]	Loss: 0.034564
# Epoch: 0 [32000/60000 (53%)]	Loss: 0.165950
# Epoch: 0 [38400/60000 (64%)]	Loss: 0.214527
# Epoch: 0 [44800/60000 (75%)]	Loss: 0.239639
# Epoch: 0 [51200/60000 (85%)]	Loss: 0.173407
# Epoch: 0 [57600/60000 (96%)]	Loss: 0.087583
# Epoch: 1 [0/60000 (0%)]	Loss: 0.040576
# Epoch: 1 [6400/60000 (11%)]	Loss: 0.092811
# Epoch: 1 [12800/60000 (21%)]	Loss: 0.397150
# Epoch: 1 [19200/60000 (32%)]	Loss: 0.221431
# Epoch: 1 [25600/60000 (43%)]	Loss: 0.218968
# Epoch: 1 [32000/60000 (53%)]	Loss: 0.164273
# Epoch: 1 [38400/60000 (64%)]	Loss: 0.122340
# Epoch: 1 [44800/60000 (75%)]	Loss: 0.197523
# Epoch: 1 [51200/60000 (85%)]	Loss: 0.268147
# Epoch: 1 [57600/60000 (96%)]	Loss: 0.203193
# Epoch: 2 [0/60000 (0%)]	Loss: 0.115242
# Epoch: 2 [6400/60000 (11%)]	Loss: 0.276544
# Epoch: 2 [12800/60000 (21%)]	Loss: 0.515723
# Epoch: 2 [19200/60000 (32%)]	Loss: 0.202442
# Epoch: 2 [25600/60000 (43%)]	Loss: 0.092944
# Epoch: 2 [32000/60000 (53%)]	Loss: 0.090384
# Epoch: 2 [38400/60000 (64%)]	Loss: 0.145279
# Epoch: 2 [44800/60000 (75%)]	Loss: 0.155133
# Epoch: 2 [51200/60000 (85%)]	Loss: 0.091369
# Epoch: 2 [57600/60000 (96%)]	Loss: 0.216552

代码关键点解释:

  1. 创建掩码 (create_sparsity_mask): 为每个权重矩阵生成一个随机二进制掩码(1表示保留,0表示剪枝)
  2. 初始化应用掩码: 在训练开始前,将模型的权重与掩码相乘,使部分权重归零
  3. 梯度掩码 : 在反向传播计算出梯度后、优化器更新权重之前,将梯度与掩码相乘。这确保了只有被保留的权重才会被更新,被剪枝的权重的梯度始终为零
  4. 权重掩码: 在优化器更新完权重后,再次将权重与掩码相乘。这是一个保护步骤,防止由于优化器(如带有动量的SGD)的更新操作可能使本应为零的权重产生微小的数值变化
相关推荐
有代理ip几秒前
成功请求的密码:HTTP 2 开头响应码深度解析
java·大数据·python·算法·php
0思必得01 分钟前
[Web自动化] Selenium截图
前端·爬虫·python·selenium·自动化
jl48638212 分钟前
打造医疗设备的“可靠视窗”:医用控温仪专用屏从抗菌设计到EMC兼容的全链路解析
大数据·运维·人工智能·物联网·人机交互
kiro_10237 分钟前
BGRtoNV12与NV12toBGR互转函数
人工智能·opencv·计算机视觉
码农三叔7 分钟前
(9-1)电源管理与能源系统:电池选择与安全
人工智能·嵌入式硬件·安全·机器人·能源·人形机器人
司沐_Simuoss9 分钟前
Text to SQL系统的千层套路~
数据库·人工智能·sql·语言模型·系统架构
北京阿法龙科技有限公司10 分钟前
工业场景下AR+AI图像识别:精准选型赋能运维与质检
运维·人工智能·ar
才兄说31 分钟前
机器人租售怎么嵌?按流程节点
人工智能
logic_534 分钟前
关于VIT为啥可以用卷积代替第一层嵌入层
人工智能·神经网络·cnn