Vision Transformer 神经网络在水果、动物、血细胞上的应用【深度学习、PyTorch、图像分类】

源链接(英文):https://www.orzzz.net/directory/codes/VisionTransformer/index.html

代码链接(中文):https://github.com/Illusionna/VisionTransformer

1. 介绍

2023-11-25 我在 GitHub 上提交了 EfficientTransformer 的比赛代码,近一年后的今天,我在删除电脑磁盘上吃灰的文件时,再次发现到它。由于去年写的函数接口很简约,所以这次重写 ViT 代码,修改了很多冗余的语句,并在其他图片集上进行测试,没想到这个神经网络模型效果依然这么好。

该项目代码是用于图片集训练、测试、预测的 Vision Transformer 神经网络架构。你可以下载已提供的数据集,当然,你也可以将代码运行在自己的图片集上。然而,需要注意的是,该项目仅是 demo 演示,不建议你将代码部署工业生产。

2. 数据集

已提供的图片集有三种:动物、水果、血细胞。你可以自定义新的图片集,不过需要满足一定的规则,否则代码可能无法运行。参考已提供的三种图片集文件夹或文件放置结构即可,发现其中的规律对你来说轻而易举?

三种图片集共约 1 GB 磁盘占用,所以 git 仓库里只放置了"水果"的图片集压缩包,另外两种数据集推荐使用百度网盘下载。在你运行代码前记得解压缩哦。

说句题外话,你可以把这三类图片集整合到一块,形成一个共计 15 类的更大数据集,再尝试训练测试 ViT 模型。项目工作区结构如下:

项目结构文件的描述如下:

复制代码
.
├── cache
│   ├── log
│   │   ├── Animals-train_loss(0.48440)valid_loss(0.40608).pt (动物图片训练 200 次后的权重)
│   │   └── Fruits-train_loss(0.12103)valid_loss(0.04023).pt (该权重可以直接用于测试、预测)
│   ├── Animals-info.json (动物数据集详情)
│   ├── Animals-map.json (动物数据集划分映射)
│   ├── Animals-net.txt (动物图片集对应的神经网络结构)
│   ├── Animals-test.json
│   ├── Animals-train.json
│   ├── Fruits-info.json
│   ├── Fruits-map.json
│   ├── Fruits-net.txt
│   ├── Fruits-test.json (水果测试集正确率)
│   └── Fruits-train.json (水果训练集训练过程)
├── datasets
│   ├── Animals (动物数据集)
│   │   ├── cat (猫类)
│   │   └── dog (狗类)
│   ├── Bloodcells
│   │   ├── basophil
│   │   ├── eosinophil
│   │   ├── erythroblast
│   │   ├── ig
│   │   ├── lymphocyte
│   │   ├── monocyte
│   │   ├── neutrophil
│   │   ├── platelet
│   │   └── README.md (血细胞数据集介绍, 不必管它)
│   ├── Fruits (水果数据集)
│   │   ├── Apple (苹果类)
│   │   │   ├── Apple (1).jpg (苹果类的某张图片)
│   │   │   ├── Apple (2).jpg
│   │   │   └── Apple (3).jpg
│   │   ├── Carambola
│   │   ├── Pear (梨子类)
│   │   │   ├── Pear (1).jpg
│   │   │   └── Pear (2).jpg (梨子类的某张照片)
│   │   ├── Plum
│   │   └── Tomatoes
│   ├── Unknown-Fruits (未知水果, 被用于预测)
│   │   ├── Fruit (a).jpg
│   │   ├── Fruit (b).jpg
│   │   ├── Fruit (c).jpg
│   │   ├── Fruit (d).jpg (某个未知的水果, 使用 Predict.py 代码进行预测)
│   │   └── Fruit (e).jpg
│   ├── Fruits.zip (水果数据集压缩包, 解压后, 顶替 Fruits 文件夹所有示例文件)
│   └── Unknown-Fruits.zip (未知水果的压缩包, 解压缩后, 顶替 Unknown-Fruits 文件夹)
├── utils
│   ├── einops (一个面向眼睛编程的张量库, 巨牛)
│   │   ├── experimental
│   │   │   ├── __init__.py
│   │   │   └── indexing.py
│   │   ├── layers
│   │   │   ├── __init__.py
│   │   │   ├── _einmix.py
│   │   │   ├── chainer.py
│   │   │   ├── flax.py
│   │   │   ├── keras.py
│   │   │   ├── oneflow.py
│   │   │   ├── paddle.py
│   │   │   ├── tensorflow.py
│   │   │   └── torch.py
│   │   ├── __init__.py
│   │   ├── _backends.py
│   │   ├── _torch_specific.py
│   │   ├── array_api.py
│   │   ├── einops.py
│   │   ├── packing.py
│   │   ├── parsing.py
│   │   └── py.typed
│   ├── linformer (一个把 Transformer 降低成线性复杂度的库)
│   │   ├── __init__.py
│   │   ├── linformer.py
│   │   └── reversible.py
│   ├── efficient.py (ViT 架构模型类)
│   ├── interface.py (自定义的接口文件)
│   ├── net.py (自定义的神经网络结构文件)
│   └── tool.py (自定义的工具文件)
├── Illustrate.py (插图函数, 需要单独安装库 >>> pip install matplotlib)
├── Predict.py (预测函数, 如果测试效果很好, 可将权重用于预测未知图片集)
├── Preprocessing.py (预处理函数, 对数据集进行处理, 之后可进行训练、测试、预测)
├── REAME.md (你正在阅读的文件)
├── Test.py (测试函数, 用于检测训练好的结果在测试集上的正确率)
└── Train.py (预处理完成后, 即可进行训练)

3. 依赖

该项目既支持 CPU 计算也支持 GPU 计算,如果你的电脑有图形处理器,可在终端使用 nvidia-smi 指令查看显卡型号,我的 GPU 显存是 4 GB 大小,且指令返回 CUDA 版本是 12.5 的型号。

复制代码
NVIDIA-SMI 556.12   Driver Version: 556.12  CUDA Version: 12.5

你可在 PyTorch 官网找到适合自己电脑的 Python 库,所以我在 Windows 11 上安装 torch-cu121 库。

推荐使用 Anaconda、Miniconda 或者 venv 虚拟环境下载 Python 3.10+ 的依赖库。

Conda

复制代码
  conda create -n GPU python==3.10.0
复制代码
  conda activate GPU
复制代码
  pip install matplotlib
复制代码
  python Preprocessing.py 

venv

复制代码
  python -m venv venv
复制代码
  ./venv/Scripts/activate
复制代码
  ./venv/Scripts/pip.exe install matplotlib
复制代码
  ./venv/Scripts/python.exe Preprocessing.py 

Windows 11

复制代码
  pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Windows 10

复制代码
  pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
复制代码
  pip install pillow

macOS

复制代码
  pip install torch torchvision torchaudio
复制代码
  pip install numpy
复制代码
  pip install pillow

Linux

复制代码
  I don't know.

4. 数据集预处理

在正式训练之前需要执行 Preprocessing.py 文件,对图片进行统计和处理。

python 复制代码
import os
from utils.tool import RandomSeed, Process

if __name__ == '__main__':
    os.system('')	# 使得终端 ASCII 颜色可以正常转义.
    print('\033[H\033[J', end = '')		# 清屏.
    RandomSeed(42)    # 播下随机种子, 使得每次执行程序结果相同.
    Process(
        dir = './datasets/Animals',  # 需要处理的图片集目录.
        ratio = [7, 2, 1], # 训练集、测试集、验证集比例, 不限格式.
        show = True,    # 展示预处理结果.
        resolution = True   # 统计图片平均分辨率.
    )
复制代码
>>> python ./Preprocessing.py

5. 训练

"./cache/log" 文件夹的 Animals.ptFruits.ptBloodcells.pt 是我已经训练 200 轮后,得到了非常好的权重文件,你可以直接拿这两个权重文件进行测试、预测,跳过训练环节。毕竟,训练是很耗时间的。

如果你想重头开始进行水果、动物、血细胞图片集训练,或者对自己的数据集进行训练,你可以在 Preprocessing.py 完成后,执行 Train.py 文件。

python 复制代码
import torch
import platform
from utils.net import Transformer
from utils.interface import TrainActivate

if __name__ == '__main__':
    DATASET_NAME = 'Fruits' # 图片集的文件夹名称, 即训练的类型.
    device = torch.device(
        ('mps' if platform.system() == 'Darwin' else 'cuda')
        if torch.cuda.is_available() else 'cpu'
    )   # 程序计算使用的处理器.
    net = Transformer(info_path = f'./cache/{DATASET_NAME}-info.json').to(device)
    optimizer = torch.optim.Adam(params = net.parameters(), lr = 3e-5)
    parameters: dict = {
        'train_test_valid_set_map': f'./cache/{DATASET_NAME}-map.json',
        'epoch': 200,   # 迭代次数.
        'batch_size': 256,  # 处理器一次性计算的图片个数.
        'model': net,
        'device': device,
        'optimizer': optimizer,
        'scheduler': torch.optim.lr_scheduler.StepLR(
            optimizer = optimizer,
            step_size = 1,
            gamma = 0.7
        ),
        'criterion': torch.nn.CrossEntropyLoss()
    }
    TrainActivate(parameters)
复制代码
>>> python ./Train.py





上面这五张连续图片是我对 Animals、Fruits 和 Bloodcells 三类图片集各训练 200 次的过程,其中 Animals 耗时 03:31:17 较长, Fruits 训练耗时 01:54:29 较短,Bloodcells 训练耗时 02:25:42 适中。

程序代码自动将训练结果全部保存在 "./cache" 目录下,其中子文件夹 log 是存放训练权重,你可以看到 log 里的动物和水果权重,这是我训练出的最优权重。

6. 测试

Illustrate.py 文件是用于绘制训练过程的插图,你可以根据它的结果找一个笼统的区间范围,再从区间内找一个最优的训练权重文件。当然,你可以直接使用 "./cache/log" 目录中我训练好的两个最优权重进行测试。

复制代码
>>> python ./Illustrate.py
python 复制代码
import torch
import platform
from utils.interface import TestActivate

if __name__ == '__main__':
    params = {
        'batch_size': 128,  # 处理器一次计算测试的图片个数.
        # 映射字典路径.
        'train_test_valid_set_map': './cache/Fruits-map.json',
        # 权重文件路径.
        'weight': './cache/log/Fruits-train_loss(0.12103)valid_loss(0.04023).pt',
        'device': torch.device(
            ('mps' if platform.system() == 'Darwin' else 'cuda')
            if torch.cuda.is_available() else 'cpu'
        )
    }
    TestActivate(params)
复制代码
>>> python ./Test.py

最终测试结果会返回并保存一个正确率指数,水果测试集达到 98% 的正确率,相当高。

7. 预测

水果 Fruits 测试集正确率高达 98%,因此表明训练的权重相当好,那么我们可将这个权重用于未知图片集的预测。

python 复制代码
import torch
import platform
from utils.interface import PredictActivate

if __name__ == '__main__':
    params = {
        'batch_size': 1536, # 处理器一次性预测计算的图片个数.
        # 数据集详情文件路径.
        'info_path': './cache/Fruits-info.json',
        # 待预测的图片集目录.
        'predict_images_dir': './datasets/Unknown-Fruits',
        # 使用训练好的权重文件路径.
        'weight': './cache/log/Fruits-train_loss(0.12103)valid_loss(0.04023).pt',
        'device': torch.device(
            ('mps' if platform.system() == 'Darwin' else 'cuda')
            if torch.cuda.is_available() else 'cpu'
        )
    }
    PredictActivate(params)
复制代码
>>> python ./Predict.py



8. 开源同步

GPLv3

https://orzzz.net

GitHub Repository

相关推荐
飞翔的佩奇4 分钟前
【完整源码+数据集+部署教程】 水果叶片分割系统: yolov8-seg-dyhead
人工智能·yolo·计算机视觉·数据集·yolov8·yolo11·水果叶片分割系统
感谢地心引力4 分钟前
【Python】基于 PyQt6 和 Conda 的 PyInstaller 打包工具
数据库·python·conda·pyqt·pyinstaller
春末的南方城市1 小时前
清华&字节开源HuMo: 打造多模态可控的人物视频,输入文字、图片、音频,生成电影级的视频,Demo、代码、模型、数据全开源。
人工智能·深度学习·机器学习·计算机视觉·aigc
xiaohanbao092 小时前
Transformer架构与NLP词表示演进
python·深度学习·神经网络
m0_743106462 小时前
LOBE-GS:分块&致密化效率提升
人工智能·算法·计算机视觉·3d·几何学
love530love2 小时前
【笔记】 Podman Desktop 中部署 Stable Diffusion WebUI (GPU 支持)
人工智能·windows·笔记·python·容器·stable diffusion·podman
视觉人机器视觉2 小时前
机器视觉Halcon3D中,六大类3D处理算子
人工智能·计算机视觉·3d·视觉检测
CoovallyAIHub2 小时前
Transformer作者开源进化计算新框架,样本效率暴增数十倍!
深度学习·算法·计算机视觉
程序员晚枫3 小时前
Python 3.14正式发布!这5大新特性太炸裂了
python
先做个垃圾出来………3 小时前
SortedList
python