CIFAR10 数据不平衡代码制作与理解

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) # 生成每个类别的样本数量列表。
        print("Generated Image Number List:", img_num_list)
        self.gen_imbalanced_data(img_num_list) # 调用 gen_imbalanced_data 方法,生成不平衡数据集。
        print("Imbalanced Data Generated.")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0] # 找到所有属于当前类别的样本的索引,并存储在 idx 中。
            np.random.shuffle(idx) # 随机打乱当前类别的样本索引顺序,以确保样本的随机性。
            selec_idx = idx[:the_img_num]  # 从当前类别的样本索引中选择前 the_img_num 个索引,即根据不平衡设定的数量,选择少数类别的样本。
            new_data.append(self.data[selec_idx, ...]) # 
            new_targets.extend([the_class, ] * the_img_num) #将当前类别的标签复制 the_img_num 次,并将复制的标签添加到新标签列表 new_targets 中,确保标签与样本数据对应。
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Example usage:
imbalance_cifar10 = IMBALANCECIFAR10(root='../data', imb_type='exp', imb_factor=0.01, rand_number=0, train=True, download=True)

print("Class-wise Sample Numbers:", imbalance_cifar10.get_cls_num_list())

val_dataset = datasets.CIFAR10(root='../data', train=False, download=True,transform=transform_val)
original_targets = np.array(val_dataset.targets)
original_class_counts = np.bincount(original_targets)
print("Original CIFAR-10 Class-wise Sample Numbers:", original_class_counts)
相关推荐
Blossom.1183 小时前
基于深度学习的图像分割:使用DeepLabv3实现高效分割
人工智能·python·深度学习·机器学习·分类·机器人·transformer
zzywxc7878 小时前
AI 驱动的软件测试革新:框架、检测与优化实践
人工智能·深度学习·机器学习·数据挖掘·数据分析
Ronin-Lotus9 小时前
深度学习篇---PaddleDetection模型选择
人工智能·深度学习
Blossom.1189 小时前
基于深度学习的医学图像分析:使用CycleGAN实现图像到图像的转换
人工智能·深度学习·目标检测·机器学习·分类·数据挖掘·语音识别
CoovallyAIHub13 小时前
无人机图像+深度学习:湖南农大团队实现稻瘟病分级检测84%准确率
深度学习·算法·计算机视觉
TiAmo zhang13 小时前
深度学习与图像处理案例 │ 图像分类(智能垃圾分拣器)
图像处理·深度学习·分类
zzywxc78715 小时前
随着人工智能技术的飞速发展,大语言模型(Large Language Models, LLMs)已经成为当前AI领域最引人注目的技术突破。
人工智能·深度学习·算法·低代码·机器学习·自动化·排序算法
网安INF16 小时前
【论文阅读】-《RayS: A Ray Searching Method for Hard-label Adversarial Attack》
论文阅读·人工智能·深度学习·计算机视觉·网络安全·对抗攻击
F_D_Z16 小时前
数据集相关类代码回顾理解 | DataLoader\datasets.xxx
python·深度学习
盼小辉丶17 小时前
生成模型实战 | GLOW详解与实现
深度学习·aigc·生成模型