CIFAR10 数据不平衡代码制作与理解

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) # 生成每个类别的样本数量列表。
        print("Generated Image Number List:", img_num_list)
        self.gen_imbalanced_data(img_num_list) # 调用 gen_imbalanced_data 方法,生成不平衡数据集。
        print("Imbalanced Data Generated.")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0] # 找到所有属于当前类别的样本的索引,并存储在 idx 中。
            np.random.shuffle(idx) # 随机打乱当前类别的样本索引顺序,以确保样本的随机性。
            selec_idx = idx[:the_img_num]  # 从当前类别的样本索引中选择前 the_img_num 个索引,即根据不平衡设定的数量,选择少数类别的样本。
            new_data.append(self.data[selec_idx, ...]) # 
            new_targets.extend([the_class, ] * the_img_num) #将当前类别的标签复制 the_img_num 次,并将复制的标签添加到新标签列表 new_targets 中,确保标签与样本数据对应。
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Example usage:
imbalance_cifar10 = IMBALANCECIFAR10(root='../data', imb_type='exp', imb_factor=0.01, rand_number=0, train=True, download=True)

print("Class-wise Sample Numbers:", imbalance_cifar10.get_cls_num_list())

val_dataset = datasets.CIFAR10(root='../data', train=False, download=True,transform=transform_val)
original_targets = np.array(val_dataset.targets)
original_class_counts = np.bincount(original_targets)
print("Original CIFAR-10 Class-wise Sample Numbers:", original_class_counts)
相关推荐
深度学习实战训练营1 小时前
基于keras的停车场车位识别
人工智能·深度学习·keras
菜就多练_08282 小时前
《深度学习》OpenCV 摄像头OCR 过程及案例解析
人工智能·深度学习·opencv·ocr
没有余地 EliasJie2 小时前
Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT
pytorch·windows·深度学习·ubuntu·pycharm·conda·tensorflow
技术无疆3 小时前
【Python】Streamlit:为数据科学与机器学习打造的简易应用框架
开发语言·人工智能·python·深度学习·神经网络·机器学习·数据挖掘
浊酒南街3 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)2.7-2.8
人工智能·深度学习·神经网络
被制作时长两年半的个人练习生4 小时前
【pytorch】权重为0的情况
人工智能·pytorch·深度学习
xiandong2011 小时前
240929-CGAN条件生成对抗网络
图像处理·人工智能·深度学习·神经网络·生成对抗网络·计算机视觉
innutritious12 小时前
车辆重识别(2020NIPS去噪扩散概率模型)论文阅读2024/9/27
人工智能·深度学习·计算机视觉
醒了就刷牙13 小时前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
橙子小哥的代码世界13 小时前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer