效果图

项目地址
https://github.com/PeterL1n/RobustVideoMatting
Unity Sentis已升级为Inference Engine

https://docs.unity3d.com/Packages/com.unity.ai.inference@2.2/manual/index.html
一些弯路
拿C# onnxruntime搞了一遍,用了cuda但感觉程序没跑在gpu上,只有2fps......
主要代码
csharp
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using UnityEngine;
using UnityEngine.UI;
using Debug = UnityEngine.Debug;
public class RobustVideoMatting : MonoBehaviour
{
[Header("模型设置")]
public string modelPath = "rvm_mobilenetv3_fp32.onnx";
public float downsampleRatio = 0.25f;
[Header("输入设置")]
public Texture2D inputTexture;
public bool useWebcam = false;
public int webcamWidth = 640;
public int webcamHeight = 480;
[Header("输出显示")]
public Texture2D outputTexture;
public RawImage rawImage;
private InferenceSession session;
private WebCamTexture webCamTexture;
// 递归状态(隐藏状态)
private Tensor<float>[] recurrentStates;
private Tensor<float> downsampleRatioTensor;
// 输出名称
private readonly string[] outputNames = { "fgr", "pha", "r1o", "r2o", "r3o", "r4o" };
private readonly string[] recurrentInputNames = { "r1i", "r2i", "r3i", "r4i" };
private readonly string[] recurrentOutputNames = { "r1o", "r2o", "r3o", "r4o" };
// 当前帧结果
private Texture2D foregroundTexture;
private Texture2D alphaTexture;
private Texture2D resultTexture;
void Start()
{
InitializeModel();
InitializeRecurrentStates();
if (useWebcam)
{
InitializeWebcam();
}
}
void InitializeModel()
{
try
{
// 加载模型
var modelFullPath = Path.Combine(Application.streamingAssetsPath, modelPath);
// 创建会话选项
var sessionOptions = new SessionOptions();
var aps = OrtEnv.Instance().GetAvailableProviders();
foreach (var ap in aps)
{
Debug.Log(ap);
}
// 设置线程数
sessionOptions.IntraOpNumThreads = 6;
sessionOptions.InterOpNumThreads = 12;
sessionOptions.EnableProfiling = true;
sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE;
sessionOptions.ProfileOutputPathPrefix = "onnxruntime_profile_" + DateTime.Now.ToString("yyyyMMdd_HHmmss");
// 设置图优化级别
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
//sessionOptions.AppendExecutionProvider_DML();
//sessionOptions.AppendExecutionProvider_CPU();
sessionOptions.AppendExecutionProvider_CUDA();
session = new InferenceSession(modelFullPath, sessionOptions);
Debug.Log($"Robust Video Matting模型加载成功: {modelPath}");
}
catch (Exception e)
{
Debug.LogError($"模型初始化失败: {e.Message}");
}
}
void InitializeRecurrentStates()
{
try
{
// 初始化递归状态 [1, 1, 1, 1]
recurrentStates = new Tensor<float>[4];
var zeroData = new float[1] { 0f };
var shape = new int[] { 1, 1, 1, 1 };
for (int i = 0; i < 4; i++)
{
recurrentStates[i] = new DenseTensor<float>(zeroData, shape);
}
// 初始化downsample_ratio
var ratioData = new float[] { downsampleRatio };
downsampleRatioTensor = new DenseTensor<float>(ratioData, new int[] { 1 });
Debug.Log("递归状态初始化完成");
}
catch (Exception e)
{
Debug.LogError($"递归状态初始化失败: {e.Message}");
}
}
void InitializeWebcam()
{
WebCamDevice[] devices = WebCamTexture.devices;
if (devices.Length > 0)
{
webCamTexture = new WebCamTexture(devices[0].name, webcamWidth, webcamHeight, 30);
webCamTexture.Play();
Debug.Log($"启动摄像头: {devices[0].name}");
}
else
{
Debug.LogWarning("未找到摄像头设备");
}
}
Texture2D sourceTexture = null;
void Update()
{
if (session == null) return;
if (Input.GetMouseButtonDown(0))
{
TestFrame();
}
// 获取输入图像
if (useWebcam && webCamTexture != null && webCamTexture.isPlaying)
{
sourceTexture = WebCamTextureToTexture2D(webCamTexture);
}
if (sourceTexture != null)
{
// 处理当前帧
ProcessFrame(sourceTexture);
// 显示结果
if (resultTexture != null)
{
rawImage.texture = resultTexture;
}
// 清理临时纹理
if (useWebcam && sourceTexture != null)
{
DestroyImmediate(sourceTexture);
}
}
}
void TestFrame()
{
if (inputTexture != null)
{
Stopwatch stopwatch = new Stopwatch();
stopwatch.Start();
// 处理当前帧
ProcessFrame(inputTexture);
stopwatch.Stop();
long lastInferenceTime = stopwatch.ElapsedMilliseconds;
// 输出耗时信息
Debug.Log($"推理完成!总耗时: {lastInferenceTime}ms");
// 显示结果
if (resultTexture != null)
{
rawImage.texture = resultTexture;
}
}
}
void ProcessFrame(Texture2D sourceTexture)
{
try
{
// 准备输入张量
var inputTensor = PrepareInputTensor(sourceTexture);
// 创建输入列表
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("src", inputTensor),
NamedOnnxValue.CreateFromTensor("downsample_ratio", downsampleRatioTensor)
};
// 添加递归状态输入
for (int i = 0; i < recurrentInputNames.Length; i++)
{
inputs.Add(NamedOnnxValue.CreateFromTensor(recurrentInputNames[i], recurrentStates[i]));
}
// 运行推理
using (var results = session.Run(inputs))
{
// 获取输出
var outputs = ProcessOutputs(results, sourceTexture.width, sourceTexture.height);
// 更新递归状态
UpdateRecurrentStates(results);
}
}
catch (Exception e)
{
Debug.LogError($"帧处理失败: {e.Message}");
}
}
Tensor<float> PrepareInputTensor(Texture2D texture)
{
// 调整尺寸到模型期望的输入大小
int targetWidth = 512; // 根据模型调整
int targetHeight = 512; // 根据模型调整
var resizedTexture = ResizeTexture(texture, targetWidth, targetHeight);
Color32[] pixels = resizedTexture.GetPixels32();
// 创建张量 [1, 3, H, W]
float[] dataArray = new float[1 * 3 * targetHeight * targetWidth];
int[] shapeArray = new int[] { 1, 3, targetHeight, targetWidth };
for (int y = 0; y < targetHeight; y++)
{
for (int x = 0; x < targetWidth; x++)
{
int index = y * targetWidth + x;
var pixel = pixels[index];
// 归一化到 [0, 1]
int rIndex = 0 * targetHeight * targetWidth + y * targetWidth + x;
int gIndex = 1 * targetHeight * targetWidth + y * targetWidth + x;
int bIndex = 2 * targetHeight * targetWidth + y * targetWidth + x;
dataArray[rIndex] = pixel.r / 255.0f;
dataArray[gIndex] = pixel.g / 255.0f;
dataArray[bIndex] = pixel.b / 255.0f;
}
}
// 清理临时纹理
DestroyImmediate(resizedTexture);
return new DenseTensor<float>(dataArray, shapeArray, false);
}
bool ProcessOutputs(IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results, int originalWidth, int originalHeight)
{
try
{
// 获取前景 (fgr) 和 Alpha (pha) 输出
var fgrValue = results.FirstOrDefault(r => r.Name == "fgr");
var phaValue = results.FirstOrDefault(r => r.Name == "pha");
if (fgrValue == null || phaValue == null)
{
Debug.LogError("缺少必要的输出: fgr 或 pha");
return false;
}
var fgrTensor = fgrValue.AsTensor<float>();
var phaTensor = phaValue.AsTensor<float>();
// 处理前景纹理
foregroundTexture = TensorToTexture(fgrTensor, originalWidth, originalHeight, false);
// 处理Alpha纹理
alphaTexture = TensorToTexture(phaTensor, originalWidth, originalHeight, true);
// 创建合成结果
resultTexture = ComposeResult(foregroundTexture, alphaTexture);
return true;
}
catch (Exception e)
{
Debug.LogError($"输出处理失败: {e.Message}");
return false;
}
}
void UpdateRecurrentStates(IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results)
{
// 更新递归状态为当前输出的递归状态
for (int i = 0; i < recurrentOutputNames.Length; i++)
{
var stateValue = results.FirstOrDefault(r => r.Name == recurrentOutputNames[i]);
if (stateValue != null)
{
recurrentStates[i] = stateValue.AsTensor<float>();
}
}
}
Texture2D TensorToTexture(Tensor<float> tensor, int targetWidth, int targetHeight, bool isAlpha)
{
var dimensions = tensor.Dimensions.ToArray();
if (dimensions.Length < 4)
{
Debug.LogError($"张量维度不足: {dimensions.Length}");
return null;
}
int tensorHeight = dimensions[2];
int tensorWidth = dimensions[3];
int channels = dimensions[1];
// 创建临时纹理
var tempTexture = new Texture2D(tensorWidth, tensorHeight,
isAlpha ? TextureFormat.RFloat : TextureFormat.RGB24, false);
if (isAlpha && channels == 1)
{
// Alpha通道处理
for (int y = 0; y < tensorHeight; y++)
{
for (int x = 0; x < tensorWidth; x++)
{
float alpha = Mathf.Clamp(tensor[0, 0, y, x], 0f, 1f);
tempTexture.SetPixel(x, y, new Color(alpha, alpha, alpha));
}
}
}
else if (!isAlpha && channels == 3)
{
// RGB图像处理
for (int y = 0; y < tensorHeight; y++)
{
for (int x = 0; x < tensorWidth; x++)
{
float r = Mathf.Clamp(tensor[0, 0, y, x], 0f, 1f);
float g = Mathf.Clamp(tensor[0, 1, y, x], 0f, 1f);
float b = Mathf.Clamp(tensor[0, 2, y, x], 0f, 1f);
tempTexture.SetPixel(x, y, new Color(r, g, b));
}
}
}
tempTexture.Apply();
// 调整到目标尺寸
var finalTexture = ResizeTexture(tempTexture, targetWidth, targetHeight);
DestroyImmediate(tempTexture);
return finalTexture;
}
Texture2D ComposeResult(Texture2D foreground, Texture2D alpha)
{
if (foreground == null || alpha == null ||
foreground.width != alpha.width || foreground.height != alpha.height)
{
Debug.LogError("前景和Alpha纹理尺寸不匹配");
return null;
}
var result = new Texture2D(foreground.width, foreground.height, TextureFormat.RGBA32, false);
var fgPixels = foreground.GetPixels();
var alphaPixels = alpha.GetPixels();
for (int i = 0; i < fgPixels.Length; i++)
{
Color fgColor = fgPixels[i];
Color alphaColor = alphaPixels[i];
// 使用Alpha值设置透明度
fgColor.a = alphaColor.r; // Alpha纹理是灰度图,r通道就是alpha值
result.SetPixel(i % result.width, i / result.width, fgColor);
}
result.Apply();
return result;
}
Texture2D ResizeTexture(Texture2D source, int newWidth, int newHeight)
{
var rt = RenderTexture.GetTemporary(newWidth, newHeight);
Graphics.Blit(source, rt);
var result = new Texture2D(newWidth, newHeight, TextureFormat.RGBA32, false);
RenderTexture.active = rt;
result.ReadPixels(new Rect(0, 0, newWidth, newHeight), 0, 0);
result.Apply();
RenderTexture.ReleaseTemporary(rt);
return result;
}
Texture2D WebCamTextureToTexture2D(WebCamTexture webCamTexture)
{
Texture2D tex = new Texture2D(webCamTexture.width, webCamTexture.height, TextureFormat.RGBA32, false);
tex.SetPixels32(webCamTexture.GetPixels32());
tex.Apply();
return tex;
}
/// <summary>
/// 重置递归状态(开始新的视频序列时调用)
/// </summary>
public void ResetRecurrentStates()
{
InitializeRecurrentStates();
Debug.Log("递归状态已重置");
}
/// <summary>
/// 设置下采样比例
/// </summary>
public void SetDownsampleRatio(float ratio)
{
downsampleRatio = Mathf.Clamp(ratio, 0.1f, 1.0f);
// 更新downsample_ratio值
var ratioData = new float[] { downsampleRatio };
downsampleRatioTensor = new DenseTensor<float>(ratioData, new int[] { 1 });
Debug.Log($"下采样比例设置为: {downsampleRatio}");
}
/// <summary>
/// 获取Alpha遮罩纹理
/// </summary>
public Texture2D GetAlphaTexture()
{
return alphaTexture;
}
/// <summary>
/// 获取前景纹理
/// </summary>
public Texture2D GetForegroundTexture()
{
return foregroundTexture;
}
/// <summary>
/// 获取结果纹理
/// </summary>
public Texture2D GetResultTexture()
{
return resultTexture;
}
void OnDestroy()
{
session?.Dispose();
if (webCamTexture != null && webCamTexture.isPlaying)
{
webCamTexture.Stop();
}
}
}
Unity实现效果

主要代码
csharp
using System.Collections;
using UnityEngine;
using UnityEngine.UI;
//참고
//reference1: https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference.md
//reference2: https://docs.unity3d.com/Packages/com.unity.sentis@2.1/manual/use-model-output.html
public class VideoMatting : MonoBehaviour
{
public RenderTexture OutputTexture => ouputCamera?.targetTexture;
[SerializeField] private RenderTexture sourceTexture; //source texture 에 원본 데이터 넣기
[SerializeField] private Unity.InferenceEngine.ModelAsset modelAsset;
[SerializeField] private Material alphaMaterial;
[SerializeField] private RawImage sketchRawImage;
[SerializeField] private Camera ouputCamera;
[SerializeField] private Vector2 frameResolution = new Vector2(1920, 1080);
[SerializeField] private RawImage debugRawImage;
private RenderTexture _foregroundTexture;
private RenderTexture _alphaTexture;
private Unity.InferenceEngine.Worker _worker;
private Unity.InferenceEngine.Model _runtimeModel;
private RenderTexture _resultRenderTexture;
private Unity.InferenceEngine.Tensor<float> _r1, _r2, _r3, _r4, _inputTensor, _downsampleRatioTensor;
private Vector2 _previousResolution;
void Awake()
{
//initialize model
_runtimeModel = Unity.InferenceEngine.ModelLoader.Load(modelAsset);
_worker = new Unity.InferenceEngine.Worker(_runtimeModel, Unity.InferenceEngine.BackendType.GPUCompute);
_r1 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
_r2 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
_r3 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
_r4 = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 1, 1, 1), new float[] { 0.0f });
_inputTensor = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1, 3, 1, 1));
_downsampleRatioTensor = new Unity.InferenceEngine.Tensor<float>(new Unity.InferenceEngine.TensorShape(1), new float[] { 1.0f });
ouputCamera.backgroundColor = new Color(0, 0, 0, 0);
sketchRawImage.material = alphaMaterial;
}
void Start()
{
StartCoroutine(ProcessVideoMatting());
}
void Update()
{
UpdateResultRenderTexture();
UpdateDebugRawImage();
}
public void SetSourceTexture(RenderTexture sourceTexture)
{
this.sourceTexture = sourceTexture;
}
void UpdateResultRenderTexture()
{
bool changedResolution = _previousResolution.x != frameResolution.x || _previousResolution.y != frameResolution.y;
GetOrCreateRenderTexture(ref _resultRenderTexture, (int)frameResolution.x, (int)frameResolution.y, "ResultRT", changedResolution);
if (ouputCamera != null && ouputCamera.targetTexture == null)
ouputCamera.targetTexture = _resultRenderTexture;
_previousResolution = frameResolution;
}
void UpdateDebugRawImage()
{
if(debugRawImage != null && OutputTexture != null)
debugRawImage.texture = OutputTexture;
}
IEnumerator ProcessVideoMatting()
{
while (true)
{
if (sourceTexture == null)
{
yield return null;
continue;
}
int textureWidth = sourceTexture.width;
int textureHeight = sourceTexture.height;
float optimalRatio = CalculateOptimalDownsampleRatio(textureWidth, textureHeight); // get downsaple ratio
var inputShape = new Unity.InferenceEngine.TensorShape(1, 3, textureHeight, textureWidth); // batch, channel, height, width
if (_inputTensor == null || !_inputTensor.shape.Equals(inputShape))
{
_inputTensor?.Dispose();
_inputTensor = new Unity.InferenceEngine.Tensor<float>(inputShape);
}
Unity.InferenceEngine.TextureConverter.ToTensor(sourceTexture, _inputTensor, new Unity.InferenceEngine.TextureTransform());
_downsampleRatioTensor[0] = optimalRatio;
_worker.SetInput("src", _inputTensor);
_worker.SetInput("r1i", _r1);
_worker.SetInput("r2i", _r2);
_worker.SetInput("r3i", _r3);
_worker.SetInput("r4i", _r4);
_worker.SetInput("downsample_ratio", _downsampleRatioTensor);
_worker.Schedule();
yield return null;
var foregroundTensor = _worker.PeekOutput("fgr") as Unity.InferenceEngine.Tensor<float>;
var alphaTensor = _worker.PeekOutput("pha") as Unity.InferenceEngine.Tensor<float>;
GetOrCreateRenderTexture(ref _foregroundTexture, textureWidth, textureHeight, "ForegroundRT");
GetOrCreateRenderTexture(ref _alphaTexture, textureWidth, textureHeight, "AlphaRT");
var fgrAwaiter = foregroundTensor.ReadbackAndCloneAsync().GetAwaiter();
var alphaAwaiter = alphaTensor.ReadbackAndCloneAsync().GetAwaiter();
while (!fgrAwaiter.IsCompleted || !alphaAwaiter.IsCompleted)
{
yield return null;
}
using (var foregroundOut = fgrAwaiter.GetResult())
using (var alphaOut = alphaAwaiter.GetResult())
{
Unity.InferenceEngine.TextureConverter.RenderToTexture(foregroundTensor, _foregroundTexture);
Unity.InferenceEngine.TextureConverter.RenderToTexture(alphaTensor, _alphaTexture);
}
try
{
if(sketchRawImage != null)
{
sketchRawImage.material.SetTexture("_FgrTex", _foregroundTexture);
sketchRawImage.material.SetTexture("_PhaTex", _alphaTexture);
}
}
catch (System.Exception e)
{
Debug.LogError("NOTE: Please make sure the RawImage has a material using the VideoMatting shader. Exception: " + e.Message);
}
}
}
private RenderTexture GetOrCreateRenderTexture(ref RenderTexture renderTexture, int width, int height, string name, bool forceCreate = false)
{
if (renderTexture == null || renderTexture.width != width || renderTexture.height != height || forceCreate)
{
if (renderTexture != null)
{
renderTexture.Release();
DestroyImmediate(renderTexture);
}
renderTexture = new RenderTexture(width, height, 24, RenderTextureFormat.ARGB32);
renderTexture.name = name;
renderTexture.Create();
}
return renderTexture;
}
// | Resolution | Portrait | Full-Body |
// | ------------- | ------------- | -------------- |
// | <= 512x512 | 1 | 1 |
// | 1280x720 | 0.375 | 0.6 |
// | 1920x1080 | 0.25 | 0.4 |
// | 3840x2160 | 0.125 | 0.2 |
// 5번의 다운샘플링이 모델 내에서 이루어지는데, 아래의 값으로 다운샘플링이 5번 이루어질때 홀수값이 나와선 안됨.
// width height 변경시 참고
private float CalculateOptimalDownsampleRatio(int width, int height)
{
int imagePixelCount = width * height;
if (imagePixelCount <= 512 * 512)
{
return 1.0f; // 원본 크기 유지
}
else if (imagePixelCount <= 1280 * 720)
{
return 0.6f;
}
else if (imagePixelCount <= 1920 * 1080)
{
return 0.4f;
}
else if (imagePixelCount <= 3840 * 2160)
{
return 0.2f;
}
else
{
return 0.1f;
}
}
void OnDestroy()
{
_r1?.Dispose();
_r2?.Dispose();
_r3?.Dispose();
_r4?.Dispose();
_inputTensor?.Dispose();
_downsampleRatioTensor?.Dispose();
_worker?.Dispose();
}
}
弯路地址 https://github.com/xue-fei/RobustVideoMatting-unity.git
工程地址
https://github.com/xue-fei/robustvideomatting-unity-sentis.git
fork from
https://github.com/realmorm/unity-sentis-robustvideomatting-example.git