B站项目-基于Pytorch的ResNet垃圾图片分类

基于Pytorch的ResNet垃圾图片分类

项目链接
数据集下载链接

1. 数据集预处理

1.1 画图片的宽高分布散点图
python 复制代码
import os

import matplotlib.pyplot as plt
import PIL.Image as Image


def plot_resolution(dataset_root_path):
    image_size_list = []#存放图片尺寸
    for root, dirs, files in os.walk(dataset_root_path):
        for file in files:
            image_full_path = os.path.join(root, file)
            image = Image.open(image_full_path)
            image_size = image.size
            image_size_list.append(image_size)

    print(image_size_list)

    image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽
    image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高

    plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体
    plt.rcParams['font.size'] = 8
    plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题

    plt.scatter(image_width_list, image_height_list, s=1)
    plt.xlabel('宽')
    plt.ylabel('高')
    plt.title('图像宽高分布散点图')
    plt.show()



if __name__ == '__main__':
    dataset_root_path = "F:\数据与代码\dataset"
    plot_resolution(dataset_root_path)

运行结果:

注: os.walk详细解释参考

1.2 画出数据集的各个类别图片数量的条形图

文件组织结构:

python 复制代码
def plot_bar(dataset_root_path):

    file_name_list = []
    file_num_list = []

    for root, dirs, files in os.walk(dataset_root_path):
        if len(dirs) != 0 :
            for dir in dirs:
                file_name_list.append(dir)
        file_num_list.append(len(files))



    file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23,  25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]
    #[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]


    mean = np.mean(file_num_list)
    print("mean= ", mean)

    bar_positions = np.arange(len(file_name_list))
    fig, ax = plt.subplots()
    ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度
    ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
    plt.rcParams['font.size'] = 8
    plt.rcParams['axes.unicode_minus'] = False  # 解决图像中的负号乱码问题

    ax.set_xticks(bar_positions)#设置x轴的刻度
    ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签
    ax.set_ylabel("类别数量")
    ax.set_title("各个类别数量分布散点图")
    plt.show()

运行结果

1.3 删除宽高有问题的图片
python 复制代码
import os
import PIL.Image as Image


MIN = 200
MAX = 2000
ratio = 0.5

def delete_img(dataset_root_path):
    delete_img_list = [] #需要删除的图片地址

    for root, dirs, files in os.walk(dataset_root_path):

        for file in files:
            img_full_path = os.path.join(root, file)
            img = Image.open(img_full_path)
            img_size = img.size
            max_l = img_size[0] if img_size[0] > img_size[1] else img_size[1]
            min_l = img_size[0] if img_size[0] < img_size[1] else img_size[1]
            # 把图片宽高限制在 200~2000 这里可能会重复添加图片路径
            if img_size[0] < MIN or img_size[1] < MIN:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)

            elif img_size[0] > MAX or img_size[1] > MAX:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)

            #避免图片窄长
            elif min_l / max_l < ratio:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)


    for img in delete_img_list:
        print("正在删除", img)
        os.remove(img)



if __name__ == '__main__':
    dataset_root_img = 'F:\数据与代码\dataset'
    delete_img(dataset_root_img)

再次运行1.1 和1.2的代码得到处理后的数据集宽高分布和类别数量

1.4 数据增强
python 复制代码
import os

import cv2

#水平翻转
import numpy as np


def Horizontal(image):
    return cv2.flip(image, 1, dst=None)

#垂直翻转
def Vertical(image):
    return cv2.flip(image, 0, dst=None)

threshold = 200 #阈值

#数据增强
def data_augmentation(from_root_path, save_root_path):
    for root, dirs, files in os.walk(from_root_path):
            for file in files:
                img_full_path = os.path.join(root, file)
                split = os.path.split(img_full_path)
                save_path = os.path.join(save_root_path, os.path.split(split[0])[1])
                print(save_path)
                if os.path.isdir(save_path) == False:#文件夹不存在就创建
                    os.makedirs(save_path)

                img = cv2.imdecode(np.fromfile(img_full_path, dtype=np.uint8), -1)#读取含中文的路径
                cv2.imencode('.jpg', img)[1].tofile(os.path.join(save_path,file[:-5]+ "_original.jpg")) #保存原图


                if len(files) > 0 and len(files) < threshold:  # 类别数量小于阈值,需要对该类别的所有图片进行数据增强
                    img_horizontal = Horizontal(img)
                    cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file[:-5] + "_horizontal.jpg"))
                    img_vertical = Vertical(img)
                    cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file[:-5] + "_vertical.jpg"))
                else:
                    pass

if __name__ == '__main__':
    from_root_path = 'F:\数据与代码\dataset'
    save_root_path = 'F:\数据与代码\enhance_dataset'
    data_augmentation(from_root_path, save_root_path)
1.5 数据集平衡处理

将图片数量超过阈值的类别删除一部分图片

python 复制代码
import os
import random

threshold = 300
def dataset_balance(dataset_root_path):

    for root, dirs, files in os.walk(dataset_root_path):
        if len(files) > threshold:
            delete_img_list = []
            for file in files:
                img_full_path = os.path.join(root, file)
                delete_img_list.append(img_full_path)

            random.shuffle(delete_img_list)
            delete_img_list = delete_img_list[threshold:]
            for img in delete_img_list:
                os.remove(img)
                print("成功删除", img)

if __name__ == '__main__':
    dataset_root_path = 'F:\数据与代码\enhance_dataset'
    dataset_balance(dataset_root_path)
1.6 求图像的均值和方差
python 复制代码
from torchvision import transforms as T
import torch
from torchvision.datasets import ImageFolder
from tqdm import tqdm

transform = T.Compose([
    T.RandomResizedCrop(224),#随机采样并缩放为 224X224
    T.ToTensor(),
])


def getStat(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
    )

    #均值 方差
    mean = torch.zeros(3)#三维
    std = torch.zeros(3)

    for X, _ in tqdm(train_loader):# tqdm添加进度条
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()

    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root='F:/数据与代码/enhance_dataset', transform=transform)
    print(getStat(train_dataset))

2. 生成数据集与数据加载器

2.1 生成数据集
python 复制代码
import os
import random

train_ratio = 0.9
test_ratio = 1 - train_ratio

root_data = 'F:\数据与代码\enhance_dataset'

train_list, test_list = [], []

class_flag = -1
for root, dirs, files in os.walk(root_data):
    for i in range(0, int(len(files)*train_ratio)):
        train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        train_list.append(train_data)

    for i in range(int(len(files)*train_ratio), len(files)):
        test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        test_list.append(test_data)

    class_flag += 1

random.shuffle(train_list)
random.shuffle(test_list)

with open('train.txt', 'w', encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt', 'w', encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(str(test_img))
2.2 生成数据加载器
python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms

#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True



from torch.utils.data import Dataset



#数据归一化与标准化
transform_BZ = transforms.Normalize(
    mean = [0.64148515, 0.57362735, 0.5084857],
    std = [0.21153161, 0.21981773, 0.22988321]
)


class LoadData(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag
        self.img_size = 512
        self.train_tf = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.RandomHorizontalFlip(),#随机水平翻转
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transform_BZ#数据归一化与标准化
        ])

        self.val_tf = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transform_BZ  # 数据归一化与标准化
        ])

    def get_images(self, txt_path):#返回格式[路径, 标签]
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            #map(函数,参数)
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
        return imgs_info

    def padding_black(self, img):  # 如果尺寸太小可以扩充
        w, h = img.size
        scale = self.img_size / max(w, h)
        img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
        size_fg = img_fg.size
        size_bg = self.img_size
        img_bg = Image.new("RGB", (size_bg, size_bg))
        img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                              (size_bg - size_fg[1]) // 2))
        img = img_bg
        return img

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')#转换为RGB格式
        img = self.padding_black(img)
        if self.train_flag:
            img = self.train_tf(img)

        else:
            img = self.val_tf(img)

        label = int(label)

        return img, label


    def __len__(self):
        return len(self.imgs_info)

if __name__ == '__main__':
    train_dataset = LoadData('train.txt', True)
    print("数据个数", len(train_dataset))

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=5,
        shuffle=True
    )

    for image, label in train_loader:
        print("image.shape", image.shape)
        # print(image)
        print(label)
相关推荐
Python极客之家32 分钟前
基于深度学习的乳腺癌分类识别与诊断系统
人工智能·深度学习·分类
BulingQAQ6 小时前
论文阅读:PET/CT Cross-modal medical image fusion of lung tumors based on DCIF-GAN
论文阅读·深度学习·生成对抗网络·计算机视觉·gan
slomay9 小时前
关于对比学习(简单整理
经验分享·深度学习·学习·机器学习
丶21369 小时前
【CUDA】【PyTorch】安装 PyTorch 与 CUDA 11.7 的详细步骤
人工智能·pytorch·python
AI完全体10 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
卷心菜小温11 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
陈苏同学11 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
FL162386312911 小时前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~12 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
龙的爹233312 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt