pytorch模型的进阶训练和性能优化

综合案例

  • 将MNIST数据集保存成本地图片
  • 读取本地图片进行训练
  • 读取自己的数据集进行训练
  • 用自己的模型进行训练
  • 获得更多评价指标
  • 提升模型性能的方法

MNIST转本地图片

复制代码
import os
import torchvision
import torchvision.transforms as transforms

# 下载MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False,
                                           download=True, transform=transform)

# 创建一个目录来保存图像(如果它还不存在)
os.makedirs('./mnist_images/train', exist_ok=True)
os.makedirs('./mnist_images/test', exist_ok=True)

# 遍历数据集并保存图像
for idx, (image, label) in enumerate(mnist_trainset):
    # 创建类别文件夹(如果它还不存在)
    label_dir = os.path.join('./mnist_images/train', str(label))
    os.makedirs(label_dir, exist_ok=True)

    # 转换为PIL图像并保存
    pil_image = transforms.ToPILImage()(image)
    pil_image.save(os.path.join(label_dir, f'{idx}.jpg'))

# 遍历数据集并保存图像
for idx, (image, label) in enumerate(mnist_testset):
    # 创建类别文件夹(如果它还不存在)
    label_dir = os.path.join('./mnist_images/test', str(label))
    os.makedirs(label_dir, exist_ok=True)

    # 转换为PIL图像并保存
    pil_image = transforms.ToPILImage()(image)
    pil_image.save(os.path.join(label_dir, f'{idx}.jpg'))

# 打印完成消息
print("All images have been saved successfully.")

接下来我来讲解一下上述的代码,在我的视角看来应该要将的东西

transform = transforms.Compose([transforms.ToTensor()])

使用 torchvision.transforms 模块中的 ComposeToTensor 方法来定义一个图像预处理的转换操作,主要用于将图像数据转换为 PyTorch 张量(Tensor),以便用于深度学习模型的训练或推理。

mnist_trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True, transform=transform

自动下载MNIST数据集,然后将其转换为tensor格式

os.makedirs('./mnist_images/train', exist_ok=True)

使用 Python 的 os 模块中的 makedirs 函数来创建目录。具体来说,它的作用是创建一个目录路径 ./mnist_images/train,并且如果该目录已经存在,不会报错。

其中exist_ok=True,在目录已经存在的情况下,不会报错

label_dir = os.path.join('./mnist_images/train', str(label))

这一句代码的作用就是,把'./mnist_images/train'字符串和str(label)字符串拼接起来。

pil_image = transforms.ToPILImage()(image)

这行代码的作用是将输入的图像数据(通常是 PyTorch 张量或 NumPy 数组)转换为 PIL 图像对象。

读取本地图片进行训练

复制代码
import os
import torch
from torch.utils.data import Dataset, DataLoader
import cv2 as cv


class MNISTDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_list = []
        self.name_list = []
        self.id_list = []
        for root, dirs, files in os.walk(self.root_dir):
            if dirs:
                self.name_list = dirs
            for file_i in files:
                file_i_full_path = os.path.join(root, file_i)
                file_class = os.path.split(file_i_full_path)[0].split('\\')[-1]
                self.id_list.append(self.name_list.index(file_class))
                self.file_list.append(file_i_full_path)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img = self.file_list[idx]
        img = cv.imread(img, 0)
        img = cv.resize(img, dsize=(28, 28))
        img = torch.from_numpy(img).float()

        label = self.id_list[idx]
        # print(label)
        label = torch.tensor(label)
        return img, label


if __name__ == '__main__':
    my_dataset_train = MNISTDataset(r'mnist_images/train')
    my_dataloader_train = DataLoader(my_dataset_train, batch_size=10, shuffle=True)
    # 尝试读取训练集数据
    print("读取训练集数据")
    for x, y in my_dataloader_train:
        print(x.type(), x.shape, y)
    my_dataset_test = MNISTDataset(r'mnist_images/test')
    my_dataloader_test = DataLoader(my_dataset_test, batch_size=10, shuffle=False)
    # 尝试读取训练集数据
    print("读取测试集数据")
    for x, y in my_dataloader_test:
        print(x.shape, y)

在前面讲了,Dataset的三件套,initlen,__getitem__如果这三个魔法方法忘记了,可以回去看看以前的文章。

for root, dirs, files in os.walk(self.root_dir):

Python 中使用 os.walk 函数的一个典型用法,用于遍历指定目录及其所有子目录中的文件和文件夹。

  • root:当前正在遍历的目录路径。

  • dirs:当前目录下的子目录列表。

  • files:当前目录下的文件列表

self.name_list

列表当中存的就是,标签名字

进入第二次循环后

os.path.split(file_i_full_path)[0]

file_i_full_path 分割为目录部分和文件名部分,返回一个元组 (head, tail)

self.name_list.index(file_class)

self.name_list 列表中查找 file_class 元素的索引位置。

self.file_list最后的效果

相关推荐
董厂长17 分钟前
langchain :记忆组件混淆概念澄清 & 创建Conversational ReAct后显示指定 记忆组件
人工智能·深度学习·langchain·llm
亿牛云爬虫专家1 小时前
Kubernetes下的分布式采集系统设计与实战:趋势监测失效引发的架构进化
分布式·python·架构·kubernetes·爬虫代理·监测·采集
G皮T4 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
九年义务漏网鲨鱼4 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间4 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享4 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频
产品经理独孤虾4 小时前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性
老任与码5 小时前
Spring AI Alibaba(1)——基本使用
java·人工智能·后端·springaialibaba
蹦蹦跳跳真可爱5895 小时前
Python----OpenCV(图像増强——高通滤波(索贝尔算子、沙尔算子、拉普拉斯算子),图像浮雕与特效处理)
人工智能·python·opencv·计算机视觉
nananaij5 小时前
【Python进阶篇 面向对象程序设计(3) 继承】
开发语言·python·神经网络·pycharm