如何使用C#实现Padim算法的训练和推理

目录

说明

项目背景

算法实现

预处理模块------图像预处理

主要模块------训练:Resnet层信息提取

[主要模块------信息处理,计算Anomaly Map](#主要模块——信息处理,计算Anomaly Map)

主要模块------评估

主要模块------评估:门限值的确定

主要模块------推理

写在最后

项目下载链接


说明

作者:来瓶霸王防脱发

项目地址:

https://github.com/IntptrMax/PadimSharp

原文地址:

https://blog.csdn.net/qq_30270773/article/details/143029865

项目背景

缺陷检测(Anomaly Detection)算法是一个区分正常类别与异常类别的二分类问题,但在工业场景中大多数数据都为良品,不良数据难以获取,更难枚举,所以训练一个全监督的模型是不切实际的。因此,异常检测模型通常以单类别学习的模式。Padim算法是一种十分优秀的缺陷检测算法,直接上图可以看一下这个算法的效果。

良品图片

不良品图片

检测效果

C#是一种十分受欢迎的编程语言,这种编程语言在工业场景下使用也是十分广泛的。在一些AI领域,会在Python下将模型转化为onnx形式,通过onnxruntime加载使用,进行推理。但是在onnx形式下进行训练十分困难。很多C#开发者不太熟悉Python环境,或者某些条件下希望在纯粹的C#环境下进行深度学习的训练和使用。这个还是有一定的困难的。

目前搜索了Github和CSDN排名靠前的几十条数据,还没有Padim算法在除Python平台下的训练+推理的相关项目或资料。本文就是在C#平台实现了Padim的训练+推理过程,应该在相关领域也算是独一份了。

算法实现

Padim算法的"训练"过程其实并没有涉及到真正的训练,而是使用Resnet18算法提取关键信息加以处理,在推理时再次使用,因此"训练"过程速度非常快,这也是这个算法的优点之一。Padim算法的具体实现还请参考相关论文:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization

https://arxiv.org/abs/2011.08785

如果论文看起来困难,还有一些大佬对该算法在Python平台下的解读,也可以参考:PaDiM 原理与代码解析

https://blog.csdn.net/ooooocj/article/details/127601035

预处理模块------图像预处理

图像预处理使用的方法比较常规,使用了缩放等方式,此处并没有使用LetterBox,也可以达到预期效果:

var transformers = torchvision.transforms.Compose([
    torchvision.transforms.Resize(resizeHeight,resizeWidth),
torchvision.transforms.CenterCrop(cropHeight,cropWidth),
torchvision.transforms.Normalize(means, stdevs)]);

主要模块------训练:Resnet层信息提取

使用Resnet模型进行推理,并提取Layer1、Layer2、Layer3层的信息,并进行了拼接(EmbeddingConcat)。注意:这里提取时使用了钩子,钩子在使用时会有资源释放,因此这里使用了比较迂回的方式记录结果

实现代码如下:

public List<(string, Tensor)> Forward(Tensor input)
{
 List<(string, Tensor)> outputs = new List<(string, Tensor)>();
 List<TempTensor> tempTensors = new List<TempTensor>();
 foreach (var named_module in model.named_children())
 {
  string name = named_module.name;
  if (name == "layer1" || name == "layer2" || name == "layer3")
  {
   ((Sequential)named_module.module).register_forward_hook((Module, input, output) =>
   {
    tempTensors.Add(new TempTensor
    {
     Data = output.data<float>().ToArray(),
     Name = name,
     Shape = output.shape,
    });
    return null;
   });
  }
 }
 model.forward(input);

 var layer1output = tempTensors.Find(a => a.Name == "layer1");
 var layer2output = tempTensors.Find(a => a.Name == "layer2");
 var layer3output = tempTensors.Find(a => a.Name == "layer3");

 Tensor l1 = torch.tensor(layer1output.Data, layer1output.Shape, device: input.device);
 Tensor l2 = torch.tensor(layer2output.Data, layer2output.Shape, device: input.device);
 Tensor l3 = torch.tensor(layer3output.Data, layer3output.Shape, device: input.device);
 outputs.Add(new("layer1", l1));
 outputs.Add(new("layer2", l2));
 outputs.Add(new("layer3", l3));
 GC.Collect();
 return outputs;
}

private Tensor EmbeddingConcat(Tensor[] features)
{
 var embeddings = features[0];

 for (int i = 1; i < features.Length; i++)
 {
  var layerEmbedding = features[i];
  layerEmbedding = torch.nn.functional.interpolate(layerEmbedding, size: [embeddings.shape[2], embeddings.shape[2]], mode: InterpolationMode.Nearest);
  embeddings = torch.cat([embeddings, layerEmbedding], 1);
 }
 return embeddings;
}

主要模块------信息处理,计算Anomaly Map

这一块主要对信息进行处理,获取矩阵的mean和cov(协方差矩阵),代码如下:

public Tensor ComputeAnomalyMapInternal(Tensor embedding, Tensor mean, Tensor covariance)
{
 var scoreMap = ComputeDistance(embedding, mean, covariance);
 var upSampledScoreMap = UpSample(scoreMap);
 var smoothedAnomalyMap = SmoothAnomalyMap(upSampledScoreMap);
 return smoothedAnomalyMap;
}

public Tensor ComputeAnomalyMap(List<(string, Tensor)> outputs, Tensor mean, Tensor covariance, Tensor idx)
{
 Tensor embedding = GetEmbedding(outputs);
 var embeddingVectors = torch.index_select(embedding, 1, idx);
 return ComputeAnomalyMapInternal(embeddingVectors, mean, covariance);
}

主要模块------评估

与训练过程开始部分相似,也是获取图像的Embeddings,然后利用之前获取的Cov和mean计算马氏距离,以此评估图像的异常情况。马氏距离的计算方法如下:

private Tensor ComputeDistance(Tensor embedding, Tensor mean, Tensor covariance)
{
 long batch = embedding.shape[0];
 long channel = embedding.shape[1];
 long height = embedding.shape[2];
 long width = embedding.shape[3];

 Tensor inv_covariance = covariance.permute(2, 0, 1).inverse();
 var embedding_reshaped = embedding.reshape(batch, channel, height * width);
 var delta = (embedding_reshaped - mean).permute(2, 0, 1);
 var distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0);
 distances = distances.reshape(batch, 1, height, width);
 distances = distances.clamp(0).sqrt();
 return distances;
}

主要模块------评估:门限值的确定

这里需要确定图像的评估门限和像素值的评估门限。如果在评估时有负向样本,这个值会更准确,如果只有正向样本也是可以的。在Python下有个precision_recall_curve包,可以计算相关参数,但是在C#下时没有的,因此在此处仍旧只能造轮子,代码如下:

private (float[] precisions, float[] recalls, float[] thresholds) _precision_recall_curve_compute_single_class(Tensor yTrue, Tensor yScores, int pos_label = 1)
{
 var (fps, tps, thresholds) = BinaryClfCurve(yScores, yTrue, pos_label);
 var precision = tps / (tps + fps);
 var recall = tps / tps[-1];

 var lastInd = torch.where(tps == tps[-1])[0][0].ToInt32();
 int[] sl = new int[lastInd + 1];
 for (int i = 0; i < sl.Length; i++)
 {
  sl[i] = i;
 }
 var reversedPrecision = precision[sl].flip(0);
 var reversedRecall = recall[sl].flip(0);
 var reversedThresholds = thresholds[sl].flip(0);

 precision = torch.cat(new Tensor[] { reversedPrecision, torch.ones(1, dtype: precision.dtype, device: precision.device) });
 recall = torch.cat(new Tensor[] { reversedRecall, torch.zeros(1, dtype: recall.dtype, device: recall.device) });

 return (precision.data<float>().ToArray(), recall.data<float>().ToArray(), reversedThresholds.data<float>().ToArray());
}

private (Tensor fps, Tensor tps, Tensor thresholds) BinaryClfCurve(Tensor preds, Tensor target, int posLabel = 1)
{
 using (torch.no_grad())
 {
  if (preds.ndim > target.ndim)
  {
   preds = preds[TensorIndex.Ellipsis, 0];
  }

  var descScoreIndices = torch.argsort(preds, descending: true);
  preds = preds[descScoreIndices];
  target = target[descScoreIndices];

  Tensor weight = torch.tensor(1.0f);

  var distinctValueIndices = torch.nonzero(preds[1..] - preds[..^1]).squeeze();
  var thresholdIdxs = torch.cat(new Tensor[] { distinctValueIndices, torch.tensor(new long[] { target.shape[0] - 1 }, device: preds.device) });

  target = (target == posLabel).to_type(ScalarType.Int64);

  var tps = torch.cumsum(target * weight, dim: 0)[thresholdIdxs];

  Tensor fps = 1 + thresholdIdxs - tps;
  return (fps, tps, preds[thresholdIdxs]);
 }
}

主要模块------推理

这个过程与上面过程也十分相似,正向计算出图像的Anomaly Map后,取出这个张量中最大的值,与图像的门限值进行比较,即可评估图像是否是良品。然后对这个张量中每个元素与像素门限值做对比,即可得到按像素的异常区域,以便绘制Mask和热力图。

Tensor orgImg = tensors["orgImage"].clone().to(device);
Tensor t = anomaly_map > pixel_threshold;
anomaly_map = (anomaly_map * t).squeeze(0);
anomaly_map = torchvision.transforms.functional.resize(anomaly_map, (int)orgImg.size(2), (int)orgImg.size(1));
Tensor heatmapNormalized = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min());
Tensor coloredHeatmap = torch.zeros([3, (int)orgImg.size(2), (int)orgImg.size(1)],device:anomaly_map.device);

coloredHeatmap[0] = heatmapNormalized.squeeze(0);

float alpha = 0.3f;
Tensor blendedImage = (1 - alpha) * (orgImg / 255.0f) + alpha * coloredHeatmap;
var imageTensor = blendedImage.clamp(0, 1).mul(255).to(ScalarType.Byte);

torchvision.io.write_jpeg(imageTensor.cpu(), "result.jpg");

写在最后

使用C#开发深度学习项目,尤其是训练的项目,是一个十分困难的过程。或者说除了Python平台,训练都十分困难。C#进行深度学习训练这个方向在国内基本很少有人开展,所以能查得到的资料很少。本人十分喜爱C#这门语言,又十分喜爱深度学习,因此仅半年一直在这方面努力。遇到了很多困难,也收获了很多。

这条路走的不容易,希望能有更多人能加入进来,一起开发,一起学习。

我在Github上已经将完整的代码发布了,项目地址为:

https://github.com/IntptrMax/PadimSharp

,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。

项目下载链接

https://download.csdn.net/download/qq_30270773/89897710
相关推荐
九鼎科技-Leo16 分钟前
什么是 WPF 中的依赖属性?有什么作用?
windows·c#·.net·wpf
Heaphaestus,RC1 小时前
【Unity3D】获取 GameObject 的完整层级结构
unity·c#
baivfhpwxf20231 小时前
C# 5000 转16进制 字节(激光器串口通讯生成指定格式命令)
开发语言·c#
直裾2 小时前
Scala全文单词统计
开发语言·c#·scala
ZwaterZ3 小时前
vue el-table表格点击某行触发事件&&操作栏点击和row-click冲突问题
前端·vue.js·elementui·c#·vue
ZwaterZ5 小时前
el-table-column自动生成序号&&在序号前插入图标
前端·javascript·c#·vue
SRC_BLUE_178 小时前
SQLI LABS | Less-55 GET-Challenge-Union-14 Queries Allowed-Variation 2
oracle·c#·less
yngsqq9 小时前
037集——JoinEntities连接多段线polyline和圆弧arc(CAD—C#二次开发入门)
开发语言·c#·swift
Zԅ(¯ㅂ¯ԅ)9 小时前
C#桌面应用制作计算器进阶版01
开发语言·c#
hccee10 小时前
C#之异步编程
c#