C# OnnxRuntime 部署 DINOv3 密集特征可视化

说明

官网地址:github.com/facebookres...

效果

模型信息

markdown 复制代码
Model Properties
-------------------------
---------------------------------------------------------------

Inputs
-------------------------
name:input
tensor:Float[-1, 3, -1, -1]
---------------------------------------------------------------

Outputs
-------------------------
name:patch_tokens
tensor:Float[-1, -1, 1024]
---------------------------------------------------------------

项目

代码

ini 复制代码
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Drawing.Imaging;
using System.Linq;
using System.Security.Cryptography;
using System.Windows.Forms;

namespace Onnx_Demo
{
    public partial class Form1 : Form
    {
        //配置
        private readonly string modelPath = "model/dinov3_vitl16.onnx";
        private const int InputSize = 768;          // 必须与导出时一致
        private const int PatchSize = 16;
        private int GridSize => InputSize / PatchSize;   // 48
        private const int FeatureDim = 1024;        // ViT-Large

        private InferenceSession onnxSession;
        private float[,] patchFeatures;             // [NumPatches, FeatureDim]
        private Mat originalImage;                  // 原始图像 (BGR)
        private bool featuresReady = false;

        public Form1()
        {
            InitializeComponent();
            InitializeModel();
            AttachEvents();
        }

        private void Form1_Load(object sender, EventArgs e)
        {
            string imagePath = "test_img/1.jpg";
            originalImage = Cv2.ImRead(imagePath, ImreadModes.Color);
            pictureBox1.Image = Image.FromFile(imagePath);
        }

        private void InitializeModel()
        {
            try
            {
                var opts = new SessionOptions();
                opts.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
                opts.AppendExecutionProvider_CPU(0);
                onnxSession = new InferenceSession(modelPath, opts);
                LogMessage($"模型加载成功: {modelPath}");
            }
            catch (Exception ex)
            {
                LogMessage($"模型加载失败: {ex.Message}");
                MessageBox.Show("请确保 dinov3_vitl16.onnx 文件存在。", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
            }
        }

        private void AttachEvents()
        {
            // pictureBox1 单击事件(选点)
            this.pictureBox1.MouseClick += PictureBox1_MouseClick;
        }

        private void LogMessage(string msg)
        {
            if (textBox1.InvokeRequired)
                textBox1.Invoke(new Action(() => textBox1.AppendText($"{DateTime.Now:HH:mm:ss} - {msg}\r\n")));
            else
                textBox1.AppendText($"{DateTime.Now:HH:mm:ss} - {msg}\r\n");
        }

        private void button1_Click(object sender, EventArgs e)
        {
            using (OpenFileDialog ofd = new OpenFileDialog())
            {
                ofd.Filter = "图像文件|*.bmp;*.jpg;*.jpeg;*.png";
                if (ofd.ShowDialog() == DialogResult.OK)
                {
                    string imagePath = ofd.FileName;
                    originalImage = Cv2.ImRead(imagePath, ImreadModes.Color);
                    pictureBox1.Image = Image.FromFile(imagePath);
                    featuresReady = false;
                    pictureBox2.Image = null;
                    LogMessage($"已加载图片: {imagePath}");
                }
            }
        }

        private void button2_Click(object sender, EventArgs e)
        {
            if (originalImage == null)
            {
                MessageBox.Show("请先打开图片。", "提示", MessageBoxButtons.OK, MessageBoxIcon.Warning);
                return;
            }
            if (onnxSession == null)
            {
                MessageBox.Show("模型未正确加载。", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
                return;
            }

            button2.Enabled = false;
            pictureBox2.Image = null;
            LogMessage("开始提取特征...");
            Application.DoEvents();

            try
            {
                // 预处理
                var inputTensor = Preprocess(originalImage);
                var inputs = new List<NamedOnnxValue>
                {
                    NamedOnnxValue.CreateFromTensor("input", inputTensor)
                };

                DateTime start = DateTime.Now;
                using (var results = onnxSession.Run(inputs))
                {
                    DateTime end = DateTime.Now;
                    var output = results.First(o => o.Name == "patch_tokens");
                    var tensor = output.AsTensor<float>();
                    int[] dims = tensor.Dimensions.ToArray(); // [1, N, D]
                    int numPatches = dims[1];
                    int actualDim = dims[2];

                    if (numPatches != GridSize * GridSize)
                        throw new Exception($"Patch数量错误: 预期 {GridSize * GridSize}, 实际 {numPatches}");
                    if (actualDim != FeatureDim)
                        LogMessage($"特征维度 {actualDim} (预期 {FeatureDim}),继续...");

                    float[] flat = tensor.ToArray();
                    patchFeatures = new float[numPatches, actualDim];
                    for (int i = 0; i < numPatches; i++)
                        for (int j = 0; j < actualDim; j++)
                            patchFeatures[i, j] = flat[i * actualDim + j];

                    featuresReady = true;
                    LogMessage($"特征提取完成,耗时 {(end - start).TotalMilliseconds:F2} ms,Patch数: {numPatches},维度: {actualDim}");
                    MessageBox.Show("特征已就绪,请在左侧图片上单击选择查询点。", "提示", MessageBoxButtons.OK, MessageBoxIcon.Information);
                }
            }
            catch (Exception ex)
            {
                LogMessage($"特征提取失败: {ex.Message}");
                MessageBox.Show($"推理错误: {ex.Message}", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
            }
            finally
            {
                button2.Enabled = true;
            }
        }

        // ========== pictureBox1 单击选点 ==========
        private void PictureBox1_MouseClick(object sender, MouseEventArgs e)
        {
            if (!featuresReady || patchFeatures == null)
            {
                MessageBox.Show("请先点击"提取特征"按钮。", "提示", MessageBoxButtons.OK, MessageBoxIcon.Information);
                return;
            }
            if (originalImage == null) return;

            // 获取点击位置相对于 pictureBox1 的坐标 (像素)
            int clickX = e.X;
            int clickY = e.Y;

            // 获取 pictureBox1 中实际图像区域的尺寸(考虑 SizeMode)
            Rectangle imgRect = GetImageRectangle(pictureBox1);
            if (!imgRect.Contains(clickX, clickY))
            {
                LogMessage("点击位置不在图像区域内。");
                return;
            }

            // 将点击坐标映射到原图尺寸
            float scaleX = (float)originalImage.Width / imgRect.Width;
            float scaleY = (float)originalImage.Height / imgRect.Height;
            int origX = (int)((clickX - imgRect.X) * scaleX);
            int origY = (int)((clickY - imgRect.Y) * scaleY);
            origX = Math.Max(0, Math.Min(origX, originalImage.Width - 1));
            origY = Math.Max(0, Math.Min(origY, originalImage.Height - 1));

            // 根据原图坐标计算对应的 patch 索引
            // 注意:模型输入是 768x768 的正方形,原图会被缩放至该尺寸,因此坐标需要转换到 768 空间
            float modelX = origX * (float)InputSize / originalImage.Width;
            float modelY = origY * (float)InputSize / originalImage.Height;
            int patchCol = (int)(modelX / PatchSize);
            int patchRow = (int)(modelY / PatchSize);
            patchCol = Math.Min(patchCol, GridSize - 1);
            patchRow = Math.Min(patchRow, GridSize - 1);
            int targetIdx = patchRow * GridSize + patchCol;

            LogMessage($"单击位置: 原图({origX},{origY}) -> patch ({patchRow},{patchCol}) 索引 {targetIdx}");

            // 计算相似度热力图
            float[,] simMap = ComputeSimilarityMap(patchFeatures, targetIdx, GridSize);

            // 生成热力图并显示在 pictureBox2 中
            Bitmap heatmap = GenerateHeatmap(simMap, originalImage.Width, originalImage.Height);
            pictureBox2.Image = heatmap;

            // 可选:在原图上绘制红点并刷新 pictureBox1
            Image markedImage = DrawMarkerOnImage(originalImage, new OpenCvSharp.Point(origX, origY));
            pictureBox1.Image = markedImage;
        }

        // ========== 辅助函数 ==========
        private DenseTensor<float> Preprocess(Mat bgrImage)
        {
            // BGR -> RGB, resize to InputSize x InputSize
            Mat rgb = new Mat();
            Cv2.CvtColor(bgrImage, rgb, ColorConversionCodes.BGR2RGB);
            Mat resized = new Mat();
            Cv2.Resize(rgb, resized, new OpenCvSharp.Size(InputSize, InputSize));

            // 归一化至 [0,1] 并减去均值除以标准差
            resized.ConvertTo(resized, MatType.CV_32FC3, 1.0 / 255.0);
            float[] mean = { 0.485f, 0.456f, 0.406f };
            float[] std = { 0.229f, 0.224f, 0.225f };

            int h = InputSize, w = InputSize;
            float[] inputData = new float[3 * h * w];
            for (int y = 0; y < h; y++)
            {
                for (int x = 0; x < w; x++)
                {
                    Vec3f pixel = resized.At<Vec3f>(y, x); // R,G,B
                    inputData[0 * h * w + y * w + x] = (pixel.Item0 - mean[0]) / std[0];
                    inputData[1 * h * w + y * w + x] = (pixel.Item1 - mean[1]) / std[1];
                    inputData[2 * h * w + y * w + x] = (pixel.Item2 - mean[2]) / std[2];
                }
            }
            resized.Dispose();
            rgb.Dispose();

            return new DenseTensor<float>(inputData, new[] { 1, 3, h, w });
        }

        private float[,] ComputeSimilarityMap(float[,] feats, int targetIdx, int gridSize)
        {
            int numPatches = feats.GetLength(0);
            int dim = feats.GetLength(1);
            float[,] sim = new float[gridSize, gridSize];

            // 目标向量
            float[] target = new float[dim];
            for (int j = 0; j < dim; j++) target[j] = feats[targetIdx, j];
            float targetNorm = (float)Math.Sqrt(target.Sum(v => v * v));

            const float eps = 1e-8f;
            for (int idx = 0; idx < numPatches; idx++)
            {
                float[] curr = new float[dim];
                for (int j = 0; j < dim; j++) curr[j] = feats[idx, j];
                float currNorm = (float)Math.Sqrt(curr.Sum(v => v * v));
                float dot = 0;
                for (int j = 0; j < dim; j++) dot += curr[j] * target[j];
                float cos = dot / (currNorm * targetNorm + eps);
                int row = idx / gridSize;
                int col = idx % gridSize;
                sim[row, col] = cos;
            }
            return sim;
        }

        private Bitmap GenerateHeatmap(float[,] simMap, int outW, int outH)
        {
            int g = GridSize;
            // 双线性插值上采样
            float[,] upsampled = BilinearUpsample(simMap, outH, outW);
            var colors = GetViridisColormap();
            Bitmap bmp = new Bitmap(outW, outH);
            for (int y = 0; y < outH; y++)
            {
                for (int x = 0; x < outW; x++)
                {
                    float val = upsampled[y, x];
                    int idx = (int)(val * 255);
                    idx = Math.Max(0, Math.Min(255, idx));
                    bmp.SetPixel(x, y, colors[idx]);
                }
            }
            return bmp;
        }

        private float[,] BilinearUpsample(float[,] src, int newH, int newW)
        {
            int srcH = src.GetLength(0);
            int srcW = src.GetLength(1);
            float[,] dst = new float[newH, newW];
            float scaleX = (float)(srcW - 1) / newW;
            float scaleY = (float)(srcH - 1) / newH;

            for (int y = 0; y < newH; y++)
            {
                float fy = y * scaleY;
                int y0 = (int)Math.Floor(fy);
                int y1 = Math.Min(y0 + 1, srcH - 1);
                float dy = fy - y0;
                for (int x = 0; x < newW; x++)
                {
                    float fx = x * scaleX;
                    int x0 = (int)Math.Floor(fx);
                    int x1 = Math.Min(x0 + 1, srcW - 1);
                    float dx = fx - x0;
                    float v00 = src[y0, x0];
                    float v01 = src[y0, x1];
                    float v10 = src[y1, x0];
                    float v11 = src[y1, x1];
                    float v0 = v00 * (1 - dx) + v01 * dx;
                    float v1 = v10 * (1 - dx) + v11 * dx;
                    dst[y, x] = v0 * (1 - dy) + v1 * dy;
                }
            }
            return dst;
        }

        private Color[] GetViridisColormap()
        {
            Mat cm = new Mat(1, 256, MatType.CV_8UC3);
            for (int i = 0; i < 256; i++)
                cm.Set<Vec3b>(0, i, new Vec3b((byte)i, (byte)i, (byte)i));
            Cv2.ApplyColorMap(cm, cm, ColormapTypes.Viridis);
            Color[] colors = new Color[256];
            for (int i = 0; i < 256; i++)
            {
                Vec3b bgr = cm.At<Vec3b>(0, i);
                colors[i] = Color.FromArgb(bgr.Item2, bgr.Item1, bgr.Item0);
            }
            cm.Dispose();
            return colors;
        }

        private Image DrawMarkerOnImage(Mat bgrImg, OpenCvSharp.Point pixel)
        {
            Mat marked = bgrImg.Clone();
            Cv2.Circle(marked, new OpenCvSharp.Point(pixel.X, pixel.Y), 8, new Scalar(0, 0, 255), -1);
            return new Bitmap(marked.ToMemoryStream());
        }

        private Rectangle GetImageRectangle(PictureBox picBox)
        {
            if (picBox.Image == null) return Rectangle.Empty;
            PictureBoxSizeMode mode = picBox.SizeMode;
            int imgW = picBox.Image.Width;
            int imgH = picBox.Image.Height;
            int ctrlW = picBox.Width;
            int ctrlH = picBox.Height;

            if (mode == PictureBoxSizeMode.Zoom)
            {
                float scale = Math.Min((float)ctrlW / imgW, (float)ctrlH / imgH);
                int drawW = (int)(imgW * scale);
                int drawH = (int)(imgH * scale);
                int x = (ctrlW - drawW) / 2;
                int y = (ctrlH - drawH) / 2;
                return new Rectangle(x, y, drawW, drawH);
            }
            elseif (mode == PictureBoxSizeMode.Normal || mode == PictureBoxSizeMode.AutoSize)
            {
                return new Rectangle(0, 0, imgW, imgH);
            }
            else // StretchImage
            {
                return new Rectangle(0, 0, ctrlW, ctrlH);
            }
        }

        private void button3_Click(object sender, EventArgs e)
        {
            if (pictureBox2.Image == null)
            {
                MessageBox.Show("请先进行推理!", "提示", MessageBoxButtons.OK, MessageBoxIcon.Information);
                return;
            }

            SaveFileDialog sfd = new SaveFileDialog();
            sfd.Title = "保存图像";
            sfd.Filter = "PNG图片 (*.png)|*.png|JPEG图片 (*.jpg)|*.jpg|BMP图片 (*.bmp)|*.bmp";
            sfd.FilterIndex = 1;
            if (sfd.ShowDialog() == DialogResult.OK)
            {
                string ext = System.IO.Path.GetExtension(sfd.FileName).ToLower();
                ImageFormat format = ImageFormat.Png;
                if (ext == ".jpg" || ext == ".jpeg")
                    format = ImageFormat.Jpeg;
                elseif (ext == ".bmp")
                    format = ImageFormat.Bmp;

                using (var stream = pictureBox2.Image)
                using (var bitmap = new Bitmap(stream))
                {
                    bitmap.Save(sfd.FileName, format);
                }
                MessageBox.Show($"保存成功!\n位置: {sfd.FileName}", "完成", MessageBoxButtons.OK, MessageBoxIcon.Information);
            }
        }
    }
}
相关推荐
byoass1 小时前
自动化任务系列之五:PDF批量转换+自动清理——文件格式规范化工作流
网络·人工智能·安全·云计算
nix.gnehc1 小时前
读懂 OpenSpec:AI 编码时代的规范驱动开发新范式
人工智能·驱动开发·sdd·openspec
咚咚王者2 小时前
人工智能之大模型应用 基础入门第三章 大模型赋能行业与未来展望
人工智能
阿杰学AI2 小时前
AI核心知识141—大语言模型之 对齐难题(简洁且通俗易懂版)
人工智能·安全·ai·语言模型·自然语言处理·aigc·ai对齐
AI医影跨模态组学2 小时前
如何将机器学习模型评分与肿瘤微环境中的去乙酰化修饰及免疫细胞组成建立关联,并进一步解释其与NSCLC免疫治疗预后的机制联系
人工智能·机器学习·论文·医学·医学影像
流年似水~2 小时前
2026 年跨平台开发全貌:Flutter、KMP、React Native 怎么选?
人工智能·程序人生·语言模型·ai编程
墨染天姬2 小时前
【AI】MCP和SKILLS区别
人工智能
rpa研究爱好者2 小时前
基于 DeerFlow 二次开发:AgentFlow 如何让超级智能体“零门槛”落地?
人工智能·ai
sali-tec2 小时前
C# 基于OpenCv的视觉工作流-章53-QR二维码1
图像处理·人工智能·opencv·算法·计算机视觉