【图像分类实用脚本】数据可视化以及高数量类别截断

图像分类时,如果某个类别或者某些类别的数量远大于其他类别的话,模型在计算的时候,更倾向于拟合数量更多的类别;因此,观察类别数量以及对数据量多的类别进行截断是很有必要的。

1.准备数据

数据的格式为图像分类数据集格式,根目录下分为train和val文件夹,每个文件夹下以类别名命名的子文件夹:

bash 复制代码
.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── ...
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ...
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── ...
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ...

2.查看数据分布

python 复制代码
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def count_images(directory, image_extensions):
    """
    统计每个子文件夹中的图像数量。

    :param directory: 主目录路径(train或val)
    :param image_extensions: 允许的图像文件扩展名元组
    :return: 一个字典,键为类别名,值为图像数量
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            # 统计符合扩展名的文件数量
            image_count = sum(
                1 for file in os.listdir(class_path)
                if file.lower().endswith(image_extensions)
            )
            counts[class_name] = image_count
    return counts

def count_images_in_single_directory(directory, image_extensions):
    """
    统计单个目录下每个类别的图像数量。

    :param directory: 主目录路径
    :param image_extensions: 允许的图像文件扩展名元组
    :return: 一个字典,键为类别名,值为图像数量
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            image_count = sum(
                1 for file in os.listdir(class_path)
                if file.lower().endswith(image_extensions)
            )
            counts[class_name] = image_count
    return counts

def autolabel(ax, rects):
    """
    在每个柱状图上方添加数值标签。

    :param ax: Matplotlib 的轴对象
    :param rects: 柱状图对象
    """
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):
    """
    绘制并保存训练集和验证集中每个类别的图像数量分布柱状图。
    如果没有验证集数据,则只绘制训练集数据。

    :param all_classes: 所有类别名称列表
    :param train_values: 训练集中每个类别的图像数量列表
    :param val_values: 验证集中每个类别的图像数量列表(如果有的话)
    :param output_path: 保存图表的文件路径
    :param has_val: 是否包含验证集数据
    """
    x = np.arange(len(all_classes))  # 类别位置
    width = 0.35  # 柱状图的宽度

    fig, ax = plt.subplots(figsize=(12, 8))

    if has_val:
        rects1 = ax.bar(x - width/2, train_values, width, label='Train')
        rects2 = ax.bar(x + width/2, val_values, width, label='Validation')
    else:
        rects1 = ax.bar(x, train_values, width, label='Count')

    # 添加一些文本标签
    ax.set_xlabel('Category')
    ax.set_ylabel('Number of Images')
    title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(all_classes, rotation=45, ha='right')
    ax.legend() if has_val else ax.legend(['Count'])

    # 自动标注柱状图上的数值
    autolabel(ax, rects1)
    if has_val:
        autolabel(ax, rects2)

    fig.tight_layout()

    # 保存图表为图片文件
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"图表已保存到 {output_path}")

def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):
    """
    计算并展示统计数据,包括总图像数量、类别数量、平均每个类别的图像数量和类别占比。

    :param counts_dict: 类别名称与图像数量的字典
    :param dataset_name: 数据集名称(例如 'Train', 'Validation', 'Dataset')
    :param save_csv: 是否保存统计结果为 CSV 文件
    """
    total_images = sum(counts_dict.values())
    num_classes = len(counts_dict)
    avg_per_class = total_images / num_classes if num_classes > 0 else 0

    # 计算每个类别的占比
    category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 
                            for cls, count in counts_dict.items()}

    # 创建 DataFrame
    df = pd.DataFrame({
        '类别名称': list(counts_dict.keys()),
        '图像数量': list(counts_dict.values()),
        '占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]
    })

    # 排序 DataFrame 按图像数量降序
    df = df.sort_values(by='图像数量', ascending=False)

    print(f"\n===== {dataset_name} 数据统计 =====")
    print(df.to_string(index=False))

    print(f"总图像数量: {total_images}")
    print(f"类别数量: {num_classes}")
    print(f"平均每个类别的图像数量: {avg_per_class:.2f}")

    # 根据 save_csv 参数决定是否保存为 CSV 文件
    if save_csv:
        # 将数据集名称转换为小写并去除空格,以作为文件名的一部分
        sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")
        csv_filename = f"{sanitized_name}_statistics.csv"
        df.to_csv(csv_filename, index=False, encoding='utf-8-sig')
        print(f"统计表已保存为 {csv_filename}\n")

def main():
    # ================== 配置参数 ==================
    # 设置数据集的根目录路径
    dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218'  # 替换为你的数据集路径

    # 定义train和val目录
    train_dir = os.path.join(dataset_root, 'train')
    val_dir = os.path.join(dataset_root, 'val')

    # 定义允许的图像文件扩展名
    image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')

    # 输出图表的路径
    output_path = 'dataset_distribution.png'  # 你可以更改为你想要的文件名和路径

    # 是否保存统计结果为 CSV 文件(默认不保存)
    SAVE_CSV = False  # 设置为 True 以启用保存 CSV

    # ================== 统计图像数量 ==================
    has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)
    has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)

    if has_train and has_val:
        print("检测到 'train' 和 'val' 目录。统计训练集和验证集中的图像数量...")
        train_counts = count_images(train_dir, image_extensions)
        val_counts = count_images(val_dir, image_extensions)

        # 获取所有类别的名称(确保train和val中的类别一致)
        all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))

        # 准备绘图数据
        train_values = [train_counts.get(cls, 0) for cls in all_classes]
        val_values = [val_counts.get(cls, 0) for cls in all_classes]

        # ================== 计算并展示统计数据 ==================
        compute_and_display_statistics(train_counts, '训练集 (Train)', save_csv=SAVE_CSV)
        compute_and_display_statistics(val_counts, '验证集 (Validation)', save_csv=SAVE_CSV)

        # ================== 绘制并保存图表 ==================
        print("绘制并保存训练集和验证集的图表...")
        plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)

    else:
        print("未检测到 'train' 和 'val' 目录。将统计主目录下的图像数量...")
        # 如果没有train和val目录,则统计主目录下的图像分布
        main_counts = count_images_in_single_directory(dataset_root, image_extensions)

        # 获取所有类别的名称
        all_classes = sorted(main_counts.keys())

        # 准备绘图数据
        main_values = [main_counts.get(cls, 0) for cls in all_classes]

        # 定义输出图表路径(可以区分不同的输出文件名)
        output_path_single = 'dataset_distribution_single.png'  # 或者使用与train_val相同的output_path

        # ================== 计算并展示统计数据 ==================
        compute_and_display_statistics(main_counts, '数据集 (Dataset)', save_csv=SAVE_CSV)

        # ================== 绘制并保存图表 ==================
        print("绘制并保存主目录的图表...")
        plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)

if __name__ == "__main__":
    main()

下图为原始数据集运行结果,可以看到数据存在严重不均衡问题

3.数据截断

python 复制代码
import os
import shutil
import random


def count_images(directory, image_extensions):
    """
    统计每个子文件夹中的图像文件路径列表。

    :param directory: 主目录路径(train或val)
    :param image_extensions: 允许的图像文件扩展名列表
    :return: 一个字典,键为类别名,值为图像文件路径列表
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            # 获取符合扩展名的文件列表
            images = [
                file for file in os.listdir(class_path)
                if file.lower().endswith(tuple(image_extensions))
            ]
            image_paths = [os.path.join(class_path, img) for img in images]
            counts[class_name] = image_paths
    return counts


def truncate_dataset(class_images, threshold, seed=42):
    """
    对每个类别的图像进行截断,如果超过阈值则随机选择一定数量的图像。

    :param class_images: 一个字典,键为类别名,值为图像文件路径列表
    :param threshold: 每个类别的图像数量阈值
    :param seed: 随机种子
    :return: 截断后的类别图像字典
    """
    truncated = {}
    random.seed(seed)
    for class_name, images in class_images.items():
        if len(images) > threshold:
            truncated_images = random.sample(images, threshold)
            truncated[class_name] = truncated_images
            print(f"类别 '{class_name}' 超过阈值 {threshold},已随机选择 {threshold} 张图像。")
        else:
            truncated[class_name] = images
            print(f"类别 '{class_name}' 不超过阈值 {threshold},保留所有 {len(images)} 张图像。")
    return truncated


def copy_images(truncated_data, subset, output_root):
    """
    将截断后的图像复制到输出目录,保持原有的目录结构。

    :param truncated_data: 截断后的类别图像字典
    :param subset: 'train' 或 'val'
    :param output_root: 输出根目录路径
    """
    for class_name, images in truncated_data.items():
        dest_dir = os.path.join(output_root, subset, class_name)
        os.makedirs(dest_dir, exist_ok=True)
        for img_path in images:
            img_name = os.path.basename(img_path)
            dest_path = os.path.join(dest_dir, img_name)
            shutil.copy2(img_path, dest_path)
    print(f"'{subset}' 子集已复制到 {output_root}")


def main():
    """
    主函数,执行数据集截断和复制操作。
    """
    # ================== 配置参数 ==================

    # 原始数据集根目录路径
    input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224'  # 替换为你的原始数据集路径

    # 截断后数据集的输出根目录路径
    output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate'  # 替换为你希望保存截断后数据集的路径

    # 训练集每个类别的图像数量阈值
    train_threshold = 2000  # 设置为你需要的训练集阈值

    # 验证集每个类别的图像数量阈值
    val_threshold = 400  # 设置为你需要的验证集阈值

    # 允许的图像文件扩展名
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']

    # 随机种子以确保可重复性
    random_seed = 42

    # ================== 脚本实现 ==================

    # 设置随机种子
    random.seed(random_seed)

    # 定义train和val目录路径
    train_input_dir = os.path.join(input_dir, 'train')
    val_input_dir = os.path.join(input_dir, 'val')

    # 统计train和val中的图像
    print("统计训练集中的图像数量...")
    train_counts = count_images(train_input_dir, image_extensions)
    print("统计验证集中的图像数量...")
    val_counts = count_images(val_input_dir, image_extensions)

    # 截断train和val中的图像
    print("\n截断训练集中的图像...")
    truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)
    print("\n截断验证集中的图像...")
    truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)

    # 复制截断后的图像到输出目录
    print("\n复制截断后的训练集图像...")
    copy_images(truncated_train, 'train', output_dir)
    print("复制截断后的验证集图像...")
    copy_images(truncated_val, 'val', output_dir)

    print("\n数据集截断完成。")


if __name__ == "__main__":
    main()

再次查看已经符合截断后的数据分布了

相关推荐
Sylvia33.4 天前
火星数据:解构斯诺克每一杆进攻背后的数字语言
java·前端·python·数据挖掘·数据分析
Flying pigs~~4 天前
机器学习之逻辑回归
人工智能·机器学习·数据挖掘·数据分析·逻辑回归
LCG元4 天前
低功耗显示方案:STM32L0驱动OLED,动态波形绘制与优化
stm32·嵌入式硬件·信息可视化
YangYang9YangYan4 天前
2026中专计算机专业学数据分析的实用价值分析
数据挖掘·数据分析
YangYang9YangYan4 天前
2026高职大数据管理与应用专业学数据分析的价值与前景
数据挖掘·数据分析
babe小鑫4 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
TDengine (老段)4 天前
TDengine IDMP 数据可视化——散点图
大数据·数据库·物联网·信息可视化·时序数据库·tdengine·涛思数据
发哥来了4 天前
主流GEO优化系统技术对比评测
人工智能·信息可视化
赤月奇5 天前
https改为http
数据挖掘·https·ssl
码农三叔5 天前
(3-2-01)视觉感知:目标检测与分类
人工智能·目标检测·分类·机器人·人机交互·人形机器人