CIFAR10 数据不平衡代码制作与理解

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) # 生成每个类别的样本数量列表。
        print("Generated Image Number List:", img_num_list)
        self.gen_imbalanced_data(img_num_list) # 调用 gen_imbalanced_data 方法,生成不平衡数据集。
        print("Imbalanced Data Generated.")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0] # 找到所有属于当前类别的样本的索引,并存储在 idx 中。
            np.random.shuffle(idx) # 随机打乱当前类别的样本索引顺序,以确保样本的随机性。
            selec_idx = idx[:the_img_num]  # 从当前类别的样本索引中选择前 the_img_num 个索引,即根据不平衡设定的数量,选择少数类别的样本。
            new_data.append(self.data[selec_idx, ...]) # 
            new_targets.extend([the_class, ] * the_img_num) #将当前类别的标签复制 the_img_num 次,并将复制的标签添加到新标签列表 new_targets 中,确保标签与样本数据对应。
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Example usage:
imbalance_cifar10 = IMBALANCECIFAR10(root='../data', imb_type='exp', imb_factor=0.01, rand_number=0, train=True, download=True)

print("Class-wise Sample Numbers:", imbalance_cifar10.get_cls_num_list())

val_dataset = datasets.CIFAR10(root='../data', train=False, download=True,transform=transform_val)
original_targets = np.array(val_dataset.targets)
original_class_counts = np.bincount(original_targets)
print("Original CIFAR-10 Class-wise Sample Numbers:", original_class_counts)
相关推荐
中医正骨葛大夫2 小时前
一文解决如何在Pycharm中创建cuda深度学习环境?
pytorch·深度学习·pycharm·软件安装·cuda·anaconda·配置环境
龙腾AI白云2 小时前
具身智能-高层任务规划(High-level Task Planning)
深度学习·数据挖掘
WWZZ20253 小时前
快速上手大模型:深度学习9(池化层、卷积神经网络1)
人工智能·深度学习·神经网络·算法·机器人·大模型·具身智能
AI即插即用5 小时前
即插即用系列 | 2025 SOTA Strip R-CNN 实战解析:用于遥感目标检测的大条带卷积
人工智能·pytorch·深度学习·目标检测·计算机视觉·cnn·智慧城市
IT油腻大叔6 小时前
DeepSeek-多层注意力计算机制理解
python·深度学习·机器学习
九年义务漏网鲨鱼6 小时前
【多模态大模型面经】现代大模型架构(一): 组注意力机制(GQA)和 RMSNorm
人工智能·深度学习·算法·架构·大模型·强化学习
O***p6046 小时前
机器学习挑战同时也带来了一系列亟待解决的问题。
人工智能·深度学习·机器学习
B站_计算机毕业设计之家7 小时前
python手写数字识别系统 CNN算法 卷积神经网络 OpenCV和Keras模型 计算机视觉 (建议收藏)✅
python·深度学习·opencv·机器学习·计算机视觉·cnn
Valueyou247 小时前
引入基于加权 IoU 的 WiseIoU 回归损失以提升 CT 图像检测鲁棒性
人工智能·python·深度学习·目标检测
这张生成的图像能检测吗7 小时前
(论文速读)SpiralMLP:一个轻量级的视觉MLP架构
图像处理·人工智能·深度学习·计算机视觉·mlp框架·分类、检测、分割