python
复制代码
#!/usr/bin/env python3
"""
ImageNet-1k数据集处理工具
基于CSDN博客:https://blog.csdn.net/qq_45588019/article/details/125642466
功能:下载、解压和分类ImageNet-1k数据集
"""
import os
import subprocess
import shutil
import tarfile
from scipy import io
import sys
class ImageNetProcessor:
def __init__(self, data_dir="./imagenet_data"):
self.data_dir = data_dir
self.train_tar = "ILSVRC2012_img_train.tar"
self.val_tar = "ILSVRC2012_img_val.tar"
self.devkit_tar = "ILSVRC2012_devkit_t12.tar.gz"
os.makedirs(data_dir, exist_ok=True)
def _download_with_aria2(self, url: str, out_path: str, connections: int = 4, splits: int = 4):
"""
使用 aria2c 下载(多连接/多分段/断点续传)
依赖:系统已安装 aria2c
"""
os.makedirs(os.path.dirname(out_path), exist_ok=True)
cmd = [
"aria2c",
"-c", # 断点续传
"-x", str(connections), # 单服务器最大连接数
"-s", str(splits), # 分段数
"-k", "1M", # 分段大小
"--allow-overwrite=true",
"--check-certificate=false", # 等价于 wget --no-check-certificate
"-o", os.path.basename(out_path),
"-d", os.path.dirname(out_path),
url,
]
subprocess.run(cmd, check=True)
def download_dataset(self):
"""下载ImageNet数据集文件"""
print("开始下载ImageNet-1k数据集...")
# 训练集下载
train_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar"
print(f"下载训练集: {self.train_tar}")
self._download_with_aria2(train_url, os.path.join(self.data_dir, self.train_tar))
# 验证集下载
val_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar"
print(f"下载验证集: {self.val_tar}")
self._download_with_aria2(val_url, os.path.join(self.data_dir, self.val_tar))
# 标签映射文件下载
devkit_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz"
print(f"下载标签映射文件: {self.devkit_tar}")
self._download_with_aria2(devkit_url, os.path.join(self.data_dir, self.devkit_tar))
print("所有文件下载完成!")
def extract_train_set(self):
"""解压训练集"""
print("开始解压训练集...")
train_dir = os.path.join(self.data_dir, "train")
os.makedirs(train_dir, exist_ok=True)
# 解压主tar文件
train_tar_path = os.path.join(self.data_dir, self.train_tar)
print(f"解压主训练集文件到: {train_dir}")
subprocess.run([
"tar", "-xvf", train_tar_path, "-C", train_dir
], check=True)
# 进入train目录
original_cwd = os.getcwd()
os.chdir(train_dir)
try:
# 解压每个类别的tar文件
tar_files = [f for f in os.listdir('.') if f.endswith('.tar')]
print(f"找到 {len(tar_files)} 个类别tar文件")
for tar_file in tar_files:
class_name = tar_file.replace('.tar', '')
class_dir = os.path.join(train_dir, class_name)
# 创建类别目录
os.makedirs(class_dir, exist_ok=True)
# 解压类别tar文件
print(f"解压类别: {class_name}")
subprocess.run([
"tar", "-xvf", tar_file, "-C", class_dir
], check=True)
# 删除原始的tar文件
os.remove(tar_file)
finally:
os.chdir(original_cwd)
# 验证解压结果
self._verify_train_extraction(train_dir)
def _verify_train_extraction(self, train_dir):
"""验证训练集解压结果"""
print("验证训练集解压结果...")
# 统计文件夹数量(应该是1000个类别)
result = subprocess.run([
"bash", "-c", f"cd {train_dir} && ls -lR | grep '^d' | wc -l"
], capture_output=True, text=True)
folder_count = int(result.stdout.strip())
print(f"类别文件夹数量: {folder_count} (期望: 1000)")
# 统计图片文件数量(应该是1,281,167张)
result = subprocess.run([
"bash", "-c", f"cd {train_dir} && ls -lR | grep '^-' | wc -l"
], capture_output=True, text=True)
file_count = int(result.stdout.strip())
print(f"图片文件数量: {file_count} (期望: 1,281,167)")
def extract_validation_set(self):
"""解压验证集并进行分类"""
print("开始解压验证集...")
val_dir = os.path.join(self.data_dir, "val")
os.makedirs(val_dir, exist_ok=True)
# 解压验证集tar文件
val_tar_path = os.path.join(self.data_dir, self.val_tar)
print(f"解压验证集到: {val_dir}")
subprocess.run([
"tar", "xvf", val_tar_path, "-C", val_dir
], check=True)
# 解压标签映射文件
devkit_path = os.path.join(self.data_dir, self.devkit_tar)
print("解压标签映射文件...")
subprocess.run([
"tar", "-xzf", devkit_path, "-C", self.data_dir
], check=True)
# 执行验证集分类
self._classify_validation_set(val_dir)
def _classify_validation_set(self, val_dir):
"""将验证集图片分类到对应的1000个文件夹中"""
print("开始验证集图片分类...")
devkit_dir = os.path.join(self.data_dir, "ILSVRC2012_devkit_t12")
# 加载synset和验证集标签
synset = io.loadmat(os.path.join(devkit_dir, 'data', 'meta.mat'))
# 读取验证集ground truth
ground_truth_path = os.path.join(devkit_dir, 'data', 'ILSVRC2012_validation_ground_truth.txt')
with open(ground_truth_path, 'r') as f:
lines = f.readlines()
labels = [int(line.strip()) for line in lines]
# 移动验证集图片到对应类别文件夹
for filename in os.listdir(val_dir):
if not filename.endswith('.JPEG'):
continue
# 提取图片ID (val_id)
val_id = int(filename.split('.')[0].split('_')[-1])
# 获取对应的ILSVRC ID和类别名称
ILSVRC_ID = labels[val_id - 1] # 标签从1开始
WIND = synset['synsets'][ILSVRC_ID - 1][0][1][0] # 类别名称
print(f"处理图片: {filename}, val_id:{val_id}, ILSVRC_ID:{ILSVRC_ID}, 类别:{WIND}")
# 创建类别目录
output_dir = os.path.join(val_dir, WIND)
os.makedirs(output_dir, exist_ok=True)
# 移动图片到对应类别目录
src_path = os.path.join(val_dir, filename)
dst_path = os.path.join(output_dir, filename)
shutil.move(src_path, dst_path)
print("验证集分类完成!")
# 验证分类结果
self._verify_val_classification(val_dir)
def _verify_val_classification(self, val_dir):
"""验证验证集分类结果"""
print("验证验证集分类结果...")
# 统计类别文件夹数量
folders = [d for d in os.listdir(val_dir) if os.path.isdir(os.path.join(val_dir, d))]
print(f"验证集类别文件夹数量: {len(folders)} (期望: 1000)")
# 统计总图片数量
total_images = 0
for folder in folders:
folder_path = os.path.join(val_dir, folder)
images = [f for f in os.listdir(folder_path) if f.endswith('.JPEG')]
total_images += len(images)
print(f"验证集图片总数: {total_images} (期望: 50,000)")
def main():
"""主函数"""
print("=" * 60)
print("ImageNet-1k数据集处理工具")
print("基于CSDN博客:https://blog.csdn.net/qq_45588019/article/details/125642466")
print("=" * 60)
processor = ImageNetProcessor()
while True:
print("\n请选择操作:")
print("1. 下载数据集")
print("2. 解压训练集")
print("3. 解压并分类验证集")
print("4. 完整处理流程")
print("5. 退出")
choice = input("请输入选择 (1-5): ").strip()
if choice == "1":
try:
processor.download_dataset()
except Exception as e:
print(f"下载失败: {e}")
elif choice == "2":
try:
processor.extract_train_set()
except Exception as e:
print(f"解压训练集失败: {e}")
elif choice == "3":
try:
processor.extract_validation_set()
except Exception as e:
print(f"解压验证集失败: {e}")
elif choice == "4":
try:
print("开始完整处理流程...")
processor.download_dataset()
processor.extract_train_set()
processor.extract_validation_set()
print("完整处理流程完成!")
except Exception as e:
print(f"完整处理流程失败: {e}")
elif choice == "5":
print("退出程序")
break
else:
print("无效选择,请重新输入")
if __name__ == "__main__":
main()