基于Pytorch的ResNet垃圾图片分类
1. 数据集预处理
1.1 画图片的宽高分布散点图
python
import os
import matplotlib.pyplot as plt
import PIL.Image as Image
def plot_resolution(dataset_root_path):
image_size_list = []#存放图片尺寸
for root, dirs, files in os.walk(dataset_root_path):
for file in files:
image_full_path = os.path.join(root, file)
image = Image.open(image_full_path)
image_size = image.size
image_size_list.append(image_size)
print(image_size_list)
image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽
image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高
plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题
plt.scatter(image_width_list, image_height_list, s=1)
plt.xlabel('宽')
plt.ylabel('高')
plt.title('图像宽高分布散点图')
plt.show()
if __name__ == '__main__':
dataset_root_path = "F:\数据与代码\dataset"
plot_resolution(dataset_root_path)
运行结果:
1.2 画出数据集的各个类别图片数量的条形图
文件组织结构:
python
def plot_bar(dataset_root_path):
file_name_list = []
file_num_list = []
for root, dirs, files in os.walk(dataset_root_path):
if len(dirs) != 0 :
for dir in dirs:
file_name_list.append(dir)
file_num_list.append(len(files))
file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]
#[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]
mean = np.mean(file_num_list)
print("mean= ", mean)
bar_positions = np.arange(len(file_name_list))
fig, ax = plt.subplots()
ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度
ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False # 解决图像中的负号乱码问题
ax.set_xticks(bar_positions)#设置x轴的刻度
ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签
ax.set_ylabel("类别数量")
ax.set_title("各个类别数量分布散点图")
plt.show()
运行结果
1.3 删除宽高有问题的图片
python
import os
import PIL.Image as Image
MIN = 200
MAX = 2000
ratio = 0.5
def delete_img(dataset_root_path):
delete_img_list = [] #需要删除的图片地址
for root, dirs, files in os.walk(dataset_root_path):
for file in files:
img_full_path = os.path.join(root, file)
img = Image.open(img_full_path)
img_size = img.size
max_l = img_size[0] if img_size[0] > img_size[1] else img_size[1]
min_l = img_size[0] if img_size[0] < img_size[1] else img_size[1]
# 把图片宽高限制在 200~2000 这里可能会重复添加图片路径
if img_size[0] < MIN or img_size[1] < MIN:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
elif img_size[0] > MAX or img_size[1] > MAX:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
#避免图片窄长
elif min_l / max_l < ratio:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
for img in delete_img_list:
print("正在删除", img)
os.remove(img)
if __name__ == '__main__':
dataset_root_img = 'F:\数据与代码\dataset'
delete_img(dataset_root_img)
再次运行1.1 和1.2的代码得到处理后的数据集宽高分布和类别数量
1.4 数据增强
python
import os
import cv2
#水平翻转
import numpy as np
def Horizontal(image):
return cv2.flip(image, 1, dst=None)
#垂直翻转
def Vertical(image):
return cv2.flip(image, 0, dst=None)
threshold = 200 #阈值
#数据增强
def data_augmentation(from_root_path, save_root_path):
for root, dirs, files in os.walk(from_root_path):
for file in files:
img_full_path = os.path.join(root, file)
split = os.path.split(img_full_path)
save_path = os.path.join(save_root_path, os.path.split(split[0])[1])
print(save_path)
if os.path.isdir(save_path) == False:#文件夹不存在就创建
os.makedirs(save_path)
img = cv2.imdecode(np.fromfile(img_full_path, dtype=np.uint8), -1)#读取含中文的路径
cv2.imencode('.jpg', img)[1].tofile(os.path.join(save_path,file[:-5]+ "_original.jpg")) #保存原图
if len(files) > 0 and len(files) < threshold: # 类别数量小于阈值,需要对该类别的所有图片进行数据增强
img_horizontal = Horizontal(img)
cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file[:-5] + "_horizontal.jpg"))
img_vertical = Vertical(img)
cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file[:-5] + "_vertical.jpg"))
else:
pass
if __name__ == '__main__':
from_root_path = 'F:\数据与代码\dataset'
save_root_path = 'F:\数据与代码\enhance_dataset'
data_augmentation(from_root_path, save_root_path)
1.5 数据集平衡处理
将图片数量超过阈值的类别删除一部分图片
python
import os
import random
threshold = 300
def dataset_balance(dataset_root_path):
for root, dirs, files in os.walk(dataset_root_path):
if len(files) > threshold:
delete_img_list = []
for file in files:
img_full_path = os.path.join(root, file)
delete_img_list.append(img_full_path)
random.shuffle(delete_img_list)
delete_img_list = delete_img_list[threshold:]
for img in delete_img_list:
os.remove(img)
print("成功删除", img)
if __name__ == '__main__':
dataset_root_path = 'F:\数据与代码\enhance_dataset'
dataset_balance(dataset_root_path)
1.6 求图像的均值和方差
python
from torchvision import transforms as T
import torch
from torchvision.datasets import ImageFolder
from tqdm import tqdm
transform = T.Compose([
T.RandomResizedCrop(224),#随机采样并缩放为 224X224
T.ToTensor(),
])
def getStat(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)
#均值 方差
mean = torch.zeros(3)#三维
std = torch.zeros(3)
for X, _ in tqdm(train_loader):# tqdm添加进度条
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root='F:/数据与代码/enhance_dataset', transform=transform)
print(getStat(train_dataset))
2. 生成数据集与数据加载器
2.1 生成数据集
python
import os
import random
train_ratio = 0.9
test_ratio = 1 - train_ratio
root_data = 'F:\数据与代码\enhance_dataset'
train_list, test_list = [], []
class_flag = -1
for root, dirs, files in os.walk(root_data):
for i in range(0, int(len(files)*train_ratio)):
train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
train_list.append(train_data)
for i in range(int(len(files)*train_ratio), len(files)):
test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
test_list.append(test_data)
class_flag += 1
random.shuffle(train_list)
random.shuffle(test_list)
with open('train.txt', 'w', encoding='UTF-8') as f:
for train_img in train_list:
f.write(str(train_img))
with open('test.txt', 'w', encoding='UTF-8') as f:
for test_img in test_list:
f.write(str(test_img))
2.2 生成数据加载器
python
import torch
from PIL import Image
import torchvision.transforms as transforms
#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Dataset
#数据归一化与标准化
transform_BZ = transforms.Normalize(
mean = [0.64148515, 0.57362735, 0.5084857],
std = [0.21153161, 0.21981773, 0.22988321]
)
class LoadData(Dataset):
def __init__(self, txt_path, train_flag=True):
self.imgs_info = self.get_images(txt_path)
self.train_flag = train_flag
self.img_size = 512
self.train_tf = transforms.Compose([
transforms.Resize(self.img_size),
transforms.RandomHorizontalFlip(),#随机水平翻转
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transform_BZ#数据归一化与标准化
])
self.val_tf = transforms.Compose([
transforms.Resize(self.img_size),
transforms.ToTensor(),
transform_BZ # 数据归一化与标准化
])
def get_images(self, txt_path):#返回格式[路径, 标签]
with open(txt_path, 'r', encoding='utf-8') as f:
imgs_info = f.readlines()
#map(函数,参数)
imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
return imgs_info
def padding_black(self, img): # 如果尺寸太小可以扩充
w, h = img.size
scale = self.img_size / max(w, h)
img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
size_fg = img_fg.size
size_bg = self.img_size
img_bg = Image.new("RGB", (size_bg, size_bg))
img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
(size_bg - size_fg[1]) // 2))
img = img_bg
return img
def __getitem__(self, index):
img_path, label = self.imgs_info[index]
img = Image.open(img_path)
img = img.convert('RGB')#转换为RGB格式
img = self.padding_black(img)
if self.train_flag:
img = self.train_tf(img)
else:
img = self.val_tf(img)
label = int(label)
return img, label
def __len__(self):
return len(self.imgs_info)
if __name__ == '__main__':
train_dataset = LoadData('train.txt', True)
print("数据个数", len(train_dataset))
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=5,
shuffle=True
)
for image, label in train_loader:
print("image.shape", image.shape)
# print(image)
print(label)