深度学习异常检测Anomalib算法训练+推理+转化+onnx

目录

[一、环境安装配置(python 版本 必须大于3.10)](#一、环境安装配置(python 版本 必须大于3.10))

二、数据集

三、训练

[源码 GitHub - open-edge-platform/anomalib: An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference. · GitHub](#源码 GitHub - open-edge-platform/anomalib: An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference. · GitHub)

train.py

[detect.py 推理代码](#detect.py 推理代码)

[模型转化 export.py](#模型转化 export.py)

官方的转化代码


一、环境安装配置(python 版本 必须大于3.10)

cuda和cudnn各个版本的Pytorch下载网页版,onnx,ncnn,pt模型转化工具_cuda国内镜像下载网站-CSDN博客

你需要的所有东西都可以下载(如果有链接打不开 就直接去网上搜官网 从官网进入下载)

标准环境

cuda 11.8 cudnn 对应11.8的即可

torch 2.4.0+cu118

torchaudio 2.4.0+cu118

torchmetrics 1.8.0

torchvision 0.19.0+cu118

下面是我的环境

torch 2.1.0
torchaudio 2.1.0
torchmetrics 1.9.0
torchvision 0.16.0

因为我的是11.7的cuda 不想换了 所以就用了 但是注意 我的环境版本只能训练Patchcore 模型 其他的模型都无法训练 版本不够

所以如果你们想训练其他的Anomalib模型需要标准的高版本环境

二、数据集

数据集中创建这俩文件夹 放入你的正样本,,负样本是用来验证的也需要放。

两边的数据集需要最少10张图片。图片命名不能用中文

三、训练

源码 GitHub - open-edge-platform/anomalib: An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference. · GitHub

你可以在连接中下载 也可以直接pip 安装库

复制代码
pip install anomalib -i https://pypi.tuna.tsinghua.edu.cn/simple

在你的工程中创建三个py文件 train.py detect.py export.py

train.py

复制代码
import multiprocessing
from anomalib.data import Folder
from anomalib.models import Patchcore
from anomalib.engine import Engine


def main():
    datamodule = Folder(
        name="flat_enameled_wire",                    #训练模型后存放的位置
        root=r"C:\Users\Administrator\Desktop\tscc\200t",   #数据集
        normal_dir="good",                                    #数据集的good
        abnormal_dir="bad",                                #数据集中的bad 
        #
        # val_split_mode="none",  # 无验证集
        # test_split_mode="none",  # 无自动测试集
        normal_split_ratio=1.0,  # 全部正常图用于训练

        train_batch_size=1,
        num_workers=0,  # Windows 必须 0
    )

    datamodule.setup()

    model = Patchcore(
        backbone="wide_resnet50_2",                #预训练模型 
        pre_trained=True,  # 必须开,否则没法训 开启就是下载预训练模型      并使用

        coreset_sampling_ratio=0.05,
        num_neighbors=5
    )

    engine = Engine(
        devices=1,
        max_epochs=1,
    )

    engine.train(datamodule=datamodule, model=model)


if __name__ == '__main__':
    multiprocessing.freeze_support()
    main()

预训练模型如果网络不好是无法下载的 可以直接用下方链接下载 预训练模型

下载好后将其放到 这个路径中 如果没有就创建 C:\Users\Administrator\.cache\torch\hub\checkpoints

复制代码
https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth

detect.py 推理代码

复制代码
import multiprocessing
from anomalib.data import Folder
# Import the model and engine
from anomalib.models import Patchcore
from anomalib.engine import Engine

def main():
    # Create the datamodule
    datamodule = Folder(
        name="flat_enameled_wire_infer4",
        # root=r"H:\anomalib-main\ttt\train\11",
        root=r"C:\Users\Administrator\Desktop\tscc\test22",
        normal_dir="good",
        abnormal_dir="bad",
        # task="classification",
    )

    # Setup the datamodule
    datamodule.setup()

    model = Patchcore(pre_trained=False)

    # engine = Engine(task="classification")
    engine = Engine()

    engine.predict(
        datamodule=datamodule,
        model=model,
        ckpt_path=r"H:\anomalib-main\results\Patchcore\flat_enameled_wire\v7\weights\lightning\model.ckpt",
    )

if __name__ == '__main__':
    multiprocessing.freeze_support()  # Optional, if your script might be frozen into an executable
    main()

模型转化 export.py

我用的是torch 转化的 如果是高版本的可以直接用官方的

复制代码
import os
import torch

os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

CKPT_PATH = r"H:\anomalib-main\results\Patchcore\flat_enameled_wire\v5\weights\lightning\model.ckpt"
INPUT_SIZE = (3, 224, 224)
ONNX_SAVE_PATH = r"H:\anomalib-main\results\Patchcore\flat_enameled_wire\v5\weights\lightning\model_final.onnx"


from anomalib.models import Patchcore

# 加载模型
model = Patchcore(backbone="wide_resnet50_2", pre_trained=False)
checkpoint = torch.load(CKPT_PATH, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
model.eval()

dummy_input = torch.randn(1, *INPUT_SIZE)

torch.onnx.export(
    model,
    dummy_input,
    ONNX_SAVE_PATH,
    opset_version=14,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    #  dynamo 参数彻底删掉!
)

print(f"导出成功:{ONNX_SAVE_PATH}")

官方的转化代码

复制代码
from anomalib.models import Patchcore
from anomalib.engine import Engine

model = Patchcore()
engine = Engine(task="classification")
onnx_model = engine.export(
    model=model,
    export_type='onnx',
    export_root=None,
    input_size=[244, 244],
    transform=None,
    compression_type=None,
    datamodule=None,
    metric=None,
    ov_args=None,
    ckpt_path='E:\\proj\\anomalib\\myProj\\model.ckpt',  # 存放model.ckpt
)
print(onnx_model)
相关推荐
流光容易把人抛7 分钟前
Claude Code & CCSwitch Mac 安装配置详细教程
人工智能
2301_818008449 分钟前
数据库模型设计实战:如何正向工程从模型建表_规范化项目开发流程
jvm·数据库·python
ai产品老杨12 分钟前
突破品牌壁垒:基于 GB28181 与 RTSP 的异构 AI 视频平台架构深度解析(支持 Docker 与源码交付)
人工智能·架构·音视频
科研前沿13 分钟前
多视角相机驱动的室内人员空间定位技术白皮书
大数据·人工智能·python·科技·数码相机·音视频
得一录25 分钟前
大模型需要量化的原因
人工智能
weixin_4171970526 分钟前
四大科技巨头狂砸7250亿美元:AI算力军备竞赛白热化
人工智能·科技
覆东流35 分钟前
第10天:python元组
开发语言·后端·python
万事大吉CC35 分钟前
【5】Django 的模板语言:页面架构设计
后端·python·django
sali-tec40 分钟前
C# 基于OpenCv的视觉工作流-章61-点线距离
图像处理·人工智能·opencv·计算机视觉