NovaSR: Pushing the Limits of Extreme Efficiency in Audio Super-Resolution
原工程地址 https://github.com/ysharma3501/NovaSR.git
NovaSR, a tiny 50kB audio upsampling model that upscales muffled 16khz audio into clear and crisp 48khz audio at speeds over 3500x realtime.
Key benefits
Speed: Can reach 3600x realtime speed on a single a100 gpu.
Quality: On par with models 5,000x larger.
Size: Just 52kB in size, several thousand times smaller then most.
Why is this even useful?
Enhancing models: NovaSR can enhance TTS model quality considerably with nearly 0 computational cost.
Real-time enhancement: NovaSR allows for on device enhancement of any low quality calls, audio, etc. while using nearly no memory.
Restoring datasets: NovaSR can enhance audio quality of any audio dataset.
Comparisons
Comparisons were done on A100 gpu. Higher realtime means faster processing speeds.
转为onnx模型
python export_onnx.py --checkpoint pytorch_model_v1.bin --output pytorch_model_v1.onnx --opset 17 --input-length 200
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm
import argparse
import warnings
import onnx
import math
warnings.filterwarnings("ignore", message=".*weight_norm is deprecated.*")
# ==============================================================================
# 🔴 关键问题:原始版本使用 register_buffer + 动态 repeat
# 这在 ONNX 导出时会导致权重被视为输入而不是常数
# ==============================================================================
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
"""计算 Kaiser-Sinc 滤波器"""
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
time = (torch.arange(-half_size, half_size) + 0.5) if even else (torch.arange(kernel_size) - half_size)
filter_ = 2 * cutoff * window * torch.sinc(2 * cutoff * time)
filter_ /= filter_.sum()
return filter_.view(1, 1, kernel_size)
# ==============================================================================
# ✅ 修复1:UpSample1d - 构建时就展开权重,不用动态 repeat
# ==============================================================================
class UpSample1d(nn.Module):
"""修复版:在 __init__ 时就展开所有权重,确保 ONNX 能识别为常数"""
def __init__(self, ratio=2, kernel_size=12, channels=512):
super().__init__()
self.ratio = ratio
self.kernel_size = kernel_size
self.channels = channels
# 计算 Kaiser 滤波器
cutoff = 0.5 / ratio
half_width = 0.5 / ratio
filter_ = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
# Polyphase 分解
w = filter_.view(kernel_size) * ratio
p0, p1 = w[0::2], w[1::2]
weight = torch.stack([p0, p1], dim=0).unsqueeze(1) # [ratio, 1, taps]
# 🔴 问题版本做法:weight.repeat(channels, 1, 1)
# 这会在 forward 时执行,导致 ONNX 认为这是动态操作
# ✅ 修复:在 __init__ 时就将权重展开为完整的 Conv1d 权重
weight = weight.repeat(channels, 1, 1) # [C*ratio, 1, taps]
# 创建 Conv1d 层(权重已在初始化时完全展开)
self.conv = nn.Conv1d(
in_channels=channels,
out_channels=channels * ratio,
kernel_size=weight.shape[2],
stride=1,
padding=0,
groups=channels,
bias=False
)
# 直接赋值权重为 Conv1d 的 Parameter(不是 buffer!)
# Parameter 会被视为模型权重,ONNX 会正确导出
self.conv.weight.data.copy_(weight)
self.conv.weight.requires_grad = False
def forward(self, x):
B, C, T = x.shape
# 固定 padding
x = F.pad(x, (2, 3), mode='constant', value=0.0)
# 使用固定的 Conv1d(权重已确定)
out = self.conv(x)
# 重塑
out = out.view(B, C, self.ratio, -1)
out = out.transpose(2, 3).reshape(B, C, -1)
# 固定裁剪
out = out[..., 2:-2]
return out
# ==============================================================================
# ✅ 修复2:LowPassFilter1d - 同样处理
# ==============================================================================
class LowPassFilter1d(nn.Module):
"""修复版:权重在 __init__ 时完全展开"""
def __init__(self, stride=2, kernel_size=12, channels=512):
super().__init__()
self.stride = stride
self.channels = channels
# 计算滤波器
cutoff = 0.5 / stride
half_width = 0.5 / stride
filter_ = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
# 🔴 问题版本:在 forward 时 repeat
# ✅ 修复:在 __init__ 时就 repeat
filter_ = filter_.repeat(channels, 1, 1)
# 创建 Conv1d
self.conv = nn.Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=stride,
padding=5,
groups=channels,
bias=False
)
# 权重作为 Parameter
self.conv.weight.data.copy_(filter_)
self.conv.weight.requires_grad = False
def forward(self, x):
return self.conv(x)
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=12, channels=512):
super().__init__()
self.lowpass = LowPassFilter1d(stride=ratio, kernel_size=kernel_size, channels=channels)
def forward(self, x):
return self.lowpass(x)
# ==============================================================================
# SnakeBeta 激活函数
# ==============================================================================
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
super().__init__()
self.alpha_logscale = alpha_logscale
init_val = torch.zeros(in_features) if alpha_logscale else torch.ones(in_features)
self.alpha = nn.Parameter(init_val * alpha)
self.beta = nn.Parameter(init_val * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
def forward(self, x):
if self.alpha_logscale:
a = torch.exp(self.alpha)
b = torch.exp(self.beta)
else:
a = self.alpha
b = self.beta
a = a.view(1, -1, 1)
b = b.view(1, -1, 1)
eps = 1e-9
return x + (1.0 - torch.cos(2.0 * a * x)) / (2.0 * b + eps)
# ==============================================================================
# Activation1d
# ==============================================================================
class Activation1d(nn.Module):
def __init__(self, activation, up_ratio=2, down_ratio=2,
up_kernel_size=12, down_kernel_size=12, channels=512):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size, channels)
self.downsample = DownSample1d(down_ratio, down_kernel_size, channels)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# ==============================================================================
# AMPBlock0
# ==============================================================================
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class AMPBlock0(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super().__init__()
self.convs1 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])))
])
self.convs2 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1,
dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.num_layers = len(self.convs1) + len(self.convs2)
self.activations = nn.ModuleList([
Activation1d(
activation=SnakeBeta(channels, alpha_logscale=True),
channels=channels
) for _ in range(self.num_layers)
])
def forward(self, x):
for c1, c2, a1, a2 in zip(self.convs1, self.convs2,
self.activations[::2], self.activations[1::2]):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x_residual = x.narrow(2, 0, xt.shape[2])
x = xt + x_residual
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
# ==============================================================================
# Generator
# ==============================================================================
class Generator(nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes,
resblock_dilation_sizes, upsample_initial_channel, gin_channels=0):
super().__init__()
self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
self.resblocks = nn.ModuleList()
for i in range(1):
ch = upsample_initial_channel // (2 ** i)
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(AMPBlock0(ch, k, d, activation="snakebeta"))
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
def forward(self, x):
x = self.conv_pre(x)
x = F.interpolate(x, scale_factor=3.0, mode='linear', align_corners=False)
xs = self.resblocks[0](x)
x = self.conv_post(xs)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.resblocks:
l.remove_weight_norm()
# ==============================================================================
# SynthesizerTrn
# ==============================================================================
class SynthesizerTrn(nn.Module):
def __init__(self, spec_channels, segment_size, resblock,
resblock_kernel_sizes, resblock_dilation_sizes,
upsample_initial_channel):
super().__init__()
self.spec_channels = spec_channels
self.segment_size = segment_size
self.dec = Generator(
1, resblock, resblock_kernel_sizes,
resblock_dilation_sizes, upsample_initial_channel
)
def forward(self, x):
return self.dec(x)
# ==============================================================================
# 其他函数保持不变
# ==============================================================================
def detect_config(checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('model', ckpt)
config = {
'upsample_initial_channel': None,
'resblock_kernel_sizes': [11],
'resblock_dilation_sizes': [[1, 3, 5]]
}
for key in state_dict:
if 'dec.conv_pre.weight' in key:
config['upsample_initial_channel'] = state_dict[key].shape[0]
print(f"✅ 检测到 upsample_initial_channel = {config['upsample_initial_channel']}")
break
if config['upsample_initial_channel'] is None:
raise ValueError("❌ 无法检测 upsample_initial_channel")
return config
def convert_state_dict(state_dict, model_channels):
"""转换权重,将 buffer 权重转为 Conv1d Parameter"""
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace('module.', '')
if 'upsample.filter' in new_key:
ratio = 2
kernel_size = v.shape[2]
w = v.view(kernel_size) * ratio
p0, p1 = w[0::2], w[1::2]
weight = torch.stack([p0, p1], dim=0).unsqueeze(1)
weight = weight.repeat(model_channels, 1, 1)
new_state_dict[new_key.replace('filter', 'conv.weight')] = weight
print(f"✅ 转换 upsample.filter -> upsample.conv.weight")
elif 'downsample.lowpass.filter' in new_key:
expanded = v.repeat(model_channels, 1, 1)
new_state_dict[new_key.replace('filter', 'conv.weight')] = expanded
print(f"✅ 转换 downsample.lowpass.filter -> downsample.lowpass.conv.weight")
else:
new_state_dict[new_key] = v
return new_state_dict
def export_to_onnx(checkpoint_path, output_path, opset_version=16, input_length=100):
print(f"🔍 加载 checkpoint: {checkpoint_path}")
config = detect_config(checkpoint_path)
model_channels = config['upsample_initial_channel']
model = SynthesizerTrn(
spec_channels=128,
segment_size=30,
resblock="amp",
resblock_kernel_sizes=config['resblock_kernel_sizes'],
resblock_dilation_sizes=config['resblock_dilation_sizes'],
upsample_initial_channel=model_channels
)
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('model', ckpt)
print("🔧 移除 weight_norm...")
model.dec.remove_weight_norm()
print("🔄 转换 state_dict...")
new_state_dict = convert_state_dict(state_dict, model_channels)
print("📦 加载 state_dict...")
model.load_state_dict(new_state_dict, strict=False)
model.eval()
print(f"📤 导出到 ONNX (opset={opset_version})...")
dummy_input = torch.randn(1, 1, input_length, dtype=torch.float32)
with torch.no_grad():
print("🧪 测试 PyTorch 推理...")
pt_output = model(dummy_input)
print(f" PyTorch 输出: {pt_output.shape}")
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch', 2: 'time'},
'output': {0: 'batch', 2: 'time'}
},
verbose=False
)
print(f"✅ ONNX 保存: {output_path}")
try:
loaded_model = onnx.load(output_path)
onnx.checker.check_model(loaded_model)
print(f"✅ ONNX 验证通过!")
except Exception as e:
print(f"⚠️ 验证失败: {e}")
def main():
parser = argparse.ArgumentParser(description='导出 NovaSR 到 ONNX')
parser.add_argument('--checkpoint', required=True, help='checkpoint 路径')
parser.add_argument('--output', default='novasr_model.onnx', help='输出路径')
parser.add_argument('--opset', type=int, default=16, help='ONNX opset 版本')
parser.add_argument('--input-length', type=int, default=100, help='输入长度')
args = parser.parse_args()
try:
export_to_onnx(
checkpoint_path=args.checkpoint,
output_path=args.output,
opset_version=args.opset,
input_length=args.input_length
)
print("\n🎉 导出完成!")
except Exception as e:
print(f"\n❌ 导出失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()
在Unity中的实现
csharp
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
/// <summary>
/// NovaSR ONNX 模型推理组件(修复版)
/// </summary>
public class OnnxModelInference : MonoBehaviour
{
[Header("模型设置")]
[Tooltip("ONNX模型文件路径(相对于StreamingAssets文件夹)")]
public string modelFileName = "pytorch_model_v2.onnx";
[Tooltip("是否在Start时自动加载模型")]
public bool autoLoadOnStart = true;
[Tooltip("是否使用GPU加速")]
public bool useGpu = false;
[Header("音频设置")]
[Tooltip("输入采样率(应为16000 Hz)")]
public int inputSampleRate = 16000;
[Tooltip("输出采样率(应为48000 Hz)")]
public int outputSampleRate = 48000;
[Header("性能设置")]
[Tooltip("推理线程数")]
public int inferenceThreads = 4;
[Tooltip("是否启用模型优化")]
public bool enableOptimization = true;
[Header("调试信息")]
public bool showDebugInfo = true;
public float lastInferenceTime = 0f;
public int lastInputLength = 0;
public int lastOutputLength = 0;
private InferenceSession _session;
private string _inputName;
private string _outputName;
private bool _isModelLoaded = false;
public bool IsModelLoaded => _isModelLoaded;
void Start()
{
if (autoLoadOnStart)
{
LoadModel();
}
}
public bool LoadModel()
{
try
{
string modelPath = System.IO.Path.Combine(Application.streamingAssetsPath, modelFileName);
if (!System.IO.File.Exists(modelPath))
{
Debug.LogError($"❌ 模型文件不存在: {modelPath}");
return false;
}
var sessionOptions = new SessionOptions();
// GPU 加速
if (useGpu)
{
try
{
sessionOptions.AppendExecutionProvider_CUDA(0);
Debug.Log("✅ 使用 GPU 加速");
}
catch (Exception ex)
{
Debug.LogWarning($"⚠️ GPU 不可用,回退到 CPU: {ex.Message}");
}
}
// 性能优化
sessionOptions.InterOpNumThreads = 1;
sessionOptions.IntraOpNumThreads = inferenceThreads;
if (enableOptimization)
{
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
sessionOptions.ExecutionMode = ExecutionMode.ORT_PARALLEL;
}
_session = new InferenceSession(modelPath, sessionOptions);
_inputName = _session.InputMetadata.Keys.First();
_outputName = _session.OutputMetadata.Keys.First();
// 验证模型
var inputMeta = _session.InputMetadata[_inputName];
var outputMeta = _session.OutputMetadata[_outputName];
if (inputMeta.Dimensions.Length != 3 || inputMeta.Dimensions[1] != 1)
{
Debug.LogError($"❌ 模型输入维度不匹配!期望 [batch, 1, time],实际 [{string.Join(", ", inputMeta.Dimensions)}]");
_session.Dispose();
return false;
}
_isModelLoaded = true;
if (showDebugInfo)
{
Debug.Log($"✅ 模型加载成功: {modelFileName}");
Debug.Log($" 输入: {_inputName} [{string.Join(", ", inputMeta.Dimensions)}]");
Debug.Log($" 输出: {_outputName} [{string.Join(", ", outputMeta.Dimensions)}]");
Debug.Log($" 期望采样率: {inputSampleRate} Hz → {outputSampleRate} Hz");
}
return true;
}
catch (Exception ex)
{
Debug.LogError($"❌ 加载模型失败: {ex.Message}\n{ex.StackTrace}");
return false;
}
}
/// <summary>
/// 执行推理(同步)- 核心方法
/// </summary>
public float[] Infer(float[] inputData)
{
if (!_isModelLoaded)
{
Debug.LogError("❌ 模型未加载!");
return null;
}
if (inputData == null || inputData.Length == 0)
{
Debug.LogError("❌ 输入数据为空!");
return null;
}
try
{
var startTime = Time.realtimeSinceStartup;
// ✅ 修复1:不要裁剪!AudioClip 数据本身就是 [-1, 1]
// ❌ 移除这段代码:
// for (int i = 0; i < inputData.Length; i++)
// {
// inputData[i] = Mathf.Clamp(inputData[i], -1f, 1f);
// }
// ✅ 修复2:确保输入数据没有 NaN 或 Inf
bool hasInvalidData = false;
for (int i = 0; i < inputData.Length; i++)
{
if (float.IsNaN(inputData[i]) || float.IsInfinity(inputData[i]))
{
inputData[i] = 0f;
hasInvalidData = true;
}
}
if (hasInvalidData && showDebugInfo)
{
Debug.LogWarning("⚠️ 输入数据包含 NaN/Inf,已替换为 0");
}
// 创建输入张量
var inputDimensions = new[] { 1, 1, inputData.Length };
var inputTensor = new DenseTensor<float>(inputData, inputDimensions);
var inputs = new[] { NamedOnnxValue.CreateFromTensor(_inputName, inputTensor) };
// 执行推理
float[] output;
using (var results = _session.Run(inputs))
{
output = results.First().AsEnumerable<float>().ToArray();
}
lastInferenceTime = (Time.realtimeSinceStartup - startTime) * 1000f;
lastInputLength = inputData.Length;
lastOutputLength = output.Length;
if (showDebugInfo)
{
float inputRMS = CalculateRMS(inputData);
float outputRMS = CalculateRMS(output);
Debug.Log($"✅ 推理完成:\n" +
$" 输入: {inputData.Length} samples, RMS={inputRMS:F4}\n" +
$" 输出: {output.Length} samples, RMS={outputRMS:F4}\n" +
$" 上采样率: {(float)output.Length / inputData.Length:F2}x\n" +
$" 耗时: {lastInferenceTime:F2}ms");
}
return output;
}
catch (Exception ex)
{
Debug.LogError($"❌ 推理失败: {ex.Message}\n{ex.StackTrace}");
return null;
}
}
/// <summary>
/// 从 AudioClip 推理(改进版)
/// </summary>
public float[] InferFromAudioClip(AudioClip audioClip, bool resampleTo16k = false)
{
if (audioClip == null)
{
Debug.LogError("❌ AudioClip 为空!");
return null;
}
// 获取原始数据
float[] samples = new float[audioClip.samples * audioClip.channels];
audioClip.GetData(samples, 0);
// ✅ 修复3:正确处理立体声
if (audioClip.channels == 2)
{
float[] mono = new float[audioClip.samples];
for (int i = 0; i < audioClip.samples; i++)
{
// 使用标准的立体声转单声道公式
mono[i] = (samples[i * 2] + samples[i * 2 + 1]) * 0.5f;
}
samples = mono;
}
// ✅ 修复4:检查采样率
if (audioClip.frequency != inputSampleRate)
{
if (resampleTo16k)
{
Debug.LogWarning($"⚠️ AudioClip 采样率为 {audioClip.frequency} Hz,将重采样到 {inputSampleRate} Hz");
samples = SimpleResample(samples, audioClip.frequency, inputSampleRate);
}
else
{
Debug.LogError($"❌ AudioClip 采样率不匹配!期望 {inputSampleRate} Hz,实际 {audioClip.frequency} Hz\n" +
$"请设置 resampleTo16k=true 或使用正确采样率的音频");
return null;
}
}
return Infer(samples);
}
/// <summary>
/// 创建输出 AudioClip
/// </summary>
public AudioClip CreateOutputAudioClip(float[] outputData, string clipName = "NovaSR_Output")
{
if (outputData == null || outputData.Length == 0)
{
Debug.LogError("❌ 输出数据为空!");
return null;
}
// ✅ 修复5:确保输出数据在 [-1, 1] 范围内
float maxAbs = 0f;
for (int i = 0; i < outputData.Length; i++)
{
float abs = Mathf.Abs(outputData[i]);
if (abs > maxAbs) maxAbs = abs;
}
// 如果超出范围,进行归一化
if (maxAbs > 1f)
{
Debug.LogWarning($"⚠️ 输出数据超出范围 (max={maxAbs:F3}),进行归一化");
for (int i = 0; i < outputData.Length; i++)
{
outputData[i] /= maxAbs;
}
}
AudioClip clip = AudioClip.Create(clipName, outputData.Length, 1, outputSampleRate, false);
clip.SetData(outputData, 0);
return clip;
}
/// <summary>
/// 完整的音频超分辨率处理流程
/// </summary>
public AudioClip ProcessAudio(AudioClip inputClip, string outputName = "Enhanced_Audio")
{
if (!_isModelLoaded)
{
Debug.LogError("❌ 模型未加载!");
return null;
}
// 从 AudioClip 推理
float[] outputData = InferFromAudioClip(inputClip, resampleTo16k: true);
if (outputData == null) return null;
// 创建输出 AudioClip
return CreateOutputAudioClip(outputData, outputName);
}
/// <summary>
/// 批量推理
/// </summary>
public List<float[]> InferBatch(List<float[]> inputBatch)
{
if (!_isModelLoaded)
{
Debug.LogError("❌ 模型未加载!");
return null;
}
var results = new List<float[]>();
foreach (var input in inputBatch)
{
var output = Infer(input);
if (output != null)
{
results.Add(output);
}
}
return results;
}
// ========== 辅助方法 ==========
/// <summary>
/// 计算 RMS(均方根)用于调试
/// </summary>
private float CalculateRMS(float[] samples)
{
if (samples == null || samples.Length == 0) return 0f;
double sum = 0;
for (int i = 0; i < samples.Length; i++)
{
sum += samples[i] * samples[i];
}
return (float)Math.Sqrt(sum / samples.Length);
}
/// <summary>
/// 简单的线性重采样(仅用于采样率转换)
/// </summary>
private float[] SimpleResample(float[] input, int fromRate, int toRate)
{
if (fromRate == toRate) return input;
double ratio = (double)fromRate / toRate;
int outputLength = (int)(input.Length / ratio);
float[] output = new float[outputLength];
for (int i = 0; i < outputLength; i++)
{
double srcIndex = i * ratio;
int idx1 = (int)srcIndex;
int idx2 = Math.Min(idx1 + 1, input.Length - 1);
float frac = (float)(srcIndex - idx1);
// 线性插值
output[i] = input[idx1] * (1f - frac) + input[idx2] * frac;
}
return output;
}
/// <summary>
/// 验证模型输出质量
/// </summary>
public bool ValidateModelOutput(AudioClip testClip)
{
Debug.Log("🔍 开始模型验证...");
float[] output = InferFromAudioClip(testClip, resampleTo16k: true);
if (output == null)
{
Debug.LogError("❌ 验证失败:推理返回 null");
return false;
}
// 检查输出质量
float rms = CalculateRMS(output);
float maxAbs = 0f;
int nanCount = 0;
for (int i = 0; i < output.Length; i++)
{
if (float.IsNaN(output[i]) || float.IsInfinity(output[i]))
{
nanCount++;
}
float abs = Mathf.Abs(output[i]);
if (abs > maxAbs) maxAbs = abs;
}
Debug.Log($"📊 验证结果:\n" +
$" 输出长度: {output.Length}\n" +
$" RMS: {rms:F4}\n" +
$" 最大值: {maxAbs:F4}\n" +
$" 异常值数量: {nanCount}");
bool isValid = nanCount == 0 && rms > 0.001f && maxAbs < 100f;
Debug.Log(isValid ? "✅ 模型验证通过" : "❌ 模型验证失败");
return isValid;
}
void OnDestroy()
{
if (_session != null)
{
_session.Dispose();
_session = null;
}
_isModelLoaded = false;
}
void OnApplicationQuit() => OnDestroy();
}
csharp
using System.Collections;
using System.IO;
using UnityEngine;
/// <summary>
/// 使用示例MonoBehaviour
/// </summary>
public class OnnxInferenceExample : MonoBehaviour
{
public OnnxModelInference modelInference;
public AudioClip testAudioClip;
void Start()
{
Loom.Initialize();
if (modelInference == null)
{
modelInference = GetComponent<OnnxModelInference>();
}
// 等待模型加载
StartCoroutine(RunExamples());
}
IEnumerator RunExamples()
{
// 等待模型加载
while (!modelInference.IsModelLoaded)
{
yield return new WaitForSeconds(0.1f);
}
if (testAudioClip != null)
{
float[] audioOutput = GetSamples(testAudioClip);
audioOutput = modelInference.Infer(audioOutput);
if (audioOutput != null)
{
Debug.Log($"示例3完成: 音频推理输出长度={audioOutput.Length}");
SaveAudioAsWav(audioOutput, 48000, Application.dataPath + "/after.wav");
}
else
{
Debug.LogError("audioOutput == null");
}
}
}
float[] GetSamples(AudioClip audioClip)
{
if (audioClip == null)
{
Debug.LogError("AudioClip为空!");
return null;
}
// 提取音频数据
float[] samples = new float[audioClip.samples * audioClip.channels];
audioClip.GetData(samples, 0);
// 如果是立体声,转换为单声道
if (audioClip.channels == 2)
{
float[] mono = new float[audioClip.samples];
for (int i = 0; i < audioClip.samples; i++)
{
mono[i] = (samples[i * 2] + samples[i * 2 + 1]) / 2f;
}
samples = mono;
}
return samples;
}
/// <summary>
/// 将 float[] 音频数据保存为 WAV 文件
/// </summary>
/// <param name="samples">单声道 float 音频数据(范围 -1.0 ~ 1.0)</param>
/// <param name="sampleRate">采样率(如 22050, 44100)</param>
/// <param name="filePath">保存路径(如 "output.wav")</param>
public void SaveAudioAsWav(float[] samples, int sampleRate, string filePath)
{
if (samples == null || samples.Length == 0)
{
Debug.LogError("音频数据为空!");
return;
}
// 确保值在 [-1, 1] 范围内
float[] clampedSamples = new float[samples.Length];
for (int i = 0; i < samples.Length; i++)
{
clampedSamples[i] = Mathf.Clamp(samples[i], -1f, 1f);
}
// 转换为 16-bit PCM
short[] pcm = new short[clampedSamples.Length];
for (int i = 0; i < clampedSamples.Length; i++)
{
pcm[i] = (short)(clampedSamples[i] * 32767); // 32767 = Int16.MaxValue
}
// 写入 WAV 文件
using (FileStream fs = new FileStream(filePath, FileMode.Create))
using (BinaryWriter writer = new BinaryWriter(fs))
{
// RIFF header
writer.Write(System.Text.Encoding.ASCII.GetBytes("RIFF"));
writer.Write(36 + pcm.Length * 2); // ChunkSize
writer.Write(System.Text.Encoding.ASCII.GetBytes("WAVE"));
// fmt subchunk
writer.Write(System.Text.Encoding.ASCII.GetBytes("fmt "));
writer.Write(16); // Subchunk1Size
writer.Write((short)1); // AudioFormat (1 = PCM)
writer.Write((short)1); // NumChannels (1 = mono)
writer.Write(sampleRate); // SampleRate
writer.Write(sampleRate * 2); // ByteRate
writer.Write((short)2); // BlockAlign
writer.Write((short)16); // BitsPerSample
// data subchunk
writer.Write(System.Text.Encoding.ASCII.GetBytes("data"));
writer.Write(pcm.Length * 2); // Subchunk2Size
foreach (short sample in pcm)
{
writer.Write(sample);
}
}
Debug.Log($"✅ 音频已保存到: {filePath}");
}
}
最后,效果方面,和原版处理结果相比,差了点,但沉闷感减轻了很多,暂未找到原因......
before1是原始示例1,after1是原始处理结果1,after是Unity中的处理结果
