CIFAR10 数据不平衡代码制作与理解

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) # 生成每个类别的样本数量列表。
        print("Generated Image Number List:", img_num_list)
        self.gen_imbalanced_data(img_num_list) # 调用 gen_imbalanced_data 方法,生成不平衡数据集。
        print("Imbalanced Data Generated.")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0] # 找到所有属于当前类别的样本的索引,并存储在 idx 中。
            np.random.shuffle(idx) # 随机打乱当前类别的样本索引顺序,以确保样本的随机性。
            selec_idx = idx[:the_img_num]  # 从当前类别的样本索引中选择前 the_img_num 个索引,即根据不平衡设定的数量,选择少数类别的样本。
            new_data.append(self.data[selec_idx, ...]) # 
            new_targets.extend([the_class, ] * the_img_num) #将当前类别的标签复制 the_img_num 次,并将复制的标签添加到新标签列表 new_targets 中,确保标签与样本数据对应。
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Example usage:
imbalance_cifar10 = IMBALANCECIFAR10(root='../data', imb_type='exp', imb_factor=0.01, rand_number=0, train=True, download=True)

print("Class-wise Sample Numbers:", imbalance_cifar10.get_cls_num_list())

val_dataset = datasets.CIFAR10(root='../data', train=False, download=True,transform=transform_val)
original_targets = np.array(val_dataset.targets)
original_class_counts = np.bincount(original_targets)
print("Original CIFAR-10 Class-wise Sample Numbers:", original_class_counts)
相关推荐
Fansv5871 小时前
深度学习-6.用于计算机视觉的深度学习
人工智能·深度学习·计算机视觉
deephub2 小时前
LLM高效推理:KV缓存与分页注意力机制深度解析
人工智能·深度学习·语言模型
奋斗的袍子0072 小时前
Spring AI + Ollama 实现调用DeepSeek-R1模型API
人工智能·spring boot·深度学习·spring·springai·deepseek
青衫弦语2 小时前
【论文精读】VLM-AD:通过视觉-语言模型监督实现端到端自动驾驶
人工智能·深度学习·语言模型·自然语言处理·自动驾驶
美狐美颜sdk2 小时前
直播美颜SDK的底层技术解析:图像处理与深度学习的结合
图像处理·人工智能·深度学习·直播美颜sdk·视频美颜sdk·美颜api·滤镜sdk
WHATEVER_LEO2 小时前
【每日论文】Text-guided Sparse Voxel Pruning for Efficient 3D Visual Grounding
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理
Binary Oracle3 小时前
RNN中远距离时间步梯度消失问题及解决办法
人工智能·rnn·深度学习
阿_旭3 小时前
基于YOLO11深度学习的糖尿病视网膜病变检测与诊断系统【python源码+Pyqt5界面+数据集+训练代码】
人工智能·python·深度学习·视网膜病变检测
小宇爱3 小时前
38、深度学习-自学之路-自己搭建深度学习框架-3、自动梯度计算改进
人工智能·深度学习·自然语言处理
小白狮ww6 小时前
国产超强开源大语言模型 DeepSeek-R1-70B 一键部署教程
人工智能·深度学习·机器学习·语言模型·自然语言处理·开源·deepseek