【ILSVRC2012】ImageNet-1k数据集下载与处理脚本

python 复制代码
#!/usr/bin/env python3
"""
ImageNet-1k数据集处理工具
基于CSDN博客:https://blog.csdn.net/qq_45588019/article/details/125642466
功能:下载、解压和分类ImageNet-1k数据集
"""

import os
import subprocess
import shutil
import tarfile
from scipy import io
import sys

class ImageNetProcessor:
    def __init__(self, data_dir="./imagenet_data"):
        self.data_dir = data_dir
        self.train_tar = "ILSVRC2012_img_train.tar"
        self.val_tar = "ILSVRC2012_img_val.tar"
        self.devkit_tar = "ILSVRC2012_devkit_t12.tar.gz"
        os.makedirs(data_dir, exist_ok=True)

    def _download_with_aria2(self, url: str, out_path: str, connections: int = 4, splits: int = 4):
        """
        使用 aria2c 下载(多连接/多分段/断点续传)
        依赖:系统已安装 aria2c
        """
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

        cmd = [
            "aria2c",
            "-c",                        # 断点续传
            "-x", str(connections),      # 单服务器最大连接数
            "-s", str(splits),           # 分段数
            "-k", "1M",                  # 分段大小
            "--allow-overwrite=true",
            "--check-certificate=false", # 等价于 wget --no-check-certificate
            "-o", os.path.basename(out_path),
            "-d", os.path.dirname(out_path),
            url,
        ]
        subprocess.run(cmd, check=True)

    def download_dataset(self):
        """下载ImageNet数据集文件"""
        print("开始下载ImageNet-1k数据集...")

        # 训练集下载
        train_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar"
        print(f"下载训练集: {self.train_tar}")
        self._download_with_aria2(train_url, os.path.join(self.data_dir, self.train_tar))

        # 验证集下载
        val_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar"
        print(f"下载验证集: {self.val_tar}")
        self._download_with_aria2(val_url, os.path.join(self.data_dir, self.val_tar))

        # 标签映射文件下载
        devkit_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz"
        print(f"下载标签映射文件: {self.devkit_tar}")
        self._download_with_aria2(devkit_url, os.path.join(self.data_dir, self.devkit_tar))

        print("所有文件下载完成!")
    
    def extract_train_set(self):
        """解压训练集"""
        print("开始解压训练集...")
        
        train_dir = os.path.join(self.data_dir, "train")
        os.makedirs(train_dir, exist_ok=True)
        
        # 解压主tar文件
        train_tar_path = os.path.join(self.data_dir, self.train_tar)
        print(f"解压主训练集文件到: {train_dir}")
        subprocess.run([
            "tar", "-xvf", train_tar_path, "-C", train_dir
        ], check=True)
        
        # 进入train目录
        original_cwd = os.getcwd()
        os.chdir(train_dir)
        
        try:
            # 解压每个类别的tar文件
            tar_files = [f for f in os.listdir('.') if f.endswith('.tar')]
            print(f"找到 {len(tar_files)} 个类别tar文件")
            
            for tar_file in tar_files:
                class_name = tar_file.replace('.tar', '')
                class_dir = os.path.join(train_dir, class_name)
                
                # 创建类别目录
                os.makedirs(class_dir, exist_ok=True)
                
                # 解压类别tar文件
                print(f"解压类别: {class_name}")
                subprocess.run([
                    "tar", "-xvf", tar_file, "-C", class_dir
                ], check=True)
                
                # 删除原始的tar文件
                os.remove(tar_file)
                
        finally:
            os.chdir(original_cwd)
        
        # 验证解压结果
        self._verify_train_extraction(train_dir)
    
    def _verify_train_extraction(self, train_dir):
        """验证训练集解压结果"""
        print("验证训练集解压结果...")
        
        # 统计文件夹数量(应该是1000个类别)
        result = subprocess.run([
            "bash", "-c", f"cd {train_dir} && ls -lR | grep '^d' | wc -l"
        ], capture_output=True, text=True)
        folder_count = int(result.stdout.strip())
        print(f"类别文件夹数量: {folder_count} (期望: 1000)")
        
        # 统计图片文件数量(应该是1,281,167张)
        result = subprocess.run([
            "bash", "-c", f"cd {train_dir} && ls -lR | grep '^-' | wc -l"
        ], capture_output=True, text=True)
        file_count = int(result.stdout.strip())
        print(f"图片文件数量: {file_count} (期望: 1,281,167)")
    
    def extract_validation_set(self):
        """解压验证集并进行分类"""
        print("开始解压验证集...")
        
        val_dir = os.path.join(self.data_dir, "val")
        os.makedirs(val_dir, exist_ok=True)
        
        # 解压验证集tar文件
        val_tar_path = os.path.join(self.data_dir, self.val_tar)
        print(f"解压验证集到: {val_dir}")
        subprocess.run([
            "tar", "xvf", val_tar_path, "-C", val_dir
        ], check=True)
        
        # 解压标签映射文件
        devkit_path = os.path.join(self.data_dir, self.devkit_tar)
        print("解压标签映射文件...")
        subprocess.run([
            "tar", "-xzf", devkit_path, "-C", self.data_dir
        ], check=True)
        
        # 执行验证集分类
        self._classify_validation_set(val_dir)
    
    def _classify_validation_set(self, val_dir):
        """将验证集图片分类到对应的1000个文件夹中"""
        print("开始验证集图片分类...")
        
        devkit_dir = os.path.join(self.data_dir, "ILSVRC2012_devkit_t12")
        
        # 加载synset和验证集标签
        synset = io.loadmat(os.path.join(devkit_dir, 'data', 'meta.mat'))
        
        # 读取验证集ground truth
        ground_truth_path = os.path.join(devkit_dir, 'data', 'ILSVRC2012_validation_ground_truth.txt')
        with open(ground_truth_path, 'r') as f:
            lines = f.readlines()
        labels = [int(line.strip()) for line in lines]
        
        # 移动验证集图片到对应类别文件夹
        for filename in os.listdir(val_dir):
            if not filename.endswith('.JPEG'):
                continue
                
            # 提取图片ID (val_id)
            val_id = int(filename.split('.')[0].split('_')[-1])
            
            # 获取对应的ILSVRC ID和类别名称
            ILSVRC_ID = labels[val_id - 1]  # 标签从1开始
            WIND = synset['synsets'][ILSVRC_ID - 1][0][1][0]  # 类别名称
            
            print(f"处理图片: {filename}, val_id:{val_id}, ILSVRC_ID:{ILSVRC_ID}, 类别:{WIND}")
            
            # 创建类别目录
            output_dir = os.path.join(val_dir, WIND)
            os.makedirs(output_dir, exist_ok=True)
            
            # 移动图片到对应类别目录
            src_path = os.path.join(val_dir, filename)
            dst_path = os.path.join(output_dir, filename)
            shutil.move(src_path, dst_path)
        
        print("验证集分类完成!")
        
        # 验证分类结果
        self._verify_val_classification(val_dir)
    
    def _verify_val_classification(self, val_dir):
        """验证验证集分类结果"""
        print("验证验证集分类结果...")
        
        # 统计类别文件夹数量
        folders = [d for d in os.listdir(val_dir) if os.path.isdir(os.path.join(val_dir, d))]
        print(f"验证集类别文件夹数量: {len(folders)} (期望: 1000)")
        
        # 统计总图片数量
        total_images = 0
        for folder in folders:
            folder_path = os.path.join(val_dir, folder)
            images = [f for f in os.listdir(folder_path) if f.endswith('.JPEG')]
            total_images += len(images)
        
        print(f"验证集图片总数: {total_images} (期望: 50,000)")

def main():
    """主函数"""
    print("=" * 60)
    print("ImageNet-1k数据集处理工具")
    print("基于CSDN博客:https://blog.csdn.net/qq_45588019/article/details/125642466")
    print("=" * 60)
    
    processor = ImageNetProcessor()
    
    while True:
        print("\n请选择操作:")
        print("1. 下载数据集")
        print("2. 解压训练集")
        print("3. 解压并分类验证集")
        print("4. 完整处理流程")
        print("5. 退出")
        
        choice = input("请输入选择 (1-5): ").strip()
        
        if choice == "1":
            try:
                processor.download_dataset()
            except Exception as e:
                print(f"下载失败: {e}")
                
        elif choice == "2":
            try:
                processor.extract_train_set()
            except Exception as e:
                print(f"解压训练集失败: {e}")
                
        elif choice == "3":
            try:
                processor.extract_validation_set()
            except Exception as e:
                print(f"解压验证集失败: {e}")
                
        elif choice == "4":
            try:
                print("开始完整处理流程...")
                processor.download_dataset()
                processor.extract_train_set()
                processor.extract_validation_set()
                print("完整处理流程完成!")
            except Exception as e:
                print(f"完整处理流程失败: {e}")
                
        elif choice == "5":
            print("退出程序")
            break
        else:
            print("无效选择,请重新输入")

if __name__ == "__main__":
    main()
相关推荐
xuzhiqiang07242 小时前
【Flask】四、flask连接并操作数据库
数据库·python·flask
醒了就刷牙2 小时前
Hugging_Face实战
python
Volunteer Technology2 小时前
LangGraph的Agent的上下文
人工智能·后端·python·langchain
狸猫算君2 小时前
别再用ChatGPT群发祝福了!手把手教你“喂”出一个懂人情的AI,连马术梗都能接住
机器学习
luoluoal2 小时前
基于python的医疗知识图谱问答系统(源码+文档)
python·mysql·django·毕业设计·源码
小比特_蓝光2 小时前
STL小知识点——C++
java·开发语言·c++·python
宁远x2 小时前
【万字长文】PyTorch FSDP 设计解读与性能分析
人工智能·pytorch·深度学习·云计算
何伯特2 小时前
PyTorch基本用法介绍:从零开始构建深度学习工作流
人工智能·pytorch·深度学习
I'm Jie2 小时前
【已解决】SqlAlchemy 插入 MySQL JSON 字段时 None 变为 ‘null‘ 字符串,WHERE IS NULL 失效
数据库·python·mysql·json·fastapi·sqlalchemy