深入浅出对抗学习:概念、攻击、防御与代码实践


深入浅出对抗学习:概念、攻击、防御与代码实践

近年来,深度学习在图像识别、自然语言处理等领域取得了巨大成功。然而,研究表明,这些看似强大的模型却异常脆弱,容易受到**对抗样本(Adversarial Examples)**的攻击。对抗学习(Adversarial Learning)应运而生,它研究如何生成对抗样本以揭示模型弱点,以及如何构建更鲁棒的模型来抵御这些攻击。

1. 什么是对抗样本?

对抗样本是指在原始输入数据上添加人眼难以察觉的微小扰动,从而使得深度学习模型以高置信度给出错误分类的样本。

想象一下,一个先进的图像识别模型能准确识别一只熊猫。攻击者通过精心设计的微小扰动(比如修改几个像素值),使得模型将这只"略有不同"的熊猫图片错误地识别为长臂猿,而且模型对此判断非常自信。

(图片来源: Goodfellow et al., "Explaining and Harnessing Adversarial Examples")

这种现象揭示了深度学习模型在理解数据本质方面的局限性,并对安全攸关的应用(如自动驾驶、医疗诊断)构成了严重威胁。

2. 为什么会对抗样本有效?

目前主流的解释之一是神经网络的线性特性。尽管神经网络具有非线性激活函数,但在高维空间中,即使每个输入特征的微小改变(线性叠加)也可能导致输出的显著变化。

对于一个线性模型 f ( x ) = w T x f(x) = w^T x f(x)=wTx,如果我们对输入 x x x 添加一个扰动 η \eta η,新的输入为 x ~ = x + η \tilde{x} = x + \eta x~=x+η。模型的输出变为 f ( x ~ ) = w T ( x + η ) = w T x + w T η f(\tilde{x}) = w^T (x + \eta) = w^T x + w^T \eta f(x~)=wT(x+η)=wTx+wTη。为了使输出变化最大,我们可以让扰动 η \eta η 与权重 w w w 的方向一致或相反。对于深度模型,虽然更复杂,但局部线性假设仍然在一定程度上成立。

3. 常见的对抗攻击方法

对抗攻击的目标是找到一个最小的扰动 η \eta η,使得 x + η x + \eta x+η 被错误分类,同时 ∥ η ∥ \|\eta\| ∥η∥ 尽可能小(例如使用 L 0 , L 2 , L ∞ L_0, L_2, L_\infty L0,L2,L∞ 范数约束)。

3.1 Fast Gradient Sign Method (FGSM) - 快速梯度符号法

FGSM 是由 Goodfellow 等人提出的一种简单而高效的白盒攻击方法(攻击者知道模型结构和参数)。它利用损失函数关于输入数据的梯度方向来生成扰动。

扰动的计算公式为:
η = ϵ ⋅ sign ( ∇ x J ( θ , x , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) η=ϵ⋅sign(∇xJ(θ,x,y))

其中:

  • ϵ \epsilon ϵ 是一个很小的正数,控制扰动的大小。
  • x x x 是原始输入。
  • y y y 是原始输入的真实标签。
  • J ( θ , x , y ) J(\theta, x, y) J(θ,x,y) 是模型的损失函数(如交叉熵损失)。
  • ∇ x J ( θ , x , y ) \nabla_x J(\theta, x, y) ∇xJ(θ,x,y) 是损失函数关于输入 x x x 的梯度。
  • sign ( ⋅ ) \text{sign}(\cdot) sign(⋅) 是符号函数。

对抗样本则为: x a d v = x + η x_{adv} = x + \eta xadv=x+η。通常还需要将 x a d v x_{adv} xadv 裁剪到原始数据有效范围内(例如图像像素值在 [0, 1] 或 [0, 255] 之间)。

代码示例 (PyTorch - MNIST数据集):

首先,我们需要一个预训练好的模型。这里我们快速训练一个简单的CNN模型用于MNIST分类。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.1307,), (0.3081,)) # 暂时不进行归一化,方便可视化和扰动计算
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 定义一个简单的CNN模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(4*4*64, 1024) # MNIST images are 28x28
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 训练模型 (如果已有预训练模型,可以跳过)
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss() # Negative Log Likelihood Loss, as output is log_softmax

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# 简单训练几轮
# num_epochs = 3
# for epoch in range(1, num_epochs + 1):
#     train(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pt")

# 加载预训练模型 (假设已训练并保存)
model.load_state_dict(torch.load("mnist_cnn.pt", map_location=device))
model.eval() # 设置为评估模式

# FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):
    # 收集梯度的元素符号
    sign_data_grad = data_grad.sign()
    # 创建扰动图像
    perturbed_image = image + epsilon * sign_data_grad
    # 添加剪切以保持在 [0,1] 范围内
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

# 测试 FGSM 攻击
def test_fgsm(model, device, test_loader, epsilon):
    correct = 0
    adv_examples = [] # 保存一些对抗样本用于可视化

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True # 重要: 设置输入需要梯度

        output = model(data)
        init_pred = output.max(1, keepdim=True)[1] # 得到初始预测

        # 如果初始预测就错了,则不进行攻击 (或者可以攻击到特定目标)
        # 这里我们只攻击那些最初分类正确的样本
        correct_mask = (init_pred.squeeze() == target)
        
        if not correct_mask.any(): # 如果这个batch里没有一个预测正确,则跳过
            continue

        # 只选择预测正确的样本进行攻击
        # data_correct = data[correct_mask]
        # target_correct = target[correct_mask]
        # init_pred_correct = init_pred[correct_mask]

        # 计算损失
        loss = F.nll_loss(output, target) # 使用NLLLoss因为模型输出是log_softmax

        # 清零所有现有梯度
        model.zero_grad()

        # 计算反向传播的梯度
        loss.backward()

        # 收集数据梯度
        data_grad = data.grad.data

        # 调用FGSM攻击
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # 重新分类受扰动的图像
        output_adv = model(perturbed_data)
        final_pred = output_adv.max(1, keepdim=True)[1] # 得到对抗样本的预测

        # 检查是否仍然正确
        correct += (final_pred.squeeze() == target).sum().item()
        
        # 保存一些成功的对抗样本 (初始正确,攻击后错误)
        if len(adv_examples) < 5:
            # 找到那些初始预测正确,但攻击后预测错误的样本
            successful_attacks_mask = (init_pred.squeeze() == target) & (final_pred.squeeze() != target)
            if successful_attacks_mask.any():
                # adv_ex = perturbed_data[successful_attacks_mask][0].squeeze().detach().cpu().numpy()
                # orig_ex = data[successful_attacks_mask][0].squeeze().detach().cpu().numpy()
                # adv_examples.append((init_pred[successful_attacks_mask][0].item(), final_pred[successful_attacks_mask][0].item(), orig_ex, adv_ex))
                
                # 取第一个成功攻击的样本
                idx_attacked = torch.where(successful_attacks_mask)[0][0]
                adv_ex = perturbed_data[idx_attacked].squeeze().detach().cpu().numpy()
                orig_ex = data[idx_attacked].squeeze().detach().cpu().numpy()
                orig_lbl = target[idx_attacked].item()
                adv_lbl = final_pred[idx_attacked].item()
                adv_examples.append((orig_lbl, adv_lbl, orig_ex, adv_ex))


    final_acc = correct / float(len(test_loader.dataset))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct}/{len(test_loader.dataset)} = {final_acc}")
    return final_acc, adv_examples

# 运行测试
epsilons = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
accuracies = []
example_lists = []

for eps in epsilons:
    acc, ex = test_fgsm(model, device, test_loader, eps)
    accuracies.append(acc)
    example_lists.append(ex)

代码解释:

  1. Net: 定义了一个简单的卷积神经网络。
  2. 训练模型 : (注释掉的部分) 标准的PyTorch训练流程。你需要先运行这部分来得到 mnist_cnn.pt 文件,或者使用你自己的预训练模型。
  3. 加载模型 : 加载预训练的权重,并设置为 model.eval() 模式,这很重要,因为它会关闭 Dropout 和 BatchNorm 的更新等。
  4. fgsm_attack 函数 :
    • data_grad.sign(): 获取梯度的符号。
    • perturbed_image = image + epsilon * sign_data_grad: 根据FGSM公式生成扰动图像。
    • torch.clamp(perturbed_image, 0, 1): 将像素值裁剪到有效的[0,1]范围(因为我们使用了ToTensor(),它将图像像素值缩放到[0,1])。
  5. test_fgsm 函数 :
    • data.requires_grad = True: 核心步骤!告诉PyTorch我们需要计算输入data的梯度。
    • loss = F.nll_loss(output, target): 计算损失。注意这里的损失函数选择要和模型输出层以及FGSM的目标一致。对于非目标攻击,我们希望最大化当前标签的损失。
    • loss.backward(): 反向传播计算梯度。
    • data_grad = data.grad.data: 获取输入数据的梯度。
    • 调用 fgsm_attack 生成对抗样本。
    • 记录在不同 epsilon 下的准确率,并保存一些对抗样本用于可视化。
  6. 运行测试 : 使用不同的 epsilon 值测试攻击效果。epsilon=0 相当于没有攻击的原始准确率。

可视化结果:

python 复制代码
# 可视化准确率随epsilon的变化
plt.figure(figsize=(5,5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon (FGSM)")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()

# 可视化一些对抗样本
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):
    if not example_lists[i]: # 如果这个epsilon下没有收集到样本
        continue
    for j in range(min(1, len(example_lists[i]))): # 每个epsilon只显示一个例子
        cnt += 1
        plt.subplot(len(epsilons)//2 +1 , 2 if len(epsilons)>1 else 1, cnt) # 动态调整子图布局
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0: # 仅在第一列显示 Epsilon 值
            plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
        orig_lbl, adv_lbl, orig_img, adv_img = example_lists[i][j]
        
        # 显示原始图像和对抗图像
        if cnt % 2 != 0: # 奇数位置显示原始图像
            plt.imshow(orig_img, cmap="gray")
            plt.title(f"Original: {orig_lbl}")
        else: # 偶数位置显示对抗图像
            plt.imshow(adv_img, cmap="gray")
            plt.title(f"Adversarial: {adv_lbl}")
        
        # 如果我们想并排显示原始和对抗
        if cnt % 2 == 0 : # 每生成一个对抗样本就画图
            # 原图
            plt.subplot(len(epsilons),2,cnt-1)
            plt.xticks([], []); plt.yticks([], [])
            plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
            plt.imshow(orig_img, cmap="gray")
            plt.title(f"Original: {orig_lbl}")
            # 对抗图
            plt.subplot(len(epsilons),2,cnt)
            plt.xticks([], []); plt.yticks([], [])
            plt.imshow(adv_img, cmap="gray")
            plt.title(f"Adversarial: {adv_lbl}")

plt.tight_layout()
plt.show()

# 更好的可视化:每个epsilon一行,左右分别是原始和对抗
num_examples_to_show = 5 # 最多显示5个epsilon的例子
fig, axes = plt.subplots(min(num_examples_to_show, len(epsilons)), 2, figsize=(6, 2 * min(num_examples_to_show, len(epsilons))))
if min(num_examples_to_show, len(epsilons)) == 1: # 如果只有一个epsilon,axes不是数组
    axes = np.array([axes])


vis_count = 0
for i in range(len(epsilons)):
    if vis_count >= num_examples_to_show:
        break
    if not example_lists[i]:
        continue
    
    # 取第一个样本
    orig_lbl, adv_lbl, orig_img, adv_img = example_lists[i][0]

    axes[vis_count, 0].imshow(orig_img, cmap="gray")
    axes[vis_count, 0].set_title(f"Eps: {epsilons[i]}\nOriginal: {orig_lbl}")
    axes[vis_count, 0].axis('off')

    axes[vis_count, 1].imshow(adv_img, cmap="gray")
    axes[vis_count, 1].set_title(f"Adversarial: {adv_lbl}")
    axes[vis_count, 1].axis('off')
    
    vis_count +=1

# 如果没有收集到足够的样本,隐藏多余的子图
for i in range(vis_count, min(num_examples_to_show, len(epsilons))):
    fig.delaxes(axes[i,0])
    fig.delaxes(axes[i,1])

plt.tight_layout()
plt.show()

这段可视化代码会展示:

  1. 模型准确率随着 epsilon 增大的下降曲线。
  2. 一些原始图像和对应的、被成功攻击的对抗图像。你会看到,随着 epsilon 增大,图像上的噪声(扰动)会更明显,但模型更容易被欺骗。
3.2 Projected Gradient Descent (PGD) - 投影梯度下降

PGD 攻击可以看作是 FGSM 的迭代版本,通常也更强大。它在每一步都进行一次小的梯度上升,然后将扰动投影回一个允许的 ϵ \epsilon ϵ-球内,以确保扰动不会过大。

迭代更新规则:
x 0 a d v = x x_0^{adv} = x x0adv=x (或加入小的随机扰动)
x t + 1 a d v = Proj x , ϵ ( x t a d v + α ⋅ sign ( ∇ x J ( θ , x t a d v , y ) ) ) x_{t+1}^{adv} = \text{Proj}_{x, \epsilon} (x_t^{adv} + \alpha \cdot \text{sign}(\nabla_x J(\theta, x_t^{adv}, y))) xt+1adv=Projx,ϵ(xtadv+α⋅sign(∇xJ(θ,xtadv,y)))

其中:

  • α \alpha α 是每一步的步长,通常设为 ϵ / N s t e p s \epsilon / N_{steps} ϵ/Nsteps 或更小。
  • Proj x , ϵ ( ⋅ ) \text{Proj}{x, \epsilon}(\cdot) Projx,ϵ(⋅) 是一个投影操作,确保 x t + 1 a d v x{t+1}^{adv} xt+1adv 仍在原始输入 x x x 的 ϵ \epsilon ϵ-邻域内(例如,对于 L ∞ L_\infty L∞ 范数,将 x t + 1 a d v x_{t+1}^{adv} xt+1adv裁剪到 [ x − ϵ , x + ϵ ] [x-\epsilon, x+\epsilon] [x−ϵ,x+ϵ],并同时裁剪到有效数据范围如 [ 0 , 1 ] [0,1] [0,1])。

代码示例 (PyTorch - PGD):

python 复制代码
def pgd_attack(model, image, target, epsilon, alpha, num_iter):
    """ PGD攻击实现 """
    perturbed_image = image.clone().detach() # 从原始图像开始

    for _ in range(num_iter):
        perturbed_image.requires_grad = True
        output = model(perturbed_image)
        loss = F.nll_loss(output, target)
        model.zero_grad()
        loss.backward()

        # FGSM步进
        adv_image_update = perturbed_image + alpha * perturbed_image.grad.sign()
        
        # 投影操作 (L_infinity norm)
        eta = torch.clamp(adv_image_update - image, min=-epsilon, max=epsilon) # 限制扰动大小
        perturbed_image = torch.clamp(image + eta, min=0, max=1).detach() # 限制图像范围并分离计算图
        
    return perturbed_image

# 测试 PGD 攻击
def test_pgd(model, device, test_loader, epsilon, alpha, num_iter):
    correct = 0
    adv_examples_pgd = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        
        # PGD不需要data.requires_grad在循环外,因为它在pgd_attack内部处理
        
        output_orig = model(data) # 获取原始预测,用于比较
        init_pred = output_orig.max(1, keepdim=True)[1]

        perturbed_data = pgd_attack(model, data, target, epsilon, alpha, num_iter)
        output_adv = model(perturbed_data)
        final_pred = output_adv.max(1, keepdim=True)[1]

        correct += (final_pred.squeeze() == target).sum().item()
        
        if len(adv_examples_pgd) < 5:
            successful_attacks_mask = (init_pred.squeeze() == target) & (final_pred.squeeze() != target)
            if successful_attacks_mask.any():
                idx_attacked = torch.where(successful_attacks_mask)[0][0]
                adv_ex = perturbed_data[idx_attacked].squeeze().detach().cpu().numpy()
                orig_ex = data[idx_attacked].squeeze().detach().cpu().numpy()
                orig_lbl = target[idx_attacked].item()
                adv_lbl = final_pred[idx_attacked].item()
                adv_examples_pgd.append((orig_lbl, adv_lbl, orig_ex, adv_ex))

    final_acc = correct / float(len(test_loader.dataset))
    print(f"PGD: Epsilon: {epsilon}, Alpha: {alpha}, Iter: {num_iter}\tTest Accuracy = {final_acc}")
    return final_acc, adv_examples_pgd

# PGD 测试参数
epsilon_pgd = 0.1 # 扰动幅度
alpha_pgd = 0.01  # 迭代步长
num_iter_pgd = 20 # 迭代次数 (可以设为7-40)

acc_pgd, ex_pgd = test_pgd(model, device, test_loader, epsilon_pgd, alpha_pgd, num_iter_pgd)

# 可视化PGD结果 (与FGSM类似)
if ex_pgd:
    print(f"\nPGD (eps={epsilon_pgd}, alpha={alpha_pgd}, iter={num_iter_pgd}) examples:")
    fig, axes = plt.subplots(1, 2, figsize=(6, 3)) # 只显示一个PGD例子
    orig_lbl, adv_lbl, orig_img, adv_img = ex_pgd[0]
    axes[0].imshow(orig_img, cmap="gray")
    axes[0].set_title(f"Original: {orig_lbl}")
    axes[0].axis('off')
    axes[1].imshow(adv_img, cmap="gray")
    axes[1].set_title(f"PGD Adversarial: {adv_lbl}")
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()

代码解释 (pgd_attack):

  1. perturbed_image = image.clone().detach(): 创建一个与原始图像相同的副本,并将其从计算图中分离出来,作为迭代的起点。
  2. 循环 : 迭代 num_iter 次。
    • perturbed_image.requires_grad = True: 在每次迭代开始时,设置当前扰动图像需要梯度。
    • 计算损失和梯度,与FGSM类似。
    • adv_image_update = perturbed_image + alpha * perturbed_image.grad.sign(): 进行一步梯度上升。
    • 投影 :
      • eta = torch.clamp(adv_image_update - image, min=-epsilon, max=epsilon): 计算扰动 eta,并将其限制在 [-epsilon, epsilon] 范围内,这是 L ∞ L_\infty L∞ 范数约束。
      • perturbed_image = torch.clamp(image + eta, min=0, max=1).detach(): 将扰动加回原始图像,然后将结果裁剪到有效像素范围 [0,1],并 .detach() 以便下一次迭代。

PGD 通常比 FGSM 更难防御,因为它会尝试在 ϵ \epsilon ϵ-球内找到一个更优的攻击点。

4. 对抗防御方法

4.1 对抗训练 (Adversarial Training)

这是目前最有效且研究最广泛的防御方法之一。核心思想是将对抗样本加入到训练数据中,让模型在训练过程中学习识别这些对抗样本。

训练目标函数变为:
min ⁡ θ E ( x , y ) ∼ D [ max ⁡ ∥ δ ∥ ≤ ϵ L ( θ , x + δ , y ) ] \min_{\theta} \mathbb{E}{(x,y) \sim D} \left[ \max{\|\delta\| \le \epsilon} L(\theta, x+\delta, y) \right] minθE(x,y)∼D[max∥δ∥≤ϵL(θ,x+δ,y)]

这意味着我们希望找到模型参数 θ \theta θ,使得在最坏情况下的扰动 δ \delta δ(在 ϵ \epsilon ϵ 范围内最大化损失 L L L)下,期望损失最小。

在实践中,通常在每个训练批次中:

  1. 对当前批次的干净样本生成对抗样本(例如使用FGSM或PGD)。
  2. 将这些对抗样本(有时也包括原始干净样本)用于更新模型参数。

代码示例 (PyTorch - 对抗训练,使用FGSM生成对抗样本):

python 复制代码
model_adv_trained = Net().to(device) # 新建一个模型用于对抗训练
optimizer_adv = optim.Adam(model_adv_trained.parameters(), lr=0.001)
criterion_adv = nn.NLLLoss()

# 对抗训练参数
adv_train_epsilon = 0.15 # 用于生成对抗样本的epsilon

def adv_train(epoch):
    model_adv_trained.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # --- 生成对抗样本 ---
        data.requires_grad = True # 允许计算data的梯度
        output_for_adv = model_adv_trained(data)
        loss_for_adv = criterion_adv(output_for_adv, target)
        
        model_adv_trained.zero_grad() # 清除旧梯度
        loss_for_adv.backward() # 计算梯度
        data_grad = data.grad.data
        
        # 使用FGSM生成对抗样本
        perturbed_data = fgsm_attack(data, adv_train_epsilon, data_grad)
        # --- 对抗样本生成完毕 ---

        optimizer_adv.zero_grad() # 清除用于模型更新的梯度
        
        # 在对抗样本上进行训练
        # 可以选择只用对抗样本,或者混合原始样本和对抗样本
        # 这里我们只用对抗样本(也可以将perturbed_data和data拼接起来)
        output_adv_train = model_adv_trained(perturbed_data) 
        loss_adv_train = criterion_adv(output_adv_train, target)
        
        loss_adv_train.backward() # 计算模型参数的梯度
        optimizer_adv.step() # 更新模型参数

        if batch_idx % 100 == 0:
            print(f'Adv Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss_adv_train.item():.6f}')

# 进行对抗训练 (示例1-2轮,实际需要更多)
num_adv_epochs = 2
for epoch in range(1, num_adv_epochs + 1):
    adv_train(epoch)
torch.save(model_adv_trained.state_dict(), "mnist_cnn_adv_trained_fgsm.pt")

# 加载对抗训练后的模型
model_adv_trained.load_state_dict(torch.load("mnist_cnn_adv_trained_fgsm.pt", map_location=device))
model_adv_trained.eval()

print("\n--- Evaluating Standard Model (Original Accuracy) ---")
test_fgsm(model, device, test_loader, 0) # Epsilon=0 -> 原始准确率

print("\n--- Evaluating Standard Model against FGSM (eps=0.15) ---")
test_fgsm(model, device, test_loader, 0.15)

print("\n--- Evaluating Adversarially Trained Model (Original Accuracy) ---")
test_fgsm(model_adv_trained, device, test_loader, 0) # Epsilon=0

print("\n--- Evaluating Adversarially Trained Model against FGSM (eps=0.15) ---")
test_fgsm(model_adv_trained, device, test_loader, 0.15)

print("\n--- Evaluating Standard Model against PGD (eps=0.1, alpha=0.01, iter=20) ---")
test_pgd(model, device, test_loader, epsilon_pgd, alpha_pgd, num_iter_pgd)

print("\n--- Evaluating Adversarially Trained Model against PGD (eps=0.1, alpha=0.01, iter=20) ---")
# 注意: 对抗训练时用的FGSM,这里用PGD测试,PGD通常更强
test_pgd(model_adv_trained, device, test_loader, epsilon_pgd, alpha_pgd, num_iter_pgd)

代码解释 (adv_train):

  1. 对于每个批次的数据,首先用当前模型状态生成对应的对抗样本(这里使用FGSM)。
    • data.requires_grad = True
    • 前向传播,计算损失,反向传播得到 data.grad.data
    • 调用 fgsm_attack 得到 perturbed_data
  2. 然后,使用这些 perturbed_data(而不是原始 data)来训练模型:
    • output_adv_train = model_adv_trained(perturbed_data)
    • loss_adv_train = criterion_adv(output_adv_train, target)
    • loss_adv_train.backward()
    • optimizer_adv.step()

预期结果:

  • 标准模型在干净数据上准确率高,但在对抗样本(FGSM或PGD)上准确率会显著下降。
  • 对抗训练后的模型:
    • 在干净数据上的准确率可能会略微下降(trade-off)。
    • 但在对抗样本上的准确率会显著高于标准模型,表明其鲁棒性增强。
    • 如果对抗训练时使用的是较弱的攻击(如FGSM),那么模型对更强的攻击(如PGD)的防御能力可能仍然有限。用PGD进行对抗训练通常能获得更好的鲁棒性,但训练成本也更高。
4.2 其他防御方法 (概念性)
  • 防御性蒸馏 (Defensive Distillation): 利用知识蒸馏的思想,训练一个模型使其输出更平滑的概率分布,从而增加对抗攻击的难度。但后续研究表明,这种方法对某些攻击的防御效果有限。
  • 梯度掩蔽/混淆 (Gradient Masking/Obfuscation): 一些防御方法通过使得梯度难以计算或梯度信息不准确来防御基于梯度的攻击。例如,随机化输入、随机化网络层。然而,这些方法通常会被适应性攻击(专门针对该防御设计的攻击)攻破。
  • 输入变换 (Input Transformation): 在将输入送入模型前对其进行变换,如JPEG压缩、总变差最小化、图像超分辨率等,试图消除对抗扰动。
  • 认证防御 (Certified Defenses): 这类方法旨在提供可证明的鲁棒性保证,即在某个 ϵ \epsilon ϵ-邻域内,模型的预测不会改变。通常计算成本较高,且能保证的 ϵ \epsilon ϵ 范围有限。

5. 对抗学习的应用与影响

  • 模型鲁棒性评估: 对抗攻击是评估模型在最坏情况下表现的重要工具。
  • 提升模型泛化能力: 有研究表明对抗训练有时也能提升模型在干净数据上的泛化能力。
  • 数据增强: 对抗样本可以看作是一种特殊的数据增强方式。
  • 安全性: 对于自动驾驶、人脸识别支付、医疗AI等安全攸关领域,对抗鲁棒性至关重要。
  • 可解释性: 对抗样本的研究有助于我们理解模型学到了什么,以及它们的决策边界。

6. 挑战与未来方向

  • 强大的自适应攻击: 许多防御方法最终被更强的、专门设计的自适应攻击攻破。
  • 鲁棒性与准确性的权衡: 提高鲁棒性往往以牺牲在干净数据上的准确性为代价。
  • 计算成本: 强大的对抗训练(如基于PGD的)计算成本很高。
  • 可迁移性: 在一个模型上生成的对抗样本,有时也能欺骗其他不同结构的模型(黑盒攻击的基础)。
  • 大规模、复杂数据集的鲁棒性: 目前多数研究集中在MNIST、CIFAR等小数据集上,对ImageNet等大规模数据集的鲁棒性研究仍具挑战。
  • 超越像素空间的扰动: 研究物理世界的对抗攻击(如打印对抗贴纸)、语义层面的对抗攻击等。

7. 总结

对抗学习是一个充满活力且快速发展的研究领域。它不仅揭示了当前深度学习模型的脆弱性,也推动了对更安全、更可靠、更可信AI系统的探索。通过理解对抗攻击的原理和实践防御方法,我们可以逐步构建出更经得起考验的智能系统。本文提供的代码示例希望能帮助你踏出探索对抗学习的第一步。


希望这篇文章对您有所帮助!它涵盖了对抗学习的核心概念、FGSM和PGD攻击方法(含代码)、对抗训练防御方法(含代码),以及相关的讨论。您可以直接运行这些PyTorch代码来体验对抗攻防的过程。

相关推荐
西岸行者5 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意5 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码5 天前
嵌入式学习路线
学习
毛小茛5 天前
计算机系统概论——校验码
学习
babe小鑫5 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms5 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下5 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。5 天前
2026.2.25监控学习
学习
im_AMBER5 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J5 天前
从“Hello World“ 开始 C++
c语言·c++·学习