Unity使用RVM实现实时人物视频抠像(无绿幕)

效果图

项目地址
https://github.com/PeterL1n/RobustVideoMatting

Unity Sentis已升级为Inference Engine

https://docs.unity3d.com/Packages/com.unity.ai.inference@2.2/manual/index.html

一些弯路

拿C# onnxruntime搞了一遍,用了cuda但感觉程序没跑在gpu上,只有2fps......

主要代码

csharp 复制代码
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using UnityEngine;
using UnityEngine.UI;
using Debug = UnityEngine.Debug;

public class RobustVideoMatting : MonoBehaviour
{
    [Header("模型设置")]
    public string modelPath = "rvm_mobilenetv3_fp32.onnx";
    public float downsampleRatio = 0.25f;

    [Header("输入设置")]
    public Texture2D inputTexture;
    public bool useWebcam = false;
    public int webcamWidth = 640;
    public int webcamHeight = 480;

    [Header("输出显示")]
    public Texture2D outputTexture;
    public RawImage rawImage;

    private InferenceSession session;
    private WebCamTexture webCamTexture;

    // 递归状态(隐藏状态)
    private Tensor<float>[] recurrentStates;
    private Tensor<float> downsampleRatioTensor;

    // 输出名称
    private readonly string[] outputNames = { "fgr", "pha", "r1o", "r2o", "r3o", "r4o" };
    private readonly string[] recurrentInputNames = { "r1i", "r2i", "r3i", "r4i" };
    private readonly string[] recurrentOutputNames = { "r1o", "r2o", "r3o", "r4o" };

    // 当前帧结果
    private Texture2D foregroundTexture;
    private Texture2D alphaTexture;
    private Texture2D resultTexture;

    void Start()
    {
        InitializeModel();
        InitializeRecurrentStates();

        if (useWebcam)
        {
            InitializeWebcam();
        } 
    }

    void InitializeModel()
    {
        try
        {
            // 加载模型
            var modelFullPath = Path.Combine(Application.streamingAssetsPath, modelPath);

            // 创建会话选项
            var sessionOptions = new SessionOptions();
            var aps = OrtEnv.Instance().GetAvailableProviders();
            foreach (var ap in aps)
            {
                Debug.Log(ap);
            }
            // 设置线程数
            sessionOptions.IntraOpNumThreads = 6;
            sessionOptions.InterOpNumThreads = 12;

            sessionOptions.EnableProfiling = true;
            sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE;
            sessionOptions.ProfileOutputPathPrefix = "onnxruntime_profile_" + DateTime.Now.ToString("yyyyMMdd_HHmmss");
            // 设置图优化级别
            sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;

            //sessionOptions.AppendExecutionProvider_DML();
            //sessionOptions.AppendExecutionProvider_CPU(); 
            sessionOptions.AppendExecutionProvider_CUDA();

            session = new InferenceSession(modelFullPath, sessionOptions);

            Debug.Log($"Robust Video Matting模型加载成功: {modelPath}");
        }
        catch (Exception e)
        {
            Debug.LogError($"模型初始化失败: {e.Message}");
        }
    }

    void InitializeRecurrentStates()
    {
        try
        {
            // 初始化递归状态 [1, 1, 1, 1]
            recurrentStates = new Tensor<float>[4];
            var zeroData = new float[1] { 0f };
            var shape = new int[] { 1, 1, 1, 1 };

            for (int i = 0; i < 4; i++)
            {
                recurrentStates[i] = new DenseTensor<float>(zeroData, shape);
            }

            // 初始化downsample_ratio
            var ratioData = new float[] { downsampleRatio };
            downsampleRatioTensor = new DenseTensor<float>(ratioData, new int[] { 1 });

            Debug.Log("递归状态初始化完成");
        }
        catch (Exception e)
        {
            Debug.LogError($"递归状态初始化失败: {e.Message}");
        }
    }

    void InitializeWebcam()
    {
        WebCamDevice[] devices = WebCamTexture.devices;
        if (devices.Length > 0)
        {
            webCamTexture = new WebCamTexture(devices[0].name, webcamWidth, webcamHeight, 30);
            webCamTexture.Play();
            Debug.Log($"启动摄像头: {devices[0].name}");
        }
        else
        {
            Debug.LogWarning("未找到摄像头设备");
        }
    }

    Texture2D sourceTexture = null;
    void Update()
    {
        if (session == null) return;

        if (Input.GetMouseButtonDown(0))
        {
            TestFrame();
        }

        // 获取输入图像
        if (useWebcam && webCamTexture != null && webCamTexture.isPlaying)
        {
            sourceTexture = WebCamTextureToTexture2D(webCamTexture);
        } 
        if (sourceTexture != null)
        {
            // 处理当前帧
            ProcessFrame(sourceTexture);

            // 显示结果
            if (resultTexture != null)
            {
                rawImage.texture = resultTexture;
            }

            // 清理临时纹理
            if (useWebcam && sourceTexture != null)
            {
                DestroyImmediate(sourceTexture);
            }
        }
    }

    void TestFrame()
    {
        if (inputTexture != null)
        {
            Stopwatch stopwatch = new Stopwatch();
            stopwatch.Start();
            // 处理当前帧
            ProcessFrame(inputTexture);
            stopwatch.Stop();
            long lastInferenceTime = stopwatch.ElapsedMilliseconds;
            // 输出耗时信息
            Debug.Log($"推理完成!总耗时: {lastInferenceTime}ms");

            // 显示结果
            if (resultTexture != null)
            {
                rawImage.texture = resultTexture;
            }
        }
    }

    void ProcessFrame(Texture2D sourceTexture)
    {
        try
        {
            // 准备输入张量
            var inputTensor = PrepareInputTensor(sourceTexture);

            // 创建输入列表
            var inputs = new List<NamedOnnxValue>
            {
                NamedOnnxValue.CreateFromTensor("src", inputTensor),
                NamedOnnxValue.CreateFromTensor("downsample_ratio", downsampleRatioTensor)
            };

            // 添加递归状态输入
            for (int i = 0; i < recurrentInputNames.Length; i++)
            {
                inputs.Add(NamedOnnxValue.CreateFromTensor(recurrentInputNames[i], recurrentStates[i]));
            }

            // 运行推理
            using (var results = session.Run(inputs))
            {
                // 获取输出
                var outputs = ProcessOutputs(results, sourceTexture.width, sourceTexture.height);

                // 更新递归状态
                UpdateRecurrentStates(results);
            }
        }
        catch (Exception e)
        {
            Debug.LogError($"帧处理失败: {e.Message}");
        }
    }

    Tensor<float> PrepareInputTensor(Texture2D texture)
    {
        // 调整尺寸到模型期望的输入大小
        int targetWidth = 512;  // 根据模型调整
        int targetHeight = 512; // 根据模型调整

        var resizedTexture = ResizeTexture(texture, targetWidth, targetHeight);
        Color32[] pixels = resizedTexture.GetPixels32();

        // 创建张量 [1, 3, H, W]
        float[] dataArray = new float[1 * 3 * targetHeight * targetWidth];
        int[] shapeArray = new int[] { 1, 3, targetHeight, targetWidth };

        for (int y = 0; y < targetHeight; y++)
        {
            for (int x = 0; x < targetWidth; x++)
            {
                int index = y * targetWidth + x;
                var pixel = pixels[index];

                // 归一化到 [0, 1]
                int rIndex = 0 * targetHeight * targetWidth + y * targetWidth + x;
                int gIndex = 1 * targetHeight * targetWidth + y * targetWidth + x;
                int bIndex = 2 * targetHeight * targetWidth + y * targetWidth + x;

                dataArray[rIndex] = pixel.r / 255.0f;
                dataArray[gIndex] = pixel.g / 255.0f;
                dataArray[bIndex] = pixel.b / 255.0f;
            }
        }

        // 清理临时纹理
        DestroyImmediate(resizedTexture);

        return new DenseTensor<float>(dataArray, shapeArray, false);
    }

    bool ProcessOutputs(IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results, int originalWidth, int originalHeight)
    {
        try
        {
            // 获取前景 (fgr) 和 Alpha (pha) 输出
            var fgrValue = results.FirstOrDefault(r => r.Name == "fgr");
            var phaValue = results.FirstOrDefault(r => r.Name == "pha");

            if (fgrValue == null || phaValue == null)
            {
                Debug.LogError("缺少必要的输出: fgr 或 pha");
                return false;
            }

            var fgrTensor = fgrValue.AsTensor<float>();
            var phaTensor = phaValue.AsTensor<float>();

            // 处理前景纹理
            foregroundTexture = TensorToTexture(fgrTensor, originalWidth, originalHeight, false);

            // 处理Alpha纹理
            alphaTexture = TensorToTexture(phaTensor, originalWidth, originalHeight, true);

            // 创建合成结果
            resultTexture = ComposeResult(foregroundTexture, alphaTexture);

            return true;
        }
        catch (Exception e)
        {
            Debug.LogError($"输出处理失败: {e.Message}");
            return false;
        }
    }

    void UpdateRecurrentStates(IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results)
    {
        // 更新递归状态为当前输出的递归状态
        for (int i = 0; i < recurrentOutputNames.Length; i++)
        {
            var stateValue = results.FirstOrDefault(r => r.Name == recurrentOutputNames[i]);
            if (stateValue != null)
            {
                recurrentStates[i] = stateValue.AsTensor<float>();
            }
        }
    }

    Texture2D TensorToTexture(Tensor<float> tensor, int targetWidth, int targetHeight, bool isAlpha)
    {
        var dimensions = tensor.Dimensions.ToArray();
        if (dimensions.Length < 4)
        {
            Debug.LogError($"张量维度不足: {dimensions.Length}");
            return null;
        }

        int tensorHeight = dimensions[2];
        int tensorWidth = dimensions[3];
        int channels = dimensions[1];

        // 创建临时纹理
        var tempTexture = new Texture2D(tensorWidth, tensorHeight,
            isAlpha ? TextureFormat.RFloat : TextureFormat.RGB24, false);

        if (isAlpha && channels == 1)
        {
            // Alpha通道处理
            for (int y = 0; y < tensorHeight; y++)
            {
                for (int x = 0; x < tensorWidth; x++)
                {
                    float alpha = Mathf.Clamp(tensor[0, 0, y, x], 0f, 1f);
                    tempTexture.SetPixel(x, y, new Color(alpha, alpha, alpha));
                }
            }
        }
        else if (!isAlpha && channels == 3)
        {
            // RGB图像处理
            for (int y = 0; y < tensorHeight; y++)
            {
                for (int x = 0; x < tensorWidth; x++)
                {
                    float r = Mathf.Clamp(tensor[0, 0, y, x], 0f, 1f);
                    float g = Mathf.Clamp(tensor[0, 1, y, x], 0f, 1f);
                    float b = Mathf.Clamp(tensor[0, 2, y, x], 0f, 1f);
                    tempTexture.SetPixel(x, y, new Color(r, g, b));
                }
            }
        }

        tempTexture.Apply();

        // 调整到目标尺寸
        var finalTexture = ResizeTexture(tempTexture, targetWidth, targetHeight);
        DestroyImmediate(tempTexture);

        return finalTexture;
    }

    Texture2D ComposeResult(Texture2D foreground, Texture2D alpha)
    {
        if (foreground == null || alpha == null ||
            foreground.width != alpha.width || foreground.height != alpha.height)
        {
            Debug.LogError("前景和Alpha纹理尺寸不匹配");
            return null;
        }

        var result = new Texture2D(foreground.width, foreground.height, TextureFormat.RGBA32, false);
        var fgPixels = foreground.GetPixels();
        var alphaPixels = alpha.GetPixels();

        for (int i = 0; i < fgPixels.Length; i++)
        {
            Color fgColor = fgPixels[i];
            Color alphaColor = alphaPixels[i];

            // 使用Alpha值设置透明度
            fgColor.a = alphaColor.r; // Alpha纹理是灰度图,r通道就是alpha值
            result.SetPixel(i % result.width, i / result.width, fgColor);
        }

        result.Apply();
        return result;
    }

    Texture2D ResizeTexture(Texture2D source, int newWidth, int newHeight)
    {
        var rt = RenderTexture.GetTemporary(newWidth, newHeight);
        Graphics.Blit(source, rt);

        var result = new Texture2D(newWidth, newHeight, TextureFormat.RGBA32, false);
        RenderTexture.active = rt;
        result.ReadPixels(new Rect(0, 0, newWidth, newHeight), 0, 0);
        result.Apply();

        RenderTexture.ReleaseTemporary(rt);
        return result;
    }

    Texture2D WebCamTextureToTexture2D(WebCamTexture webCamTexture)
    {
        Texture2D tex = new Texture2D(webCamTexture.width, webCamTexture.height, TextureFormat.RGBA32, false);
        tex.SetPixels32(webCamTexture.GetPixels32());
        tex.Apply();
        return tex;
    }

    /// <summary>
    /// 重置递归状态(开始新的视频序列时调用)
    /// </summary>
    public void ResetRecurrentStates()
    {
        InitializeRecurrentStates();
        Debug.Log("递归状态已重置");
    }

    /// <summary>
    /// 设置下采样比例
    /// </summary>
    public void SetDownsampleRatio(float ratio)
    {
        downsampleRatio = Mathf.Clamp(ratio, 0.1f, 1.0f);

        // 更新downsample_ratio值
        var ratioData = new float[] { downsampleRatio };
        downsampleRatioTensor = new DenseTensor<float>(ratioData, new int[] { 1 });

        Debug.Log($"下采样比例设置为: {downsampleRatio}");
    }

    /// <summary>
    /// 获取Alpha遮罩纹理
    /// </summary>
    public Texture2D GetAlphaTexture()
    {
        return alphaTexture;
    }

    /// <summary>
    /// 获取前景纹理
    /// </summary>
    public Texture2D GetForegroundTexture()
    {
        return foregroundTexture;
    }

    /// <summary>
    /// 获取结果纹理
    /// </summary>
    public Texture2D GetResultTexture()
    {
        return resultTexture;
    }

    void OnDestroy()
    {
        session?.Dispose();

        if (webCamTexture != null && webCamTexture.isPlaying)
        {
            webCamTexture.Stop();
        }
    }
}

Unity实现效果

主要代码

csharp 复制代码
using System.Collections;

using UnityEngine;
using UnityEngine.UI;

//참고
//reference1: https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference.md
//reference2: https://docs.unity3d.com/Packages/com.unity.sentis@2.1/manual/use-model-output.html
public class VideoMatting : MonoBehaviour
{
    public RenderTexture OutputTexture => ouputCamera?.targetTexture;

    [SerializeField] private RenderTexture sourceTexture; //source texture 에 원본 데이터 넣기
    [SerializeField] private Unity.InferenceEngine.ModelAsset modelAsset;
    [SerializeField] private Material alphaMaterial;
    [SerializeField] private RawImage sketchRawImage;
    [SerializeField] private Camera ouputCamera;
    [SerializeField] private Vector2 frameResolution = new Vector2(1920, 1080);
    [SerializeField] private RawImage debugRawImage;

    private RenderTexture _foregroundTexture;
    private RenderTexture _alphaTexture;
    private Unity.InferenceEngine.Worker _worker;
    private Unity.InferenceEngine.Model _runtimeModel;
    private RenderTexture _resultRenderTexture;
    private Unity.InferenceEngine.Tensor<float> _r1, _r2, _r3, _r4, _inputTensor, _downsampleRatioTensor;
    private Vector2 _previousResolution;

    void Awake()
    {
        //initialize model
        _runtimeModel = Unity.InferenceEngine.ModelLoader.Load(modelAsset);
        _worker = new Unity.InferenceEngine.Worker(_runtimeModel, Unity.InferenceEngine.BackendType.GPUCompute);
        _r1 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
        _r2 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
        _r3 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
        _r4 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
        _inputTensor = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 3, 1, 1));
        _downsampleRatioTensor = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1), new float[] { 1.0f });
        ouputCamera.backgroundColor = new Color(0, 0, 0, 0);
        sketchRawImage.material = alphaMaterial;
    }

    void Start()
    {
        StartCoroutine(ProcessVideoMatting());
    }
    
    void Update()
    {
        UpdateResultRenderTexture();
        UpdateDebugRawImage();
    }

    public void SetSourceTexture(RenderTexture sourceTexture)
    {
        this.sourceTexture = sourceTexture;
    }

    void UpdateResultRenderTexture()
    {
        bool changedResolution = _previousResolution.x != frameResolution.x || _previousResolution.y != frameResolution.y;
        GetOrCreateRenderTexture(ref _resultRenderTexture, (int)frameResolution.x, (int)frameResolution.y, "ResultRT", changedResolution);
        if (ouputCamera != null && ouputCamera.targetTexture == null)
            ouputCamera.targetTexture = _resultRenderTexture;
        _previousResolution = frameResolution;
    }

    void UpdateDebugRawImage()
    {
        if(debugRawImage != null && OutputTexture != null)
            debugRawImage.texture = OutputTexture;
    }

    IEnumerator ProcessVideoMatting()
    {
        while (true)
        {
            if (sourceTexture == null)
            {
                yield return null;
                continue;
            }

            int textureWidth = sourceTexture.width;
            int textureHeight = sourceTexture.height;


            float optimalRatio = CalculateOptimalDownsampleRatio(textureWidth, textureHeight); // get downsaple ratio
            var inputShape = new Unity.InferenceEngine.TensorShape(1, 3, textureHeight, textureWidth); // batch, channel, height, width
            if (_inputTensor == null || !_inputTensor.shape.Equals(inputShape))
            {
                _inputTensor?.Dispose();
                _inputTensor = new Unity.InferenceEngine.Tensor<float>(inputShape);
            }
            Unity.InferenceEngine.TextureConverter.ToTensor(sourceTexture, _inputTensor, new Unity.InferenceEngine.TextureTransform());
            _downsampleRatioTensor[0] = optimalRatio;

            _worker.SetInput("src", _inputTensor);
            _worker.SetInput("r1i", _r1);
            _worker.SetInput("r2i", _r2);
            _worker.SetInput("r3i", _r3);
            _worker.SetInput("r4i", _r4);
            _worker.SetInput("downsample_ratio", _downsampleRatioTensor);
            _worker.Schedule();

            yield return null;

            var foregroundTensor = _worker.PeekOutput("fgr") as Unity.InferenceEngine.Tensor<float>;
            var alphaTensor = _worker.PeekOutput("pha") as Unity.InferenceEngine.Tensor<float>;

            GetOrCreateRenderTexture(ref _foregroundTexture, textureWidth, textureHeight, "ForegroundRT");
            GetOrCreateRenderTexture(ref _alphaTexture, textureWidth, textureHeight, "AlphaRT");

            var fgrAwaiter = foregroundTensor.ReadbackAndCloneAsync().GetAwaiter();
            var alphaAwaiter = alphaTensor.ReadbackAndCloneAsync().GetAwaiter();

            while (!fgrAwaiter.IsCompleted || !alphaAwaiter.IsCompleted)
            {
                yield return null;
            }

            using (var foregroundOut = fgrAwaiter.GetResult())
            using (var alphaOut = alphaAwaiter.GetResult())
            {
                Unity.InferenceEngine.TextureConverter.RenderToTexture(foregroundTensor, _foregroundTexture);
                Unity.InferenceEngine.TextureConverter.RenderToTexture(alphaTensor, _alphaTexture);
            }
            
            try
            {
                if(sketchRawImage != null)
                {
                    sketchRawImage.material.SetTexture("_FgrTex", _foregroundTexture);
                    sketchRawImage.material.SetTexture("_PhaTex", _alphaTexture);
                }
            }
            catch (System.Exception e)
            {
                Debug.LogError("NOTE: Please make sure the RawImage has a material using the VideoMatting shader. Exception: " + e.Message);
            }
        }
    }

    private RenderTexture GetOrCreateRenderTexture(ref RenderTexture renderTexture, int width, int height, string name, bool forceCreate = false)
    {
        if (renderTexture == null || renderTexture.width != width || renderTexture.height != height || forceCreate)
        {
            if (renderTexture != null)
            {
                renderTexture.Release();
                DestroyImmediate(renderTexture);
            }

            renderTexture = new RenderTexture(width, height, 24, RenderTextureFormat.ARGB32);
            renderTexture.name = name;
            renderTexture.Create();
        }

        return renderTexture;
    }


    // | Resolution    | Portrait      | Full-Body      |
    // | ------------- | ------------- | -------------- |
    // | <= 512x512    | 1             | 1              |
    // | 1280x720      | 0.375         | 0.6            |
    // | 1920x1080     | 0.25          | 0.4            |
    // | 3840x2160     | 0.125         | 0.2            |
    // 5번의 다운샘플링이 모델 내에서 이루어지는데, 아래의 값으로 다운샘플링이 5번 이루어질때 홀수값이 나와선 안됨.
    // width height 변경시 참고
    private float CalculateOptimalDownsampleRatio(int width, int height)
    {
        int imagePixelCount = width * height;

        if (imagePixelCount <= 512 * 512)
        {
            return 1.0f;     // 원본 크기 유지
        }
        else if (imagePixelCount <= 1280 * 720)
        {
            return 0.6f;
        }
        else if (imagePixelCount <= 1920 * 1080)
        {
            return 0.4f;
        }
        else if (imagePixelCount <= 3840 * 2160)
        {
            return 0.2f;
        }
        else
        {
            return 0.1f;
        }
    }

    void OnDestroy()
    {
        _r1?.Dispose();
        _r2?.Dispose();
        _r3?.Dispose();
        _r4?.Dispose();
        _inputTensor?.Dispose();
        _downsampleRatioTensor?.Dispose();
        _worker?.Dispose();
    }
}

弯路地址 https://github.com/xue-fei/RobustVideoMatting-unity.git

工程地址

https://github.com/xue-fei/robustvideomatting-unity-sentis.git

fork from
https://github.com/realmorm/unity-sentis-robustvideomatting-example.git

相关推荐
音视频牛哥5 小时前
从云平台到系统内核:SmartMediakit如何重构实时视频系统
计算机视觉·音视频·gb28181对接·rtsp播放器rtmp播放器·smartmediakit·智能机器人低延迟播放方案·rtmp摄像头同屏推流
HahaGiver6668 小时前
Unity与Android原生交互开发入门篇 - 打开Android的设置
android·java·unity·游戏引擎·android studio
ACP广源盛1392462567316 小时前
(ACP广源盛)GSV6172---MIPI/LVDS 信号转换为 Type-C/DisplayPort 1.4/HDMI 2.0 并集成嵌入式 MCU
c语言·开发语言·单片机·嵌入式硬件·音视频
野奔在山外的猫17 小时前
【解决】解决方案内存在对应命名空间,但程序引用显示无该命名空间问题
unity
B0URNE17 小时前
【Unity基础详解】(5)Unity核心:Coroutines协程
unity·游戏引擎
花姐夫Jun17 小时前
基于Vue+Python+Orange Pi Zero3的完整视频监控方案
vue.js·python·音视频
野奔在山外的猫19 小时前
【案例】程序化脚本生成
unity
xiaotao13121 小时前
unity hub在ubuntu 22.0.4上启动卡住
ubuntu·unity·游戏引擎
HyperAI超神经21 小时前
在线教程丨端侧TTS新SOTA!NeuTTS-Air基于0.5B模型实现3秒音频克隆
人工智能·深度学习·机器学习·音视频·tts·音频克隆·neutts-air