CIFAR10 数据不平衡代码制作与理解

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) # 生成每个类别的样本数量列表。
        print("Generated Image Number List:", img_num_list)
        self.gen_imbalanced_data(img_num_list) # 调用 gen_imbalanced_data 方法,生成不平衡数据集。
        print("Imbalanced Data Generated.")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0] # 找到所有属于当前类别的样本的索引,并存储在 idx 中。
            np.random.shuffle(idx) # 随机打乱当前类别的样本索引顺序,以确保样本的随机性。
            selec_idx = idx[:the_img_num]  # 从当前类别的样本索引中选择前 the_img_num 个索引,即根据不平衡设定的数量,选择少数类别的样本。
            new_data.append(self.data[selec_idx, ...]) # 
            new_targets.extend([the_class, ] * the_img_num) #将当前类别的标签复制 the_img_num 次,并将复制的标签添加到新标签列表 new_targets 中,确保标签与样本数据对应。
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Example usage:
imbalance_cifar10 = IMBALANCECIFAR10(root='../data', imb_type='exp', imb_factor=0.01, rand_number=0, train=True, download=True)

print("Class-wise Sample Numbers:", imbalance_cifar10.get_cls_num_list())

val_dataset = datasets.CIFAR10(root='../data', train=False, download=True,transform=transform_val)
original_targets = np.array(val_dataset.targets)
original_class_counts = np.bincount(original_targets)
print("Original CIFAR-10 Class-wise Sample Numbers:", original_class_counts)
相关推荐
HackTorjan11 分钟前
AI图像处理的核心原理:深度学习驱动的视觉特征提取与重构
图像处理·人工智能·深度学习·django·sqlite
AI机器学习算法6 小时前
深度学习模型演进:6个里程碑式CNN架构
人工智能·深度学习·cnn·大模型·ai学习路线
AI医影跨模态组学7 小时前
如何将深度学习MTSR与膀胱癌ITGB8/TGF-β/WNT机制建立关联,并进一步解释其与患者预后及肿瘤侵袭、免疫抑制的生物学联系
人工智能·深度学习·论文·医学影像
SomeB1oody9 小时前
【Python深度学习】3.4. 循环神经网络(RNN)实战:预测股价
开发语言·人工智能·python·rnn·深度学习·机器学习
ACCELERATOR_LLC9 小时前
【DataWhale组队学习】DIY-LLM Task2 PyTorch 与资源核算
人工智能·pytorch·深度学习·大模型
Theodore_102211 小时前
深度学习(15):倾斜数据集 & 精确率-召回率权衡
人工智能·笔记·深度学习·机器学习·知识图谱
li星野12 小时前
词嵌入技术、注意力机制、MoE架构、主流Transformer架构
深度学习·架构·transformer
Omics Pro15 小时前
华大等NC|微生物多样性与抗菌物质发现
大数据·人工智能·深度学习·语言模型·excel
在秃头的路上啊15 小时前
Cascade R50 + PointRend
深度学习
数智工坊15 小时前
R-CNN目标检测算法精读全解
网络·人工智能·深度学习·算法·目标检测·r语言·cnn