Unity使用NovaSR将沉闷的16kHz音频升频成清晰的48kHz音频

NovaSR: Pushing the Limits of Extreme Efficiency in Audio Super-Resolution

原工程地址 https://github.com/ysharma3501/NovaSR.git

NovaSR, a tiny 50kB audio upsampling model that upscales muffled 16khz audio into clear and crisp 48khz audio at speeds over 3500x realtime.

Key benefits

Speed: Can reach 3600x realtime speed on a single a100 gpu.

Quality: On par with models 5,000x larger.

Size: Just 52kB in size, several thousand times smaller then most.

Why is this even useful?

Enhancing models: NovaSR can enhance TTS model quality considerably with nearly 0 computational cost.

Real-time enhancement: NovaSR allows for on device enhancement of any low quality calls, audio, etc. while using nearly no memory.

Restoring datasets: NovaSR can enhance audio quality of any audio dataset.

Comparisons

Comparisons were done on A100 gpu. Higher realtime means faster processing speeds.

转为onnx模型

复制代码
python export_onnx.py --checkpoint pytorch_model_v1.bin --output pytorch_model_v1.onnx --opset 17 --input-length 200
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm
import argparse
import warnings
import onnx
import math

warnings.filterwarnings("ignore", message=".*weight_norm is deprecated.*")

# ==============================================================================
# 🔴 关键问题:原始版本使用 register_buffer + 动态 repeat
# 这在 ONNX 导出时会导致权重被视为输入而不是常数
# ==============================================================================

def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
    """计算 Kaiser-Sinc 滤波器"""
    even = (kernel_size % 2 == 0)
    half_size = kernel_size // 2
    delta_f = 4 * half_width
    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
    if A > 50.: 
        beta = 0.1102 * (A - 8.7)
    elif A >= 21.: 
        beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
    else: 
        beta = 0.
    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
    time = (torch.arange(-half_size, half_size) + 0.5) if even else (torch.arange(kernel_size) - half_size)
    filter_ = 2 * cutoff * window * torch.sinc(2 * cutoff * time)
    filter_ /= filter_.sum()
    return filter_.view(1, 1, kernel_size)

# ==============================================================================
# ✅ 修复1:UpSample1d - 构建时就展开权重,不用动态 repeat
# ==============================================================================

class UpSample1d(nn.Module):
    """修复版:在 __init__ 时就展开所有权重,确保 ONNX 能识别为常数"""
    def __init__(self, ratio=2, kernel_size=12, channels=512):
        super().__init__()
        self.ratio = ratio
        self.kernel_size = kernel_size
        self.channels = channels
        
        # 计算 Kaiser 滤波器
        cutoff = 0.5 / ratio
        half_width = 0.5 / ratio
        filter_ = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
        
        # Polyphase 分解
        w = filter_.view(kernel_size) * ratio
        p0, p1 = w[0::2], w[1::2]
        weight = torch.stack([p0, p1], dim=0).unsqueeze(1)  # [ratio, 1, taps]
        
        # 🔴 问题版本做法:weight.repeat(channels, 1, 1)
        # 这会在 forward 时执行,导致 ONNX 认为这是动态操作
        
        # ✅ 修复:在 __init__ 时就将权重展开为完整的 Conv1d 权重
        weight = weight.repeat(channels, 1, 1)  # [C*ratio, 1, taps]
        
        # 创建 Conv1d 层(权重已在初始化时完全展开)
        self.conv = nn.Conv1d(
            in_channels=channels,
            out_channels=channels * ratio,
            kernel_size=weight.shape[2],
            stride=1,
            padding=0,
            groups=channels,
            bias=False
        )
        
        # 直接赋值权重为 Conv1d 的 Parameter(不是 buffer!)
        # Parameter 会被视为模型权重,ONNX 会正确导出
        self.conv.weight.data.copy_(weight)
        self.conv.weight.requires_grad = False

    def forward(self, x):
        B, C, T = x.shape
        
        # 固定 padding
        x = F.pad(x, (2, 3), mode='constant', value=0.0)
        
        # 使用固定的 Conv1d(权重已确定)
        out = self.conv(x)
        
        # 重塑
        out = out.view(B, C, self.ratio, -1)
        out = out.transpose(2, 3).reshape(B, C, -1)
        
        # 固定裁剪
        out = out[..., 2:-2]
        return out

# ==============================================================================
# ✅ 修复2:LowPassFilter1d - 同样处理
# ==============================================================================

class LowPassFilter1d(nn.Module):
    """修复版:权重在 __init__ 时完全展开"""
    def __init__(self, stride=2, kernel_size=12, channels=512):
        super().__init__()
        self.stride = stride
        self.channels = channels
        
        # 计算滤波器
        cutoff = 0.5 / stride
        half_width = 0.5 / stride
        filter_ = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
        
        # 🔴 问题版本:在 forward 时 repeat
        # ✅ 修复:在 __init__ 时就 repeat
        filter_ = filter_.repeat(channels, 1, 1)
        
        # 创建 Conv1d
        self.conv = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=5,
            groups=channels,
            bias=False
        )
        
        # 权重作为 Parameter
        self.conv.weight.data.copy_(filter_)
        self.conv.weight.requires_grad = False

    def forward(self, x):
        return self.conv(x)

class DownSample1d(nn.Module):
    def __init__(self, ratio=2, kernel_size=12, channels=512):
        super().__init__()
        self.lowpass = LowPassFilter1d(stride=ratio, kernel_size=kernel_size, channels=channels)
    
    def forward(self, x):
        return self.lowpass(x)

# ==============================================================================
# SnakeBeta 激活函数
# ==============================================================================

class SnakeBeta(nn.Module):
    def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
        super().__init__()
        self.alpha_logscale = alpha_logscale
        init_val = torch.zeros(in_features) if alpha_logscale else torch.ones(in_features)
        
        self.alpha = nn.Parameter(init_val * alpha)
        self.beta = nn.Parameter(init_val * alpha)
        self.alpha.requires_grad = alpha_trainable
        self.beta.requires_grad = alpha_trainable

    def forward(self, x):
        if self.alpha_logscale:
            a = torch.exp(self.alpha)
            b = torch.exp(self.beta)
        else:
            a = self.alpha
            b = self.beta
        
        a = a.view(1, -1, 1)
        b = b.view(1, -1, 1)
        eps = 1e-9
        return x + (1.0 - torch.cos(2.0 * a * x)) / (2.0 * b + eps)

# ==============================================================================
# Activation1d
# ==============================================================================

class Activation1d(nn.Module):
    def __init__(self, activation, up_ratio=2, down_ratio=2, 
                 up_kernel_size=12, down_kernel_size=12, channels=512):
        super().__init__()
        self.up_ratio = up_ratio
        self.down_ratio = down_ratio
        self.act = activation
        
        self.upsample = UpSample1d(up_ratio, up_kernel_size, channels)
        self.downsample = DownSample1d(down_ratio, down_kernel_size, channels)

    def forward(self, x):
        x = self.upsample(x)
        x = self.act(x)
        x = self.downsample(x)
        return x

# ==============================================================================
# AMPBlock0
# ==============================================================================

def get_padding(kernel_size, dilation=1):
    return (kernel_size * dilation - dilation) // 2

class AMPBlock0(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
        super().__init__()
        
        self.convs1 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, 
                                  dilation=dilation[0],
                                  padding=get_padding(kernel_size, dilation[0])))
        ])
        
        self.convs2 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, 
                                  dilation=1,
                                  padding=get_padding(kernel_size, 1)))
        ])
        
        self.num_layers = len(self.convs1) + len(self.convs2)
        
        self.activations = nn.ModuleList([
            Activation1d(
                activation=SnakeBeta(channels, alpha_logscale=True),
                channels=channels
            ) for _ in range(self.num_layers)
        ])

    def forward(self, x):
        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, 
                                   self.activations[::2], self.activations[1::2]):
            xt = a1(x)
            xt = c1(xt)
            xt = a2(xt)
            xt = c2(xt)
            x_residual = x.narrow(2, 0, xt.shape[2])
            x = xt + x_residual
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)

# ==============================================================================
# Generator
# ==============================================================================

class Generator(nn.Module):
    def __init__(self, initial_channel, resblock, resblock_kernel_sizes,
                 resblock_dilation_sizes, upsample_initial_channel, gin_channels=0):
        super().__init__()
        
        self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
        
        self.resblocks = nn.ModuleList()
        for i in range(1):
            ch = upsample_initial_channel // (2 ** i)
            for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                self.resblocks.append(AMPBlock0(ch, k, d, activation="snakebeta"))
        
        self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)

    def forward(self, x):
        x = self.conv_pre(x)
        x = F.interpolate(x, scale_factor=3.0, mode='linear', align_corners=False)
        xs = self.resblocks[0](x)
        x = self.conv_post(xs)
        x = torch.tanh(x)
        return x

    def remove_weight_norm(self):
        for l in self.resblocks:
            l.remove_weight_norm()

# ==============================================================================
# SynthesizerTrn
# ==============================================================================

class SynthesizerTrn(nn.Module):
    def __init__(self, spec_channels, segment_size, resblock,
                 resblock_kernel_sizes, resblock_dilation_sizes, 
                 upsample_initial_channel):
        super().__init__()
        self.spec_channels = spec_channels
        self.segment_size = segment_size
        
        self.dec = Generator(
            1, resblock, resblock_kernel_sizes,
            resblock_dilation_sizes, upsample_initial_channel
        )

    def forward(self, x):
        return self.dec(x)

# ==============================================================================
# 其他函数保持不变
# ==============================================================================

def detect_config(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    state_dict = ckpt.get('model', ckpt)
    
    config = {
        'upsample_initial_channel': None,
        'resblock_kernel_sizes': [11],
        'resblock_dilation_sizes': [[1, 3, 5]]
    }
    
    for key in state_dict:
        if 'dec.conv_pre.weight' in key:
            config['upsample_initial_channel'] = state_dict[key].shape[0]
            print(f"✅ 检测到 upsample_initial_channel = {config['upsample_initial_channel']}")
            break
    
    if config['upsample_initial_channel'] is None:
        raise ValueError("❌ 无法检测 upsample_initial_channel")
    
    return config

def convert_state_dict(state_dict, model_channels):
    """转换权重,将 buffer 权重转为 Conv1d Parameter"""
    new_state_dict = {}
    
    for k, v in state_dict.items():
        new_key = k.replace('module.', '')
        
        if 'upsample.filter' in new_key:
            ratio = 2
            kernel_size = v.shape[2]
            w = v.view(kernel_size) * ratio
            p0, p1 = w[0::2], w[1::2]
            weight = torch.stack([p0, p1], dim=0).unsqueeze(1)
            weight = weight.repeat(model_channels, 1, 1)
            new_state_dict[new_key.replace('filter', 'conv.weight')] = weight
            print(f"✅ 转换 upsample.filter -> upsample.conv.weight")
            
        elif 'downsample.lowpass.filter' in new_key:
            expanded = v.repeat(model_channels, 1, 1)
            new_state_dict[new_key.replace('filter', 'conv.weight')] = expanded
            print(f"✅ 转换 downsample.lowpass.filter -> downsample.lowpass.conv.weight")
            
        else:
            new_state_dict[new_key] = v
    
    return new_state_dict

def export_to_onnx(checkpoint_path, output_path, opset_version=16, input_length=100):
    print(f"🔍 加载 checkpoint: {checkpoint_path}")
    
    config = detect_config(checkpoint_path)
    model_channels = config['upsample_initial_channel']
    
    model = SynthesizerTrn(
        spec_channels=128,
        segment_size=30,
        resblock="amp",
        resblock_kernel_sizes=config['resblock_kernel_sizes'],
        resblock_dilation_sizes=config['resblock_dilation_sizes'],
        upsample_initial_channel=model_channels
    )
    
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    state_dict = ckpt.get('model', ckpt)
    
    print("🔧 移除 weight_norm...")
    model.dec.remove_weight_norm()
    
    print("🔄 转换 state_dict...")
    new_state_dict = convert_state_dict(state_dict, model_channels)
    
    print("📦 加载 state_dict...")
    model.load_state_dict(new_state_dict, strict=False)
    
    model.eval()
    
    print(f"📤 导出到 ONNX (opset={opset_version})...")
    dummy_input = torch.randn(1, 1, input_length, dtype=torch.float32)
    
    with torch.no_grad():
        print("🧪 测试 PyTorch 推理...")
        pt_output = model(dummy_input)
        print(f"   PyTorch 输出: {pt_output.shape}")
        
        torch.onnx.export(
            model,
            dummy_input,
            output_path,
            export_params=True,
            opset_version=opset_version,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch', 2: 'time'},
                'output': {0: 'batch', 2: 'time'}
            },
            verbose=False
        )
    
    print(f"✅ ONNX 保存: {output_path}")
    
    try:
        loaded_model = onnx.load(output_path)
        onnx.checker.check_model(loaded_model)
        print(f"✅ ONNX 验证通过!")
    except Exception as e:
        print(f"⚠️  验证失败: {e}")

def main():
    parser = argparse.ArgumentParser(description='导出 NovaSR 到 ONNX')
    parser.add_argument('--checkpoint', required=True, help='checkpoint 路径')
    parser.add_argument('--output', default='novasr_model.onnx', help='输出路径')
    parser.add_argument('--opset', type=int, default=16, help='ONNX opset 版本')
    parser.add_argument('--input-length', type=int, default=100, help='输入长度')
    args = parser.parse_args()
    
    try:
        export_to_onnx(
            checkpoint_path=args.checkpoint,
            output_path=args.output,
            opset_version=args.opset,
            input_length=args.input_length
        )
        print("\n🎉 导出完成!")
    except Exception as e:
        print(f"\n❌ 导出失败: {e}")
        import traceback
        traceback.print_exc()

if __name__ == '__main__':
    main()

在Unity中的实现

csharp 复制代码
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

/// <summary>
/// NovaSR ONNX 模型推理组件(修复版)
/// </summary>
public class OnnxModelInference : MonoBehaviour
{
    [Header("模型设置")]
    [Tooltip("ONNX模型文件路径(相对于StreamingAssets文件夹)")]
    public string modelFileName = "pytorch_model_v2.onnx";

    [Tooltip("是否在Start时自动加载模型")]
    public bool autoLoadOnStart = true;

    [Tooltip("是否使用GPU加速")]
    public bool useGpu = false;

    [Header("音频设置")]
    [Tooltip("输入采样率(应为16000 Hz)")]
    public int inputSampleRate = 16000;

    [Tooltip("输出采样率(应为48000 Hz)")]
    public int outputSampleRate = 48000;

    [Header("性能设置")]
    [Tooltip("推理线程数")]
    public int inferenceThreads = 4;

    [Tooltip("是否启用模型优化")]
    public bool enableOptimization = true;

    [Header("调试信息")]
    public bool showDebugInfo = true;
    public float lastInferenceTime = 0f;
    public int lastInputLength = 0;
    public int lastOutputLength = 0;

    private InferenceSession _session;
    private string _inputName;
    private string _outputName;
    private bool _isModelLoaded = false;

    public bool IsModelLoaded => _isModelLoaded;

    void Start()
    {
        if (autoLoadOnStart)
        {
            LoadModel();
        }
    }

    public bool LoadModel()
    {
        try
        {
            string modelPath = System.IO.Path.Combine(Application.streamingAssetsPath, modelFileName);
            if (!System.IO.File.Exists(modelPath))
            {
                Debug.LogError($"❌ 模型文件不存在: {modelPath}");
                return false;
            }

            var sessionOptions = new SessionOptions();

            // GPU 加速
            if (useGpu)
            {
                try
                {
                    sessionOptions.AppendExecutionProvider_CUDA(0);
                    Debug.Log("✅ 使用 GPU 加速");
                }
                catch (Exception ex)
                {
                    Debug.LogWarning($"⚠️ GPU 不可用,回退到 CPU: {ex.Message}");
                }
            }

            // 性能优化
            sessionOptions.InterOpNumThreads = 1;
            sessionOptions.IntraOpNumThreads = inferenceThreads;

            if (enableOptimization)
            {
                sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
                sessionOptions.ExecutionMode = ExecutionMode.ORT_PARALLEL;
            }

            _session = new InferenceSession(modelPath, sessionOptions);
            _inputName = _session.InputMetadata.Keys.First();
            _outputName = _session.OutputMetadata.Keys.First();

            // 验证模型
            var inputMeta = _session.InputMetadata[_inputName];
            var outputMeta = _session.OutputMetadata[_outputName];

            if (inputMeta.Dimensions.Length != 3 || inputMeta.Dimensions[1] != 1)
            {
                Debug.LogError($"❌ 模型输入维度不匹配!期望 [batch, 1, time],实际 [{string.Join(", ", inputMeta.Dimensions)}]");
                _session.Dispose();
                return false;
            }

            _isModelLoaded = true;

            if (showDebugInfo)
            {
                Debug.Log($"✅ 模型加载成功: {modelFileName}");
                Debug.Log($"   输入: {_inputName} [{string.Join(", ", inputMeta.Dimensions)}]");
                Debug.Log($"   输出: {_outputName} [{string.Join(", ", outputMeta.Dimensions)}]");
                Debug.Log($"   期望采样率: {inputSampleRate} Hz → {outputSampleRate} Hz");
            }

            return true;
        }
        catch (Exception ex)
        {
            Debug.LogError($"❌ 加载模型失败: {ex.Message}\n{ex.StackTrace}");
            return false;
        }
    }

    /// <summary>
    /// 执行推理(同步)- 核心方法
    /// </summary>
    public float[] Infer(float[] inputData)
    {
        if (!_isModelLoaded)
        {
            Debug.LogError("❌ 模型未加载!");
            return null;
        }

        if (inputData == null || inputData.Length == 0)
        {
            Debug.LogError("❌ 输入数据为空!");
            return null;
        }

        try
        {
            var startTime = Time.realtimeSinceStartup;

            // ✅ 修复1:不要裁剪!AudioClip 数据本身就是 [-1, 1]
            // ❌ 移除这段代码:
            // for (int i = 0; i < inputData.Length; i++)
            // {
            //     inputData[i] = Mathf.Clamp(inputData[i], -1f, 1f);
            // }

            // ✅ 修复2:确保输入数据没有 NaN 或 Inf
            bool hasInvalidData = false;
            for (int i = 0; i < inputData.Length; i++)
            {
                if (float.IsNaN(inputData[i]) || float.IsInfinity(inputData[i]))
                {
                    inputData[i] = 0f;
                    hasInvalidData = true;
                }
            }

            if (hasInvalidData && showDebugInfo)
            {
                Debug.LogWarning("⚠️ 输入数据包含 NaN/Inf,已替换为 0");
            }

            // 创建输入张量
            var inputDimensions = new[] { 1, 1, inputData.Length };
            var inputTensor = new DenseTensor<float>(inputData, inputDimensions);
            var inputs = new[] { NamedOnnxValue.CreateFromTensor(_inputName, inputTensor) };

            // 执行推理
            float[] output;
            using (var results = _session.Run(inputs))
            {
                output = results.First().AsEnumerable<float>().ToArray();
            }

            lastInferenceTime = (Time.realtimeSinceStartup - startTime) * 1000f;
            lastInputLength = inputData.Length;
            lastOutputLength = output.Length;

            if (showDebugInfo)
            {
                float inputRMS = CalculateRMS(inputData);
                float outputRMS = CalculateRMS(output);
                Debug.Log($"✅ 推理完成:\n" +
                         $"   输入: {inputData.Length} samples, RMS={inputRMS:F4}\n" +
                         $"   输出: {output.Length} samples, RMS={outputRMS:F4}\n" +
                         $"   上采样率: {(float)output.Length / inputData.Length:F2}x\n" +
                         $"   耗时: {lastInferenceTime:F2}ms");
            }

            return output;
        }
        catch (Exception ex)
        {
            Debug.LogError($"❌ 推理失败: {ex.Message}\n{ex.StackTrace}");
            return null;
        }
    }

    /// <summary>
    /// 从 AudioClip 推理(改进版)
    /// </summary>
    public float[] InferFromAudioClip(AudioClip audioClip, bool resampleTo16k = false)
    {
        if (audioClip == null)
        {
            Debug.LogError("❌ AudioClip 为空!");
            return null;
        }

        // 获取原始数据
        float[] samples = new float[audioClip.samples * audioClip.channels];
        audioClip.GetData(samples, 0);

        // ✅ 修复3:正确处理立体声
        if (audioClip.channels == 2)
        {
            float[] mono = new float[audioClip.samples];
            for (int i = 0; i < audioClip.samples; i++)
            {
                // 使用标准的立体声转单声道公式
                mono[i] = (samples[i * 2] + samples[i * 2 + 1]) * 0.5f;
            }
            samples = mono;
        }

        // ✅ 修复4:检查采样率
        if (audioClip.frequency != inputSampleRate)
        {
            if (resampleTo16k)
            {
                Debug.LogWarning($"⚠️ AudioClip 采样率为 {audioClip.frequency} Hz,将重采样到 {inputSampleRate} Hz");
                samples = SimpleResample(samples, audioClip.frequency, inputSampleRate);
            }
            else
            {
                Debug.LogError($"❌ AudioClip 采样率不匹配!期望 {inputSampleRate} Hz,实际 {audioClip.frequency} Hz\n" +
                              $"请设置 resampleTo16k=true 或使用正确采样率的音频");
                return null;
            }
        }

        return Infer(samples);
    }

    /// <summary>
    /// 创建输出 AudioClip
    /// </summary>
    public AudioClip CreateOutputAudioClip(float[] outputData, string clipName = "NovaSR_Output")
    {
        if (outputData == null || outputData.Length == 0)
        {
            Debug.LogError("❌ 输出数据为空!");
            return null;
        }

        // ✅ 修复5:确保输出数据在 [-1, 1] 范围内
        float maxAbs = 0f;
        for (int i = 0; i < outputData.Length; i++)
        {
            float abs = Mathf.Abs(outputData[i]);
            if (abs > maxAbs) maxAbs = abs;
        }

        // 如果超出范围,进行归一化
        if (maxAbs > 1f)
        {
            Debug.LogWarning($"⚠️ 输出数据超出范围 (max={maxAbs:F3}),进行归一化");
            for (int i = 0; i < outputData.Length; i++)
            {
                outputData[i] /= maxAbs;
            }
        }

        AudioClip clip = AudioClip.Create(clipName, outputData.Length, 1, outputSampleRate, false);
        clip.SetData(outputData, 0);
        return clip;
    }

    /// <summary>
    /// 完整的音频超分辨率处理流程
    /// </summary>
    public AudioClip ProcessAudio(AudioClip inputClip, string outputName = "Enhanced_Audio")
    {
        if (!_isModelLoaded)
        {
            Debug.LogError("❌ 模型未加载!");
            return null;
        }

        // 从 AudioClip 推理
        float[] outputData = InferFromAudioClip(inputClip, resampleTo16k: true);
        if (outputData == null) return null;

        // 创建输出 AudioClip
        return CreateOutputAudioClip(outputData, outputName);
    }

    /// <summary>
    /// 批量推理
    /// </summary>
    public List<float[]> InferBatch(List<float[]> inputBatch)
    {
        if (!_isModelLoaded)
        {
            Debug.LogError("❌ 模型未加载!");
            return null;
        }

        var results = new List<float[]>();
        foreach (var input in inputBatch)
        {
            var output = Infer(input);
            if (output != null)
            {
                results.Add(output);
            }
        }
        return results;
    }

    // ========== 辅助方法 ==========

    /// <summary>
    /// 计算 RMS(均方根)用于调试
    /// </summary>
    private float CalculateRMS(float[] samples)
    {
        if (samples == null || samples.Length == 0) return 0f;

        double sum = 0;
        for (int i = 0; i < samples.Length; i++)
        {
            sum += samples[i] * samples[i];
        }
        return (float)Math.Sqrt(sum / samples.Length);
    }

    /// <summary>
    /// 简单的线性重采样(仅用于采样率转换)
    /// </summary>
    private float[] SimpleResample(float[] input, int fromRate, int toRate)
    {
        if (fromRate == toRate) return input;

        double ratio = (double)fromRate / toRate;
        int outputLength = (int)(input.Length / ratio);
        float[] output = new float[outputLength];

        for (int i = 0; i < outputLength; i++)
        {
            double srcIndex = i * ratio;
            int idx1 = (int)srcIndex;
            int idx2 = Math.Min(idx1 + 1, input.Length - 1);
            float frac = (float)(srcIndex - idx1);

            // 线性插值
            output[i] = input[idx1] * (1f - frac) + input[idx2] * frac;
        }

        return output;
    }

    /// <summary>
    /// 验证模型输出质量
    /// </summary>
    public bool ValidateModelOutput(AudioClip testClip)
    {
        Debug.Log("🔍 开始模型验证...");

        float[] output = InferFromAudioClip(testClip, resampleTo16k: true);
        if (output == null)
        {
            Debug.LogError("❌ 验证失败:推理返回 null");
            return false;
        }

        // 检查输出质量
        float rms = CalculateRMS(output);
        float maxAbs = 0f;
        int nanCount = 0;

        for (int i = 0; i < output.Length; i++)
        {
            if (float.IsNaN(output[i]) || float.IsInfinity(output[i]))
            {
                nanCount++;
            }
            float abs = Mathf.Abs(output[i]);
            if (abs > maxAbs) maxAbs = abs;
        }

        Debug.Log($"📊 验证结果:\n" +
                 $"   输出长度: {output.Length}\n" +
                 $"   RMS: {rms:F4}\n" +
                 $"   最大值: {maxAbs:F4}\n" +
                 $"   异常值数量: {nanCount}");

        bool isValid = nanCount == 0 && rms > 0.001f && maxAbs < 100f;
        Debug.Log(isValid ? "✅ 模型验证通过" : "❌ 模型验证失败");
        return isValid;
    }

    void OnDestroy()
    {
        if (_session != null)
        {
            _session.Dispose();
            _session = null;
        }
        _isModelLoaded = false;
    }

    void OnApplicationQuit() => OnDestroy();
}
csharp 复制代码
using System.Collections;
using System.IO;
using UnityEngine;

/// <summary>
/// 使用示例MonoBehaviour
/// </summary>
public class OnnxInferenceExample : MonoBehaviour
{
    public OnnxModelInference modelInference;
    public AudioClip testAudioClip;

    void Start()
    {
        Loom.Initialize();

        if (modelInference == null)
        {
            modelInference = GetComponent<OnnxModelInference>();
        }

        // 等待模型加载
        StartCoroutine(RunExamples());
    }

    IEnumerator RunExamples()
    {
        // 等待模型加载
        while (!modelInference.IsModelLoaded)
        {
            yield return new WaitForSeconds(0.1f);
        }
         
        if (testAudioClip != null)
        {
            float[] audioOutput = GetSamples(testAudioClip);
            audioOutput = modelInference.Infer(audioOutput);
            if (audioOutput != null)
            {
                Debug.Log($"示例3完成: 音频推理输出长度={audioOutput.Length}");
                SaveAudioAsWav(audioOutput, 48000, Application.dataPath + "/after.wav");
            }
            else
            {
                Debug.LogError("audioOutput == null");
            }
        } 
    }

    float[] GetSamples(AudioClip audioClip)
    {
        if (audioClip == null)
        {
            Debug.LogError("AudioClip为空!");
            return null;
        }

        // 提取音频数据
        float[] samples = new float[audioClip.samples * audioClip.channels];
        audioClip.GetData(samples, 0);

        // 如果是立体声,转换为单声道
        if (audioClip.channels == 2)
        {
            float[] mono = new float[audioClip.samples];
            for (int i = 0; i < audioClip.samples; i++)
            {
                mono[i] = (samples[i * 2] + samples[i * 2 + 1]) / 2f;
            }
            samples = mono;
        }
        return samples;
    }

    /// <summary>
    /// 将 float[] 音频数据保存为 WAV 文件
    /// </summary>
    /// <param name="samples">单声道 float 音频数据(范围 -1.0 ~ 1.0)</param>
    /// <param name="sampleRate">采样率(如 22050, 44100)</param>
    /// <param name="filePath">保存路径(如 "output.wav")</param>
    public void SaveAudioAsWav(float[] samples, int sampleRate, string filePath)
    {
        if (samples == null || samples.Length == 0)
        {
            Debug.LogError("音频数据为空!");
            return;
        }

        // 确保值在 [-1, 1] 范围内
        float[] clampedSamples = new float[samples.Length];
        for (int i = 0; i < samples.Length; i++)
        {
            clampedSamples[i] = Mathf.Clamp(samples[i], -1f, 1f);
        }

        // 转换为 16-bit PCM
        short[] pcm = new short[clampedSamples.Length];
        for (int i = 0; i < clampedSamples.Length; i++)
        {
            pcm[i] = (short)(clampedSamples[i] * 32767); // 32767 = Int16.MaxValue
        }

        // 写入 WAV 文件
        using (FileStream fs = new FileStream(filePath, FileMode.Create))
        using (BinaryWriter writer = new BinaryWriter(fs))
        {
            // RIFF header
            writer.Write(System.Text.Encoding.ASCII.GetBytes("RIFF"));
            writer.Write(36 + pcm.Length * 2); // ChunkSize
            writer.Write(System.Text.Encoding.ASCII.GetBytes("WAVE"));

            // fmt subchunk
            writer.Write(System.Text.Encoding.ASCII.GetBytes("fmt "));
            writer.Write(16); // Subchunk1Size
            writer.Write((short)1); // AudioFormat (1 = PCM)
            writer.Write((short)1); // NumChannels (1 = mono)
            writer.Write(sampleRate); // SampleRate
            writer.Write(sampleRate * 2); // ByteRate
            writer.Write((short)2); // BlockAlign
            writer.Write((short)16); // BitsPerSample

            // data subchunk
            writer.Write(System.Text.Encoding.ASCII.GetBytes("data"));
            writer.Write(pcm.Length * 2); // Subchunk2Size
            foreach (short sample in pcm)
            {
                writer.Write(sample);
            }
        }

        Debug.Log($"✅ 音频已保存到: {filePath}");
    }
}

最后,效果方面,和原版处理结果相比,差了点,但沉闷感减轻了很多,暂未找到原因......

before1是原始示例1,after1是原始处理结果1,after是Unity中的处理结果

工程地址
https://github.com/xue-fei/novasr-unity.git

相关推荐
JIes__7 小时前
Unity(二)——Resources资源动态加载
unity·游戏引擎
三十_A8 小时前
前端技术分享:基于 Canvas 实现视频帧截取与下载方案
前端·音视频
Dreams°1238 小时前
进阶实战:Wan2.2-T2V-A5B 实现可点击跳转的互动式教育视频
算法·microsoft·ai·音视频
Coovally AI模型快速验证10 小时前
YOLO26技术详解:原生NMS-Free架构设计与实现原理
人工智能·计算机视觉·开源·音视频·无人机
郭涤生10 小时前
高斯滤波从入门到精通
linux·音视频
dzj202110 小时前
Unity中使用LLMUnity遇到的问题(二)——LLMUnity脚本学习和探索
unity·llmunity
想你依然心痛14 小时前
AI 音效新征程:HunyuanVideo-Foley 视频配音实战
人工智能·音视频·智能电视
子夜江寒14 小时前
OpenCV 学习:文档扫描与视频运动检测与跟踪
opencv·学习·计算机视觉·音视频
wgfhill14 小时前
面向自媒体工作者的视频搬运内容去重技术方案:智能抽帧降帧处理工具解析
新媒体运营·音视频·媒体·视频