C# TensorRT部署RF-DETR目标检测&分割模型

Form1.cs

cs 复制代码
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Windows.Forms;
using OpenCvSharp;
using RFDETRDetector.TensorRT;
using RFDETRSegmentor.TensorRT;
using Point = OpenCvSharp.Point;
using Size = OpenCvSharp.Size;

namespace DML
{
    public partial class Form1 : Form
    {
        // 模型实例
        private RfDetrModelTensorRt _detector;
        private RFDETRSEGTensorRt _segmentor;

        // 当前图片路径
        private string _imagePath = "";

        // 计时
        private Stopwatch _stopwatch = new Stopwatch();

        // 模型文件路径(请根据实际路径修改)
        private readonly string _detectorEnginePath = @"models/rfdetr_det.engine";
        private readonly string _segmentorEnginePath = @"models/rfdetr_seg.engine";
        private readonly string _classesPath = @"models/classes.txt";

        public Form1()
        {
            InitializeComponent();

            // 加载模型(如果文件存在)
            LoadModels();
        }

        private void LoadModels()
        {
            try
            {
                if (File.Exists(_detectorEnginePath))
                {
                    _detector = new RfDetrModelTensorRt(_detectorEnginePath, _classesPath);
                    toolStripStatusLabel1.Text = "检测模型已加载";
                }
                else
                {
                    toolStripStatusLabel1.Text = "检测模型文件不存在";
                }

                if (File.Exists(_segmentorEnginePath))
                {
                    _segmentor = new RFDETRSEGTensorRt(_segmentorEnginePath, _classesPath);
                    toolStripStatusLabel2.Text = "分割模型已加载";
                }
                else
                {
                    toolStripStatusLabel2.Text = "分割模型文件不存在";
                }
            }
            catch (Exception ex)
            {
                MessageBox.Show($"模型加载失败: {ex.Message}", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
            }
        }

        /// <summary>
        /// 选择图片按钮
        /// </summary>
        private void BtnOpenImage_Click(object sender, EventArgs e)
        {
            using (OpenFileDialog ofd = new OpenFileDialog())
            {
                ofd.Filter = "图片文件|*.bmp;*.jpg;*.jpeg;*.png;*.tiff";
                ofd.Title = "选择图片";
                if (ofd.ShowDialog() != DialogResult.OK) return;

                _imagePath = ofd.FileName;
                pictureBox1.Image = new Bitmap(_imagePath);
                pictureBox2.Image = null;
                textBox1.Clear();
            }
        }

        /// <summary>
        /// 目标检测推理
        /// </summary>
        private async void BtnDetect_Click(object sender, EventArgs e)
        {
            if (string.IsNullOrEmpty(_imagePath))
            {
                MessageBox.Show("请先选择图片", "提示", MessageBoxButtons.OK, MessageBoxIcon.Information);
                return;
            }

            if (_detector == null)
            {
                MessageBox.Show("检测模型未加载", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
                return;
            }

            btnDetect.Enabled = false;
            btnSegment.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Clear();

            try
            {
                // 读取图像
                using (Mat img = Cv2.ImRead(_imagePath))
                {
                    if (img.Empty())
                    {
                        MessageBox.Show("无法读取图片", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
                        return;
                    }

                    // 执行推理(这里假设我们的 TensorRT 类已提供每个阶段的耗时,但为了简单,我们整体计时)
                    _stopwatch.Restart();
                    var results = _detector.Predict(img, 0.5f);
                    _stopwatch.Stop();
                    long totalMs = _stopwatch.ElapsedMilliseconds;

                    // 绘制结果
                    using (Mat resultMat = DrawDetectionResults(img, results))
                    {
                        // 转换为 Bitmap 并显示
                        using (var ms = resultMat.ToMemoryStream())
                        {
                            pictureBox2.Image = new Bitmap(ms);
                        }
                    }

                    textBox1.Text = $"检测耗时: {totalMs} ms\n检测到 {results.Count} 个目标";
                }
            }
            catch (Exception ex)
            {
                MessageBox.Show($"推理出错: {ex.Message}", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
            }
            finally
            {
                btnDetect.Enabled = true;
                btnSegment.Enabled = true;
            }
        }

        /// <summary>
        /// 分割推理
        /// </summary>
        private async void BtnSegment_Click(object sender, EventArgs e)
        {
            if (string.IsNullOrEmpty(_imagePath))
            {
                MessageBox.Show("请先选择图片", "提示", MessageBoxButtons.OK, MessageBoxIcon.Information);
                return;
            }

            if (_segmentor == null)
            {
                MessageBox.Show("分割模型未加载", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
                return;
            }

            btnDetect.Enabled = false;
            btnSegment.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Clear();

            try
            {
                using (Mat img = Cv2.ImRead(_imagePath))
                {
                    if (img.Empty())
                    {
                        MessageBox.Show("无法读取图片", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
                        return;
                    }

                    _stopwatch.Restart();
                    var results = _segmentor.Predict(img, 0.5f, 0.3f);
                    _stopwatch.Stop();
                    long totalMs = _stopwatch.ElapsedMilliseconds;

                    // 绘制分割结果(边界框 + 掩码叠加)
                    using (Mat resultMat = DrawSegmentationResults(img, results))
                    {
                        using (var ms = resultMat.ToMemoryStream())
                        {
                            pictureBox2.Image = new Bitmap(ms);
                        }
                    }

                    textBox1.Text = $"分割耗时: {totalMs} ms\n检测到 {results.Count} 个实例";
                }
            }
            catch (Exception ex)
            {
                MessageBox.Show($"推理出错: {ex.Message}", "错误", MessageBoxButtons.OK, MessageBoxIcon.Error);
            }
            finally
            {
                btnDetect.Enabled = true;
                btnSegment.Enabled = true;
            }
        }

        /// <summary>
        /// 绘制目标检测结果(边界框 + 标签)
        /// </summary>
        private Mat DrawDetectionResults(Mat image, List<RfDetrModelTensorRt.DetectionResult> results)
        {
            Mat output = image.Clone();
            foreach (var res in results)
            {
                // 绘制矩形框
                Cv2.Rectangle(output, res.BoundingBox, new Scalar(0, 255, 0), 2);
                // 绘制标签背景
                string label = $"{res.Label}: {res.Confidence:F2}";
                int baseline;
                Size textSize = Cv2.GetTextSize(label, HersheyFonts.HersheySimplex, 0.5, 1, out baseline);
                Point pt = new Point(res.BoundingBox.X, res.BoundingBox.Y - 5);
                if (pt.Y < 0) pt.Y = res.BoundingBox.Y + textSize.Height + 5;
                Cv2.Rectangle(output, new Rect(pt.X, pt.Y - textSize.Height, textSize.Width, textSize.Height + baseline),
                    new Scalar(0, 255, 0), -1);
                Cv2.PutText(output, label, new Point(pt.X, pt.Y), HersheyFonts.HersheySimplex, 0.5, new Scalar(0, 0, 0), 1);
            }
            return output;
        }

        /// <summary>
        /// 绘制分割结果(边界框 + 半透明掩码叠加)
        /// </summary>
        private Mat DrawSegmentationResults(Mat image, List<RFDETRSEGTensorRt.SegmentationResult> results)
        {
            Mat output = image.Clone();
            Random rand = new Random();

            foreach (var res in results)
            {
                // 随机颜色
                Scalar color = new Scalar(rand.Next(0, 255), rand.Next(0, 255), rand.Next(0, 255));

                // 绘制掩码(半透明叠加)
                if (res.Mask != null && !res.Mask.Empty())
                {
                    // 将二值掩码转换为彩色掩码并叠加
                    Mat maskColor = new Mat();
                    Cv2.CvtColor(res.Mask, maskColor, ColorConversionCodes.GRAY2BGR);
                    // 调整颜色为目标类别颜色
                    maskColor.SetTo(color, res.Mask);
                    // 与原图加权叠加
                    Cv2.AddWeighted(output, 0.6, maskColor, 0.4, 0, output);
                    maskColor.Dispose();
                }

                // 绘制边界框
                Cv2.Rectangle(output, res.BoundingBox, color, 2);
                // 绘制标签
                string label = $"{res.Label}: {res.Confidence:F2}";
                int baseline;
                Size textSize = Cv2.GetTextSize(label, HersheyFonts.HersheySimplex, 0.5, 1, out baseline);
                Point pt = new Point(res.BoundingBox.X, res.BoundingBox.Y - 5);
                if (pt.Y < 0) pt.Y = res.BoundingBox.Y + textSize.Height + 5;
                Cv2.Rectangle(output, new Rect(pt.X, pt.Y - textSize.Height, textSize.Width, textSize.Height + baseline),
                    color, -1);
                Cv2.PutText(output, label, new Point(pt.X, pt.Y), HersheyFonts.HersheySimplex, 0.5, new Scalar(0, 0, 0), 1);
            }
            return output;
        }

        /// <summary>
        /// 释放模型资源
        /// </summary>
        protected override void OnFormClosing(FormClosingEventArgs e)
        {
            _detector?.Dispose();
            _segmentor?.Dispose();
            base.OnFormClosing(e);
        }
    }
}

Form1.Designer.cs

cs 复制代码
namespace DML
{
    partial class Form1
    {
        private System.ComponentModel.IContainer components = null;

        protected override void Dispose(bool disposing)
        {
            if (disposing && (components != null))
            {
                components.Dispose();
            }
            base.Dispose(disposing);
        }

        private void InitializeComponent()
        {
            this.pictureBox1 = new System.Windows.Forms.PictureBox();
            this.pictureBox2 = new System.Windows.Forms.PictureBox();
            this.btnOpenImage = new System.Windows.Forms.Button();
            this.btnDetect = new System.Windows.Forms.Button();
            this.btnSegment = new System.Windows.Forms.Button();
            this.textBox1 = new System.Windows.Forms.TextBox();
            this.statusStrip1 = new System.Windows.Forms.StatusStrip();
            this.toolStripStatusLabel1 = new System.Windows.Forms.ToolStripStatusLabel();
            this.toolStripStatusLabel2 = new System.Windows.Forms.ToolStripStatusLabel();
            ((System.ComponentModel.ISupportInitialize)(this.pictureBox1)).BeginInit();
            ((System.ComponentModel.ISupportInitialize)(this.pictureBox2)).BeginInit();
            this.statusStrip1.SuspendLayout();
            this.SuspendLayout();

            // pictureBox1
            this.pictureBox1.BorderStyle = System.Windows.Forms.BorderStyle.FixedSingle;
            this.pictureBox1.Location = new System.Drawing.Point(12, 12);
            this.pictureBox1.Name = "pictureBox1";
            this.pictureBox1.Size = new System.Drawing.Size(480, 480);
            this.pictureBox1.SizeMode = System.Windows.Forms.PictureBoxSizeMode.Zoom;
            this.pictureBox1.TabIndex = 0;
            this.pictureBox1.TabStop = false;

            // pictureBox2
            this.pictureBox2.BorderStyle = System.Windows.Forms.BorderStyle.FixedSingle;
            this.pictureBox2.Location = new System.Drawing.Point(508, 12);
            this.pictureBox2.Name = "pictureBox2";
            this.pictureBox2.Size = new System.Drawing.Size(480, 480);
            this.pictureBox2.SizeMode = System.Windows.Forms.PictureBoxSizeMode.Zoom;
            this.pictureBox2.TabIndex = 1;
            this.pictureBox2.TabStop = false;

            // btnOpenImage
            this.btnOpenImage.Location = new System.Drawing.Point(12, 510);
            this.btnOpenImage.Name = "btnOpenImage";
            this.btnOpenImage.Size = new System.Drawing.Size(100, 40);
            this.btnOpenImage.TabIndex = 2;
            this.btnOpenImage.Text = "打开图片";
            this.btnOpenImage.UseVisualStyleBackColor = true;
            this.btnOpenImage.Click += new System.EventHandler(this.BtnOpenImage_Click);

            // btnDetect
            this.btnDetect.Location = new System.Drawing.Point(130, 510);
            this.btnDetect.Name = "btnDetect";
            this.btnDetect.Size = new System.Drawing.Size(100, 40);
            this.btnDetect.TabIndex = 3;
            this.btnDetect.Text = "目标检测";
            this.btnDetect.UseVisualStyleBackColor = true;
            this.btnDetect.Click += new System.EventHandler(this.BtnDetect_Click);

            // btnSegment
            this.btnSegment.Location = new System.Drawing.Point(248, 510);
            this.btnSegment.Name = "btnSegment";
            this.btnSegment.Size = new System.Drawing.Size(100, 40);
            this.btnSegment.TabIndex = 4;
            this.btnSegment.Text = "实例分割";
            this.btnSegment.UseVisualStyleBackColor = true;
            this.btnSegment.Click += new System.EventHandler(this.BtnSegment_Click);

            // textBox1
            this.textBox1.Location = new System.Drawing.Point(12, 560);
            this.textBox1.Multiline = true;
            this.textBox1.Name = "textBox1";
            this.textBox1.ReadOnly = true;
            this.textBox1.Size = new System.Drawing.Size(976, 80);
            this.textBox1.TabIndex = 5;

            // statusStrip1
            this.statusStrip1.Items.AddRange(new System.Windows.Forms.ToolStripItem[] {
            this.toolStripStatusLabel1,
            this.toolStripStatusLabel2});
            this.statusStrip1.Location = new System.Drawing.Point(0, 649);
            this.statusStrip1.Name = "statusStrip1";
            this.statusStrip1.Size = new System.Drawing.Size(1004, 22);
            this.statusStrip1.TabIndex = 6;
            this.statusStrip1.Text = "statusStrip1";

            // toolStripStatusLabel1
            this.toolStripStatusLabel1.Name = "toolStripStatusLabel1";
            this.toolStripStatusLabel1.Size = new System.Drawing.Size(120, 17);
            this.toolStripStatusLabel1.Text = "检测模型: 未加载";

            // toolStripStatusLabel2
            this.toolStripStatusLabel2.Name = "toolStripStatusLabel2";
            this.toolStripStatusLabel2.Size = new System.Drawing.Size(120, 17);
            this.toolStripStatusLabel2.Text = "分割模型: 未加载";

            // Form1
            this.AutoScaleDimensions = new System.Drawing.SizeF(7F, 17F);
            this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font;
            this.ClientSize = new System.Drawing.Size(1004, 671);
            this.Controls.Add(this.statusStrip1);
            this.Controls.Add(this.textBox1);
            this.Controls.Add(this.btnSegment);
            this.Controls.Add(this.btnDetect);
            this.Controls.Add(this.btnOpenImage);
            this.Controls.Add(this.pictureBox2);
            this.Controls.Add(this.pictureBox1);
            this.Name = "Form1";
            this.Text = "RF-DETR TensorRT 演示";
            ((System.ComponentModel.ISupportInitialize)(this.pictureBox1)).EndInit();
            ((System.ComponentModel.ISupportInitialize)(this.pictureBox2)).EndInit();
            this.statusStrip1.ResumeLayout(false);
            this.statusStrip1.PerformLayout();
            this.ResumeLayout(false);
            this.PerformLayout();
        }

        private System.Windows.Forms.PictureBox pictureBox1;
        private System.Windows.Forms.PictureBox pictureBox2;
        private System.Windows.Forms.Button btnOpenImage;
        private System.Windows.Forms.Button btnDetect;
        private System.Windows.Forms.Button btnSegment;
        private System.Windows.Forms.TextBox textBox1;
        private System.Windows.Forms.StatusStrip statusStrip1;
        private System.Windows.Forms.ToolStripStatusLabel toolStripStatusLabel1;
        private System.Windows.Forms.ToolStripStatusLabel toolStripStatusLabel2;
    }
}

RfDetrModelTensorRt.cs

cs 复制代码
using OpenCvSharp;
using JYPPX.TensorRtSharp.Cuda;
using JYPPX.TensorRtSharp.Nvinfer;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Diagnostics;
using Size = OpenCvSharp.Size;

namespace RFDETRDetector.TensorRT
{
    /// <summary>
    /// RF-DETR 目标检测模型 (TensorRT 后端)
    /// </summary>
    public class RfDetrModelTensorRt : IDisposable
    {
        private Runtime _runtime;
        private CudaEngine _engine;
        private JYPPX.TensorRtSharp.Nvinfer.ExecutionContext _context;
        private CudaStream _stream;
        private Cuda1DMemory<float> _inputMemory;
        private Cuda1DMemory<float> _outputDetsMemory;
        private Cuda1DMemory<float> _outputLabelsMemory;

        private readonly int _inputWidth;
        private readonly int _inputHeight;
        private readonly int _maxDetections;
        private readonly int _numClasses;
        private readonly List<string> _classes;

        // 预处理参数(与 ONNX 版一致)
        private readonly Scalar _mean = new Scalar(0.485f, 0.456f, 0.406f);
        private readonly Scalar _std = new Scalar(0.229f, 0.224f, 0.225f);

        public class DetectionResult
        {
            public Rect BoundingBox { get; set; }
            public string Label { get; set; }
            public float Confidence { get; set; }
            public int ClassId { get; set; }
        }

        /// <summary>
        /// 加载 TensorRT 引擎
        /// </summary>
        /// <param name="enginePath">.engine 文件路径</param>
        /// <param name="classesPath">类别文件路径</param>
        public RfDetrModelTensorRt(string enginePath, string classesPath)
        {
            if (!File.Exists(enginePath))
                throw new FileNotFoundException($"Engine 文件不存在: {enginePath}");
            if (!File.Exists(classesPath))
                throw new FileNotFoundException($"类别文件不存在: {classesPath}");

            // 1. 读取 engine 二进制
            byte[] engineData = File.ReadAllBytes(enginePath);

            // 2. 初始化 Runtime 并反序列化引擎
            _runtime = new Runtime();
            _engine = _runtime.deserializeCudaEngineByBlob(engineData, (ulong)engineData.Length);

            // 3. 创建执行上下文和 CUDA 流
            _context = _engine.createExecutionContext(TrtExecutionContextAllocationStrategy.kSTATIC);
            _stream = new CudaStream();

            // 4. 获取输入输出维度信息
            // 输入名称假设为 "pixel_values" (RF-DETR 常用),可根据实际修改
            string inputName = "input";
            string detsName = "dets";
            string labelsName = "labels";

            var inputShape = _context.getTensorShape(inputName);
            _inputHeight = (int)inputShape.d[2];
            _inputWidth = (int)inputShape.d[3];

            var detsShape = _context.getTensorShape(detsName);
            _maxDetections = (int)detsShape.d[1];
            var labelsShape = _context.getTensorShape(labelsName);
            _numClasses = (int)labelsShape.d[2];

            // 5. 分配 GPU 显存
            ulong inputSize = (ulong)(1 * 3 * _inputHeight * _inputWidth);
            ulong detsSize = 1 * (ulong)_maxDetections * 4;        // [1, maxDet, 4]
            ulong labelsSize = 1 * (ulong)_maxDetections * (ulong)_numClasses;

            _inputMemory = new Cuda1DMemory<float>(inputSize);
            _outputDetsMemory = new Cuda1DMemory<float>(detsSize);
            _outputLabelsMemory = new Cuda1DMemory<float>(labelsSize);

            // 6. 绑定输入输出地址
            _context.setInputTensorAddress(inputName, _inputMemory.get());
            _context.setOutputTensorAddress(detsName, _outputDetsMemory.get());
            _context.setOutputTensorAddress(labelsName, _outputLabelsMemory.get());

            // 7. 加载类别名称
            _classes = File.ReadAllLines(classesPath).Select(l => l.Trim()).ToList();

            Console.WriteLine($"RF-DETR TensorRT 模型加载成功 | 输入: {_inputWidth}x{_inputHeight} | 检测数: {_maxDetections} | 类别数: {_numClasses}");
        }

        /// <summary>
        /// 执行推理
        /// </summary>
        public List<DetectionResult> Predict(Mat image, float confidenceThreshold = 0.5f)
        {
            if (image == null || image.Empty())
                throw new ArgumentException("输入图像无效");

            int origW = image.Width;
            int origH = image.Height;

            // 1. 预处理
            float[] inputData = Preprocess(image);

            // 2. 异步拷贝到 GPU 并推理
            _inputMemory.copyFromHostAsync(inputData, _stream);
            _context.executeV3(_stream);
            _stream.Synchronize();

            // 3. 拷贝结果回主机
            float[] detsData = new float[_maxDetections * 4];
            float[] labelsData = new float[_maxDetections * _numClasses];
            _outputDetsMemory.copyToHostAsync(detsData, _stream);
            _outputLabelsMemory.copyToHostAsync(labelsData, _stream);
            _stream.Synchronize();

            // 4. 后处理
            return ParseOutputs(detsData, labelsData, confidenceThreshold, origW, origH);
        }

        private float[] Preprocess(Mat image)
        {
            // 直接 Resize + BGR->RGB + 归一化 (ImageNet)
            Mat resized = new Mat();
            Cv2.Resize(image, resized, new Size(_inputWidth, _inputHeight));

            Mat rgb = new Mat();
            Cv2.CvtColor(resized, rgb, ColorConversionCodes.BGR2RGB);
            resized.Dispose();

            Mat normalized = new Mat();
            rgb.ConvertTo(normalized, MatType.CV_32FC3, 1.0 / 255.0);
            rgb.Dispose();

            Cv2.Subtract(normalized, _mean, normalized);
            Cv2.Divide(normalized, _std, normalized);

            // HWC -> CHW
            float[] result = new float[3 * _inputHeight * _inputWidth];
            unsafe
            {
                float* ptr = (float*)normalized.DataPointer;
                int channelSize = _inputHeight * _inputWidth;
                for (int h = 0; h < _inputHeight; h++)
                {
                    for (int w = 0; w < _inputWidth; w++)
                    {
                        for (int c = 0; c < 3; c++)
                        {
                            result[c * channelSize + h * _inputWidth + w] = ptr[h * _inputWidth * 3 + w * 3 + c];
                        }
                    }
                }
            }
            normalized.Dispose();
            return result;
        }

        private List<DetectionResult> ParseOutputs(float[] detsData, float[] labelsData,
            float confidenceThreshold, int origW, int origH)
        {
            var results = new List<DetectionResult>();
            float imgW = origW;
            float imgH = origH;

            for (int i = 0; i < _maxDetections; i++)
            {
                float cx = detsData[i * 4];
                float cy = detsData[i * 4 + 1];
                float w = detsData[i * 4 + 2];
                float h = detsData[i * 4 + 3];
                if (w <= 0 || h <= 0) continue;

                float x1 = cx - w / 2;
                float y1 = cy - h / 2;
                float x2 = cx + w / 2;
                float y2 = cy + h / 2;

                // 类别置信度
                float maxConf = 0;
                int bestClass = -1;
                for (int c = 0; c < _numClasses; c++)
                {
                    float score = Sigmoid(labelsData[i * _numClasses + c]);
                    if (score > maxConf)
                    {
                        maxConf = score;
                        bestClass = c;
                    }
                }
                if (maxConf < confidenceThreshold) continue;

                int px1 = (int)(x1 * imgW);
                int py1 = (int)(y1 * imgH);
                int px2 = (int)(x2 * imgW);
                int py2 = (int)(y2 * imgH);
                px1 = Math.Clamp(px1, 0, origW);
                py1 = Math.Clamp(py1, 0, origH);
                px2 = Math.Clamp(px2, 0, origW);
                py2 = Math.Clamp(py2, 0, origH);
                if (px2 <= px1 || py2 <= py1) continue;

                int mappedId = (bestClass - 1) > 0 ? (bestClass - 1) : 0;
                string label = mappedId < _classes.Count ? _classes[mappedId] : $"class_{bestClass}";

                results.Add(new DetectionResult
                {
                    BoundingBox = new Rect(px1, py1, px2 - px1, py2 - py1),
                    Label = label,
                    Confidence = maxConf,
                    ClassId = mappedId
                });
            }
            return results.OrderByDescending(r => r.Confidence).ToList();
        }

        private static float Sigmoid(float x) => 1.0f / (1.0f + MathF.Exp(-x));

        public void Dispose()
        {
            _inputMemory?.Dispose();
            _outputDetsMemory?.Dispose();
            _outputLabelsMemory?.Dispose();
            _stream?.Dispose();
            _context?.Dispose();
            _engine?.Dispose();
            _runtime?.Dispose();
        }
    }
}

RFDETRSEGTensorRt.cs

cs 复制代码
using OpenCvSharp;
using JYPPX.TensorRtSharp.Cuda;
using JYPPX.TensorRtSharp.Nvinfer;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Size = OpenCvSharp.Size;

namespace RFDETRSegmentor.TensorRT
{
    public class RFDETRSEGTensorRt : IDisposable
    {
        private Runtime _runtime;
        private CudaEngine _engine;
        private JYPPX.TensorRtSharp.Nvinfer.ExecutionContext _context;
        private CudaStream _stream;
        private Cuda1DMemory<float> _inputMemory;
        private Cuda1DMemory<float> _outputDetsMemory;
        private Cuda1DMemory<float> _outputLabelsMemory;
        private Cuda1DMemory<float> _outputMasksMemory;

        private readonly int _inputWidth;
        private readonly int _inputHeight;
        private readonly int _numQueries;
        private readonly int _numClasses;
        private readonly int _maskSize;
        private readonly List<string> _userClasses;
        private readonly bool _hasBackground;

        private readonly Scalar _mean = new Scalar(0.485f, 0.456f, 0.406f);
        private readonly Scalar _std = new Scalar(0.229f, 0.224f, 0.225f);

        public class SegmentationResult
        {
            public Rect BoundingBox { get; set; }
            public int ClassId { get; set; }
            public string Label { get; set; }
            public float Confidence { get; set; }
            public Mat Mask { get; set; }
        }

        /// <summary>
        /// 加载 TensorRT 分割模型
        /// </summary>
        /// <param name="enginePath">.engine 文件路径</param>
        /// <param name="classesPath">类别文件(不含背景)</param>
        public RFDETRSEGTensorRt(string enginePath, string classesPath)
        {
            if (!File.Exists(enginePath))
                throw new FileNotFoundException($"Engine 文件不存在: {enginePath}");
            if (!File.Exists(classesPath))
                throw new FileNotFoundException($"类别文件不存在: {classesPath}");

            byte[] engineData = File.ReadAllBytes(enginePath);
            _runtime = new Runtime();
            _engine = _runtime.deserializeCudaEngineByBlob(engineData, (ulong)engineData.Length);
            _context = _engine.createExecutionContext(TrtExecutionContextAllocationStrategy.kSTATIC);
            _stream = new CudaStream();

            // 输入输出张量名称(根据实际模型修改)
            string inputName = "input";
            string detsName = "dets";
            string labelsName = "labels";
            string masksName = "masks";

            var inputShape = _context.getTensorShape(inputName);
            _inputHeight = (int)inputShape.d[2];
            _inputWidth = (int)inputShape.d[3];

            var detsShape = _context.getTensorShape(detsName);
            _numQueries = (int)detsShape.d[1];
            var labelsShape = _context.getTensorShape(labelsName);
            _numClasses = (int)labelsShape.d[2];
            var masksShape = _context.getTensorShape(masksName);
            _maskSize = (int)masksShape.d[2];

            // 分配显存
            ulong inputSize = (ulong)(1 * 3 * _inputHeight * _inputWidth);
            ulong detsSize = 1 * (ulong)_numQueries * 4;
            ulong labelsSize = 1 * (ulong)_numQueries * (ulong)_numClasses;
            ulong masksSize = 1 * (ulong)_numQueries * (ulong)_maskSize * (ulong)_maskSize;

            _inputMemory = new Cuda1DMemory<float>(inputSize);
            _outputDetsMemory = new Cuda1DMemory<float>(detsSize);
            _outputLabelsMemory = new Cuda1DMemory<float>(labelsSize);
            _outputMasksMemory = new Cuda1DMemory<float>(masksSize);

            _context.setInputTensorAddress(inputName, _inputMemory.get());
            _context.setOutputTensorAddress(detsName, _outputDetsMemory.get());
            _context.setOutputTensorAddress(labelsName, _outputLabelsMemory.get());
            _context.setOutputTensorAddress(masksName, _outputMasksMemory.get());

            _userClasses = File.ReadAllLines(classesPath).Where(l => !string.IsNullOrWhiteSpace(l)).Select(l => l.Trim()).ToList();
            _hasBackground = (_numClasses == _userClasses.Count + 1);
            Console.WriteLine($"分割模型加载成功 | 输入: {_inputWidth}x{_inputHeight} | Queries: {_numQueries} | 类别: {_numClasses} | 掩码: {_maskSize}x{_maskSize}");
        }

        public List<SegmentationResult> Predict(Mat image, float confidenceThreshold = 0.5f, float maskThreshold = 0.3f)
        {
            if (image == null || image.Empty())
                throw new ArgumentException("输入图像无效");

            int origW = image.Width, origH = image.Height;

            float[] inputData = Preprocess(image);
            _inputMemory.copyFromHostAsync(inputData, _stream);
            _context.executeV3(_stream);
            _stream.Synchronize();

            float[] dets = new float[_numQueries * 4];
            float[] labels = new float[_numQueries * _numClasses];
            float[] masks = new float[_numQueries * _maskSize * _maskSize];
            _outputDetsMemory.copyToHostAsync(dets, _stream);
            _outputLabelsMemory.copyToHostAsync(labels, _stream);
            _outputMasksMemory.copyToHostAsync(masks, _stream);
            _stream.Synchronize();

            return Postprocess(origW, origH, dets, labels, masks, confidenceThreshold, maskThreshold);
        }

        private float[] Preprocess(Mat image)
        {
            Mat resized = new Mat();
            Cv2.Resize(image, resized, new Size(_inputWidth, _inputHeight));
            Mat rgb = new Mat();
            Cv2.CvtColor(resized, rgb, ColorConversionCodes.BGR2RGB);
            resized.Dispose();

            Mat normalized = new Mat();
            rgb.ConvertTo(normalized, MatType.CV_32FC3, 1.0 / 255.0);
            rgb.Dispose();

            Cv2.Subtract(normalized, _mean, normalized);
            Cv2.Divide(normalized, _std, normalized);

            float[] result = new float[3 * _inputHeight * _inputWidth];
            unsafe
            {
                float* ptr = (float*)normalized.DataPointer;
                int channelSize = _inputHeight * _inputWidth;
                for (int h = 0; h < _inputHeight; h++)
                    for (int w = 0; w < _inputWidth; w++)
                        for (int c = 0; c < 3; c++)
                            result[c * channelSize + h * _inputWidth + w] = ptr[h * _inputWidth * 3 + w * 3 + c];
            }
            normalized.Dispose();
            return result;
        }

        private List<SegmentationResult> Postprocess(int imgW, int imgH, float[] dets, float[] labels, float[] masks,
            float confThresh, float maskThresh)
        {
            float scaleX = (float)imgW / _inputWidth;
            float scaleY = (float)imgH / _inputHeight;
            var candidates = new List<(int idx, float score, int classId, float cx, float cy, float w, float h)>();

            for (int i = 0; i < _numQueries; i++)
            {
                float maxScore = 0;
                int bestClass = -1;
                for (int c = 0; c < _numClasses; c++)
                {
                    float s = Sigmoid(labels[i * _numClasses + c]);
                    if (s > maxScore) { maxScore = s; bestClass = c; }
                }
                if (maxScore < confThresh) continue;
                if (_hasBackground && bestClass == 0) continue;
                if (bestClass >= _userClasses.Count) continue;

                float cx = dets[i * 4];
                float cy = dets[i * 4 + 1];
                float w = dets[i * 4 + 2];
                float h = dets[i * 4 + 3];
                if (w <= 0.001f || h <= 0.001f) continue;

                candidates.Add((i, maxScore, bestClass, cx, cy, w, h));
            }

            var results = new List<SegmentationResult>();
            foreach (var (idx, score, classId, cx, cy, w, h) in candidates)
            {
                float absCx = cx * _inputWidth;
                float absCy = cy * _inputHeight;
                float absW = w * _inputWidth;
                float absH = h * _inputHeight;
                float x1 = absCx - absW / 2;
                float y1 = absCy - absH / 2;
                float x2 = absCx + absW / 2;
                float y2 = absCy + absH / 2;

                float origX1 = x1 * scaleX;
                float origY1 = y1 * scaleY;
                float origX2 = x2 * scaleX;
                float origY2 = y2 * scaleY;

                int rx1 = Math.Clamp((int)origX1, 0, imgW);
                int ry1 = Math.Clamp((int)origY1, 0, imgH);
                int rx2 = Math.Clamp((int)origX2, 0, imgW);
                int ry2 = Math.Clamp((int)origY2, 0, imgH);
                if (rx2 <= rx1 || ry2 <= ry1) continue;

                Rect box = new Rect(rx1, ry1, rx2 - rx1, ry2 - ry1);

                // 处理掩码
                float[] maskSigmoid = new float[_maskSize * _maskSize];
                int offset = idx * _maskSize * _maskSize;
                for (int i = 0; i < maskSigmoid.Length; i++)
                    maskSigmoid[i] = Sigmoid(masks[offset + i]);

                using (Mat rawMask = Mat.FromPixelData(_maskSize, _maskSize, MatType.CV_32FC1, maskSigmoid))
                using (Mat maskResized = new Mat())
                {
                    Cv2.Resize(rawMask, maskResized, new Size(imgW, imgH));
                    Mat binary = new Mat();
                    Cv2.Threshold(maskResized, binary, maskThresh, 255, ThresholdTypes.Binary);
                    binary.ConvertTo(binary, MatType.CV_8UC1);
                    results.Add(new SegmentationResult
                    {
                        BoundingBox = box,
                        ClassId = classId-1,
                        Label = _userClasses[classId-1],
                        Confidence = score,
                        Mask = binary
                    });
                }
            }
            return results.OrderByDescending(r => r.Confidence).ToList();
        }

        private static float Sigmoid(float x) => 1.0f / (1.0f + MathF.Exp(-x));

        public void Dispose()
        {
            _inputMemory?.Dispose();
            _outputDetsMemory?.Dispose();
            _outputLabelsMemory?.Dispose();
            _outputMasksMemory?.Dispose();
            _stream?.Dispose();
            _context?.Dispose();
            _engine?.Dispose();
            _runtime?.Dispose();
        }
    }
}

ONNX→TensorRT转换脚本(cmd)

trtexec.exe --onnx=inference_model.sim-nano-seg.onnx --saveEngine=rfdetr_det.engine --fp16

pause

trtexec.exe --onnx=inference_model.sim-nano.onnx --saveEngine=rfdetr_seg.engine --fp16

pause

Engine输入输出信息查看脚本(cmd)

cd D:\TensorRT10.0\TensorRT-10.13.0.35\lib

trtexec.exe --loadEngine=rfdetr_det.engine --dumpLayerInfo --exportLayerInfo=layer_info.json

pause

GPU频率锁定脚本(cmd)

nvidia-smi -lgc 5001

pause

TensorRTSharp相关技术文档:https://mp.weixin.qq.com/s/D0c6j5MmraJO4Eza7tWm1A

相关推荐
小程故事多_801 小时前
[大模型面试系列] 深度解析ReAct框架,大模型Agent的“思考+行动”底层逻辑
人工智能·react.js·面试·职场和发展·智能体
逍遥德1 小时前
AI时代,计算机专业大学生学习指南
java·javascript·人工智能·学习·ai编程
蝎子莱莱爱打怪2 小时前
Claude Code 省 Token 小妙招:RTK + Caveman 组合拳
前端·人工智能·后端
tanis_32 小时前
从 PDF 中精准提取表格、图片与公式:MinerU 结构化元素抽取的 3 种方案
人工智能
sali-tec2 小时前
C# 基于OpenCv的视觉工作流-章63-点廓距离
图像处理·人工智能·opencv·计算机视觉
Maiko Star2 小时前
让 AI 开口说话:Spring AI Alibaba 语音合成(TTS)实战
java·人工智能·spring·springai
机器学习之心2 小时前
多工况车速数据集训练LSTM-Attention用于车速预测,输出未来多个时间步车速,MATLAB代码
人工智能·matlab·lstm·lstm-attention·车速预测
耀耀切克闹灬2 小时前
初识LlamaIndex (了解LlamaIndex 高层概念)
人工智能
机器之心2 小时前
马斯克官宣xAI解散,22万张GPU算力租给Anthropic
人工智能·openai