【对抗算法复现】CW

首先进行数据的预处理

python 复制代码
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor,自动将[0,255]映射到[0,1]
    transforms.Normalize((0.491,0.482 ,0.446), (0.247 ,0.243 ,0.261))  # 对张量进行标准化,使其范围为[-1,1]
])

CW实现

python 复制代码
def cw_l2_attack(model,
                 images,
                 labels,
                 targeted=True,
                 c=0.1,
                 kappa=0,
                 max_iter=1000,
                 learning_rate=0.01):

    # 计算损失函数,根据模型输出和目标标签计算一个分数,衡量模型输出的误导程度
    def f(x):
        # 论文中的 Z(X) 输出 batchsize, num_classes
        outputs = model(x)
        #将标签转换为one-hot编码形式
        one_hot_labels = torch.eye(len(outputs[0]),device=device)[labels].to(device)

        # 水平方向最大的取值,忽略索引。意思是,除去真实标签,看看每个 batchsize 中哪个标签的概率最大,取出概率
        i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)
        # 选择真实标签的概率
        j = torch.masked_select(outputs, one_hot_labels.bool())

        # 如果有攻击目标,虚假概率减去真实概率,
        if targeted:#使模型对目标错误类别的置信度至少比真实类别高 kappa。
            return torch.clamp(i - j, min=-kappa)
        # 没有攻击目标,就让真实的概率小于虚假的概率,逐步降低,也就是最小化这个损失
        else:#降低模型对真实类别的置信度,使其至少低于虚假类别的概率 -kappa。
            return torch.clamp(j - i, min=-kappa)

    w = torch.zeros_like(images, requires_grad=True).to(device)#一个与输入图像相同大小的张量,用于存储对抗扰动,并设置True以便后续梯度下降
    optimizer = optim.Adam([w], lr=learning_rate)#定义优化器
    #
    prev = 1e10

    for step in range(max_iter):
        a = 1 / 2 * (nn.Tanh()(w) + 1)#扰动应用到图像上,对抗图像
        # 第一个目标,对抗样本与原始样本足够接近
        loss1 = nn.MSELoss(reduction='sum')(a, images)
        # 第二个目标,误导模型输出
        loss2 = torch.sum(c * f(a))

        cost = loss1 + loss2
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # 早停策略 如果连续迭代中损失没有改善,则提前停止攻击
        if step % (max_iter // 10) == 0:
            if cost > prev:
                print('Attack Stopped due to CONVERGENCE....')
                return a
            prev = cost

    attack_images = 1 / 2 * (nn.Tanh()(w) + 1)#最终的对抗图像

    return attack_images

3.导入模型

python 复制代码
print('load model')
model = ResNet50()
pth_file = '../checkpoint/resnet50_ckpt.pth'
d = torch.load(pth_file)['net']
d = OrderedDict([(k[7:], v) for (k, v) in d.items()])
model.load_state_dict(d)
model.to(device)
model.eval()

完整代码

python 复制代码
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import os
from tqdm import tqdm
from collections import OrderedDict
from resnet import ResNet50

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
device = torch.device("cuda")
#数据预处理

# 图像预处理操作定义
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor,自动将[0,255]映射到[0,1]
    transforms.Normalize((0.491,0.482 ,0.446), (0.247 ,0.243 ,0.261))  # 对张量进行标准化,使其范围为[-1,1]
])
class CIFAR10Dataset(Dataset):
    """CIFAR-10数据集加载类,支持图像转换操作"""

    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """返回数据集中的图像总数"""
        return len(self.data)

    def __getitem__(self, idx):
        """获取单个图像及其标签,并应用预定义的转换"""
        image = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

#CW aTTACK
targeted=3
def cw_l2_attack(model,
                 images,
                 labels,
                 targeted=True,
                 c=0.1,
                 kappa=0,
                 max_iter=1000,
                 learning_rate=0.01):

    # 计算损失函数,根据模型输出和目标标签计算一个分数,衡量模型输出的误导程度
    def f(x):
        # 论文中的 Z(X) 输出 batchsize, num_classes
        outputs = model(x)
        # batchszie,根据labels 的取值,确定每一行哪一个为 1
        # >>> a = torch.eye(10)[[2, 3]]
        # >>> a
        # tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        # [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])
        #将标签转换为one-hot编码形式
        one_hot_labels = torch.eye(len(outputs[0]),device=device)[labels].to(device)

        # 水平方向最大的取值,忽略索引。意思是,除去真实标签,看看每个 batchsize 中哪个标签的概率最大,取出概率
        i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)
        # 选择真实标签的概率
        j = torch.masked_select(outputs, one_hot_labels.bool())

        # 如果有攻击目标,虚假概率减去真实概率,
        if targeted:#使模型对目标错误类别的置信度至少比真实类别高 kappa。
            return torch.clamp(i - j, min=-kappa)
        # 没有攻击目标,就让真实的概率小于虚假的概率,逐步降低,也就是最小化这个损失
        else:#降低模型对真实类别的置信度,使其至少低于虚假类别的概率 -kappa。
            return torch.clamp(j - i, min=-kappa)

    w = torch.zeros_like(images, requires_grad=True).to(device)#一个与输入图像相同大小的张量,用于存储对抗扰动,并设置True以便后续梯度下降
    optimizer = optim.Adam([w], lr=learning_rate)#定义优化器
    #
    prev = 1e10

    for step in range(max_iter):
        a = 1 / 2 * (nn.Tanh()(w) + 1)#扰动应用到图像上,对抗图像
        # 第一个目标,对抗样本与原始样本足够接近
        loss1 = nn.MSELoss(reduction='sum')(a, images)
        # 第二个目标,误导模型输出
        loss2 = torch.sum(c * f(a))

        cost = loss1 + loss2
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # 早停策略 如果连续迭代中损失没有改善,则提前停止攻击
        if step % (max_iter // 10) == 0:
            if cost > prev:
                print('Attack Stopped due to CONVERGENCE....')
                return a
            prev = cost

    attack_images = 1 / 2 * (nn.Tanh()(w) + 1)#最终的对抗图像

    return attack_images


print('load model')
model = ResNet50()
pth_file = '../checkpoint/resnet50_ckpt.pth'
d = torch.load(pth_file)['net']
d = OrderedDict([(k[7:], v) for (k, v) in d.items()])
model.load_state_dict(d)
model.to(device)
model.eval()

# 加载处理好的数据
test_data = np.load('../data/test_data.npy')
test_label = np.load('../data/test_labels.npy')

# 实例化数据集
testset = CIFAR10Dataset(data=test_data, labels=test_label, transform=transform)

# 创建数据加载器
testloader = DataLoader(testset, batch_size=200, shuffle=False)

# CIFAR-10的类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

save_data=None
print("cw attack...")
for data, label in tqdm(testloader):
    data, label = data.to(device), label.to(device)

    adv_data = cw_l2_attack(model=model, images=data, labels=label)

    if save_data is None:
        save_data = adv_data.detach_().cpu().numpy()
    else:
        save_data = np.concatenate(
            (save_data,adv_data.detach_().cpu().numpy()), axis=0)

# 定义保存文件的路径
file_path = 'result/cw2_cifar10_2.npy'

# 确保目录存在
if not os.path.exists(os.path.dirname(file_path)):
    os.makedirs(os.path.dirname(file_path))

# 现在可以安全地保存文件
np.save(file_path, save_data)
print('cw2_cifar10_2 has been saved')
相关推荐
Aczone2812 小时前
硬件(六)arm指令
开发语言·汇编·arm开发·嵌入式硬件·算法
luckys.one16 小时前
第9篇:Freqtrade量化交易之config.json 基础入门与初始化
javascript·数据库·python·mysql·算法·json·区块链
~|Bernard|18 小时前
在 PyCharm 里怎么“点鼠标”完成指令同样的运行操作
算法·conda
战术摸鱼大师18 小时前
电机控制(四)-级联PID控制器与参数整定(MATLAB&Simulink)
算法·matlab·运动控制·电机控制
Christo318 小时前
TFS-2018《On the convergence of the sparse possibilistic c-means algorithm》
人工智能·算法·机器学习·数据挖掘
好家伙VCC19 小时前
数学建模模型 全网最全 数学建模常见算法汇总 含代码分析讲解
大数据·嵌入式硬件·算法·数学建模
liulilittle20 小时前
IP校验和算法:从网络协议到SIMD深度优化
网络·c++·网络协议·tcp/ip·算法·ip·通信
bkspiderx1 天前
C++经典的数据结构与算法之经典算法思想:贪心算法(Greedy)
数据结构·c++·算法·贪心算法
中华小当家呐1 天前
算法之常见八大排序
数据结构·算法·排序算法
沐怡旸1 天前
【算法--链表】114.二叉树展开为链表--通俗讲解
算法·面试