基于cifar数据集合成含开集、闭集噪声的数据集

前言

噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。

训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。

论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:

  • 第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction提供的代码中,使用了这种方法。但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。

  • 第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets提供的代码使用了这种方式。

places365可以使用torchvision.datasets.Places365下载,由于训练集较大,通常是用它的验证集作为辅助数据集。

imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见这里。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets提供的链接

接下来是用第二种方法,辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。

实验

设计imagenet32数据集

python 复制代码
import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

_train_list = ['train_data_batch_1',
               'train_data_batch_2',
               'train_data_batch_3',
               'train_data_batch_4',
               'train_data_batch_5',
               'train_data_batch_6',
               'train_data_batch_7',
               'train_data_batch_8',
               'train_data_batch_9',
               'train_data_batch_10']
_val_list = ['val_data']


def get_dataset(transform_train, transform_test):
    # prepare datasets

    # Train set
    train = Imagenet32(train=True, transform=transform_train)  # Load all 1000 classes in memory

    # Test set
    test = Imagenet32(train=False, transform=transform_test)  # Load all 1000 test classes in memory

    return train, test


class Imagenet32(Dataset):
    def __init__(self, root='~/data/imagenet32', train=True, transform=None):
        if root[0] == '~':
            root = os.path.expanduser(root)
        self.transform = transform
        size = 32
        # Now load the picked numpy arrays

        if train:
            data, labels = [], []

            for f in _train_list:
                file = os.path.join(root, f)

                with open(file, 'rb') as fo:
                    entry = pickle.load(fo, encoding='latin1')
                    data.append(entry['data'])
                    labels += entry['labels']
            data = np.concatenate(data)

        else:
            f = _val_list[0]
            file = os.path.join(root, f)
            with open(file, 'rb') as fo:
                entry = pickle.load(fo, encoding='latin1')
                data = entry['data']
                labels = entry['labels']

        data = data.reshape((-1, 3, size, size))
        self.data = data.transpose((0, 2, 3, 1))  # Convert to HWC
        labels = np.array(labels) - 1
        self.labels = labels.tolist()

    def __getitem__(self, index):

        img, target = self.data[index], self.labels[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target, index

    def __len__(self):
        return len(self.data)

目录结构:

txt 复制代码
imagenet32
├─ train_data_batch_1
├─ train_data_batch_10
├─ train_data_batch_2
├─ train_data_batch_3
├─ train_data_batch_4
├─ train_data_batch_5
├─ train_data_batch_6
├─ train_data_batch_7
├─ train_data_batch_8
├─ train_data_batch_9
└─ val_data

设计cifar数据集

python 复制代码
import torchvision
import numpy as np
from dataset.imagenet32 import Imagenet32


class CIFAR10(torchvision.datasets.CIFAR10):

    def __init__(self, root='~/data', train=True, transform=None,
                 r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):
        nb_classes = 10
        self.nb_classes = nb_classes
        super().__init__(root, train=train, transform=transform)
        if train is False:
            return
        np.random.seed(seed)
        if r_ood > 0.:
            ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
            if corruption == 'imagenet':
                imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
                img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
            else:
                raise ValueError(f'Unknown corruption: {corruption}')
            self.ids_ood = ids_ood
            self.data[ids_ood] = img_ood

        if r_id > 0.:
            ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
            ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
            for i, t in enumerate(self.targets):
                if i in ids_id:
                    self.targets[i] = int(np.random.random() * nb_classes)
            self.ids_id = ids_id


class CIFAR100(torchvision.datasets.CIFAR100):

    def __init__(self, root='~/data', train=True, transform=None,
                 r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):
        nb_classes = 100
        self.nb_classes = nb_classes
        super().__init__(root, train=train, transform=transform)
        if train is False:
            return
        np.random.seed(seed)
        if r_ood > 0.:
            ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
            if corruption == 'imagenet':
                imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
                img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
            else:
                raise ValueError(f'Unknown corruption: {corruption}')
            self.ids_ood = ids_ood
            self.data[ids_ood] = img_ood

        if r_id > 0.:
            ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
            ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
            for i, t in enumerate(self.targets):
                if i in ids_id:
                    self.targets[i] = int(np.random.random() * nb_classes)
            self.ids_id = ids_id

查看统计结果

python 复制代码
import pandas as pd
import altair as alt
from dataset.cifar import CIFAR10, CIFAR100

# Initialize CIFAR10 dataset
cifar10 = CIFAR10()
cifar100 = CIFAR100()


def statistics_samples(dataset):
    ids_ood = dataset.ids_ood
    ids_id = dataset.ids_id

    # Collect statistics
    statistics = []
    for i in range(dataset.nb_classes):
        statistics.append({
            'class': i,
            'id': 0,
            'ood': 0,
            'clear': 0
        })

    for i, t in enumerate(dataset.targets):
        if i in ids_ood:
            statistics[t]['ood'] += 1
        elif i in ids_id:
            statistics[t]['id'] += 1
        else:
            statistics[t]['clear'] += 1

    df = pd.DataFrame(statistics)

    # Melt the DataFrame for Altair
    df_melt = df.melt(id_vars='class', var_name='type', value_name='count')

    # Create the bar chart
    chart = alt.Chart(df_melt).mark_bar().encode(
        x=alt.X('class:O', title='Classes'),
        y=alt.Y('count:Q', title='Sample Count'),
        color='type:N'
    )
    return chart


chart1 = statistics_samples(cifar10)
chart2 = statistics_samples(cifar100)
chart1 = chart1.properties(
    title='cifar10',
    width=100,  # Adjust width to fit both charts side by side
    height=400
)
chart2 = chart2.properties(
    title='cifar100',
    width=800,
    height=400
)
combined_chart = alt.hconcat(chart1, chart2).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_legend(
    titleFontSize=14,
    labelFontSize=12
)
combined_chart

运行环境

txt 复制代码
# Name                    Version                   Build  Channel
altair                    5.3.0                    pypi_0    pypi
pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch
pandas                    2.2.2                    pypi_0    pypi
相关推荐
盛世隐者26 分钟前
【pytorch】循环神经网络
人工智能·pytorch
四口鲸鱼爱吃盐2 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
唐小旭16 小时前
python3.6搭建pytorch环境
人工智能·pytorch·python
四口鲸鱼爱吃盐18 小时前
Pytorch | 从零构建ParNet/Non-Deep Networks对CIFAR10进行分类
人工智能·pytorch·分类
四口鲸鱼爱吃盐1 天前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf2 天前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零12 天前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐2 天前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗2 天前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
四口鲸鱼爱吃盐2 天前
Pytorch | 利用VMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python