BS-RoFormer,目前音频分离SOTA

留下阅读 (2023)Music Source Separation with Band-Split RoPE Transformer 和(2023)Mel-Band RoFormer for Music Source Separation 的痕迹。

论文提出的模型针对音乐分离任务(Music source separation, MSS),用于分离音频中的元素,例如将音乐分离成伴奏和人声。

1. 总览

mvsep 的 leaderboard 来看,BS-RoFormer 和 Mel-Band RoFormer 性能领先,是目前的 SOTA(20250819)。两个模型都是字节跳动做出来的,Mel-Band RoFormer 是 BS-RoFormer 的变体。

BS-RoFormer 已经达成了 SOTA,但其频段分割方案没有理论支撑,是凭经验确定的。这个 Mel-Band RoFormer 就是来补足这个缺陷,根据 mel scale 设计出符合人类耳朵的频段分割方案。

数据集使用 MUSDB18HQ,验证方法有效性。

2. BS-RoFormer

2.1. 实现思路与损失函数

设 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 代表输入音频, <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是其复频域形式。

期望 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 经过含有可学习参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的网络层 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_\theta </math>fθ 预测出 cIRMs,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> M ^ = f θ ( X ) \hat{M}=f_\theta(X) </math>M^=fθ(X)

借由这个 cIRMs 进行音频分离,最后就可以用逆短时傅里叶变换 iSTFT 获得预测结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Y ^ = M ^ ⊙ X \hat{Y}=\hat{M}\odot X </math>Y^=M^⊙X

什么是 cIRM?

理想比值掩蔽(Ideal Ratio Mask, IRM),可以理解为频谱的 mask。原频谱乘上 IRM 就能分离出想要的元素。

与之对应的是理想二值掩蔽(Ideal Binary Mask, IBM),这个 mask 就不是连续取值,而是只能取 0 和 1。

模型使用复数频谱,那就需要 complex IRMs,简称 cIRMs。

一句话总结:模型生成一个 mask,音频在复频域乘上这个 mask,获得推理结果。

关于损失函数,先看公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> loss = ∥ y − y ^ ∥ + ∑ s = 0 S − 1 ∥ Y ( s ) − Y ^ ( s ) ∥ \text{loss}= \Vert y-\hat{y} \Vert + \sum^{S-1}_{s=0}\Vert Y^{(s)}-\hat{Y}^{(s)} \Vert </math>loss=∥y−y^∥+s=0∑S−1∥Y(s)−Y^(s)∥

第一项很好理解,时域上的目标 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 与预测结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 的 MAE。

第二项是因为模型用到了多个不同窗口大小的 STFT,所以在频域上也计算 MAE。

模型的 STFT 窗口大小有 [4096, 2048, 1024, 512, 256],步长 147,等价于每秒 300 帧。

2.2. 模型流程

输入时域音频 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R B × C × L x\in\mathbb{R}^{B\times C\times L} </math>x∈RB×C×L。 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B 是 batch, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C 是通道数, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 是采样数。

首先,对 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 进行分频与映射。

  1. 输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 并 STFT 转换到复频域 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ C B × C × T × F X\in\mathbb{C}^{B\times C\times T\times F} </math>X∈CB×C×T×F。 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 是帧数, <math xmlns="http://www.w3.org/1998/Math/MathML"> F F </math>F 是频段数
  2. 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X,分为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 个不均匀不重叠的子频段数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> X n ∈ C B × C × T × F n X_n\in\mathbb{C}^{B\times C\times T\times F_n} </math>Xn∈CB×C×T×Fn
  3. 每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> X n X_n </math>Xn 经由一个由 RMSNorm 和 Linear 构成的 MLP,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> C × F n C\times F_n </math>C×Fn 维度映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n ∈ C B × T × D H_n\in\mathbb{C}^{B\times T\times D} </math>Hn∈CB×T×D
  4. 将所有 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n H_n </math>Hn 进行拼接,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H ∈ C B × T × N × D H\in\mathbb{C}^{B\times T\times N\times D} </math>H∈CB×T×N×D

记这个 <math xmlns="http://www.w3.org/1998/Math/MathML"> H H </math>H 为 <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl( <math xmlns="http://www.w3.org/1998/Math/MathML"> l ∈ [ 1 , ... , L ] l\in[1, \dots, L] </math>l∈[1,...,L]),然后在时间和频率两个维度分别使用 Transformer。

  1. <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl 变换维度到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B × N ) × T × D (B\times N)\times T\times D </math>(B×N)×T×D
  2. 施加 Transformer
  3. <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl 变换维度到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B × T ) × N × D (B\times T)\times N\times D </math>(B×T)×N×D
  4. 施加 Transformer
  5. 输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> H l + 1 ∈ C B × T × N × D H^{l+1}\in\mathbb{C}^{B\times T\times N\times D} </math>Hl+1∈CB×T×N×D

最终获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H L H^L </math>HL,像开头对分频的处理一样进行映射,获得 cIRM。

  1. <math xmlns="http://www.w3.org/1998/Math/MathML"> H L H^L </math>HL 拥有分频 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n L ∈ C B × T × D H^L_n\in \mathbb{C}^{B\times T\times D} </math>HnL∈CB×T×D
  2. 每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n L H^L_n </math>HnL 经由一个由 RMSNorm、Linear、Tanh、Linear 和 GLU 构成的 MLP,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 映射回 <math xmlns="http://www.w3.org/1998/Math/MathML"> C × F n C\times F_n </math>C×Fn
  3. 拼接分频,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ^ ∈ C B × C × T × F \hat{M}\in \mathbb{C}^{B\times C\times T\times F} </math>M^∈CB×C×T×F

2.3. Transformer 设计

很标准的 Transformer 结构。

  • RMSNorm,采用 Pre-Norm 方案
  • 多头注意力机制
  • RoPE 位置编码
  • GELU 激活的 MLP

除此之外:

  • attention、attention 之后的 proj,以及 MLP 的上下投影,都有 dropout

2.4. 数据集与数据增强

Musdb18HQ 数据集有 100 首歌的训练集和 50 首歌的验证集,采样率 44.1kHz,每首歌分轨 vocal bass drum other 四轨。论文额外使用了自己独家的数据进行训练,有 450 首歌的训练集和 50 首歌的验证集。

评估指标为 SDR(signal-to-distortion ratio),其值越大越好。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> SDR ( y , y ^ ) = 10 log ⁡ 10 ∥ y ∥ 2 ∥ y ^ − y ∥ 2 \text{SDR}(y, \hat{y})=10\log_{10}\frac{\Vert y\Vert^2}{\Vert \hat{y}-y\Vert^2} </math>SDR(y,y^)=10log10∥y^−y∥2∥y∥2

SDR 有个变体叫做 SI-SDR(scale-invariant SDR),会在幅度上做归一化以避免 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 音量变化导致分数急剧增大。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> SI-SDR ( y , y ^ ) = 10 log ⁡ 10 ∥ α y ∥ 2 ∥ y ^ − α y ∥ 2 , α = ⟨ y ^ , y ⟩ ∥ y ∥ 2 \text{SI-SDR}(y, \hat{y})=10\log_{10}\frac{\Vert \alpha y\Vert^2}{\Vert \hat{y}-\alpha y\Vert^2}, \alpha=\frac{\langle\hat{y}, y\rangle}{\Vert y\Vert^2} </math>SI-SDR(y,y^)=10log10∥y^−αy∥2∥αy∥2,α=∥y∥2⟨y^,y⟩

训练使用 8 秒的分段,每次取出所有 4 个分轨。各个分轨施加的增强技巧包含:

  • ±3 dB 的随机音量增益
  • 10% 概率替换为空波形
  • 分轨可能来自不同歌曲、同一歌曲的不同部分, <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 将会是分轨的线性和结果
  • 选取的分段音量必须大于 -50dB

2.5. 模型超参与其他细节

STFT 使用长度为 2048 的 Hann 窗口,间距 10ms。

分频总共分为了 64 段:

  • 1000Hz 以下分为 2 段
  • 1000-2000Hz 分为 4 段
  • 2000-4000Hz 分为 12 段
  • 4000-8000Hz 分为 24 段
  • 8000-16000Hz 分为 48 段
  • 16000Hz 以上分为 2 段

关于分频策略,Mel-Band RoFormer 提出了基于 mel-scale 的分段方案。见后文。

Transformer 层数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 12 L=12 </math>L=12,8 个头,dropout=0.1,

模型超参:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 12 L=12 </math>L=12,Transformer 层数
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> D = 384 D=384 </math>D=384,隐空间维度
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> dropout = 0.1 \text{dropout}=0.1 </math>dropout=0.1
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> MLP Scale = 4 \text{MLP Scale}=4 </math>MLP Scale=4

此时模型参数量为 93.4M。

训练细节:

  • 优化器 AdamW
  • 使用 EMA(Exponential Moving Averagin),decay 设置为 0.999
  • 学习率 <math xmlns="http://www.w3.org/1998/Math/MathML"> 5 × 1 0 − 4 5\times 10^{-4} </math>5×10−4,每 4 万步衰减到原来的 0.9 倍
  • 混合精度训练
  • STFT 操作保持 FP32
  • 每张 GPU 的 batch size 为 128
  • 16×A100 训练 4 周

推理时 8 秒分段,4 秒一间隔,相邻重合的 4 秒推理结果取平均。

3. Mel-Band RoFormer

从心理学上讲,人类对低频敏感对高频迟钝,换句话说就是低频部分需要更高的分辨率才更合适。BS-RoFormer 已经从此受益颇多,本文的 Mel-Band RoFormer 将在这方面做得更科学、更好。

论文使用 librosa 的 librosa.filters.mel 方法获得梅尔分频矩阵。

这个矩阵大概长这样:

text 复制代码
[0.      0.00098 0.      0.      0.      0.      0.      0.      0.     ]
[0.      0.00025 0.00041 0.      0.      0.      0.      0.      0.     ]
[0.      0.      0.00024 0.00036 0.00012 0.      0.      0.      0.     ]
[0.      0.      0.      0.00006 0.0002  0.00023 0.00015 0.00008 0.     ]

每一行都是一个梅尔频谱滤波器。左低频右高频。

可见,滤波器都呈刺状,且低频的滤波器能量集中,高频的滤波器能量分散。

Mel-Band RoFormer 选择把这个矩阵二值化,转换为 mask。

4. 碎碎念

论文是两年前的了,模型放到现在依旧能打。很难不怀疑字节内部已经有更好的模型不想也没必要放出来,闷声发大财。恐怖如斯。

训练花了这么多资源和时间,看来我自己是没办法完全复现了。

这个方案简单直观,好评。但这么说要分 4 个轨道就需要训练 4 个不同的模型?那就不太优雅了。之后查阅代码实现时再看是怎么个情况。

看来纯频域的音频处理就能获得很好的效果,不需要像 HDemucs 那样时域频域混合,那问题简单了不少。需要注意 STFT 时保留虚部从而保留相位信息。

5. 参考来源

相关推荐
音视频牛哥14 分钟前
从H.264到AV1:音视频技术演进与模块化SDK架构全解析
人工智能·音视频·大牛直播sdk·rtsp h.265·h.264 h.265 av1·h.265和h.266·enhenced rtmp
AIbase202424 分钟前
如何快速找到最适合的AI绘画工具?避免在200+工具中挑花眼?
人工智能
机器之心1 小时前
DeepSeek开源新基础模型,但不是V4,而是V3.1-Base
人工智能·openai
金融小师妹1 小时前
AI多因子模型解析:黄金涨势受阻与美联储9月降息政策预期重构
大数据·人工智能·算法
R-G-B1 小时前
【P38 6】OpenCV Python——图片的运算(算术运算、逻辑运算)加法add、subtract减法、乘法multiply、除法divide
人工智能·python·opencv·图片的运算·图片加法add·图片subtract减法·图片乘法multiply
拖拖7651 小时前
解读《Thyme: Think Beyond Images》——让大模型“写代码”思考图像
人工智能
双向331 小时前
模型量化大揭秘:INT8、INT4量化对推理速度和精度的影响测试
人工智能
lisuwen1162 小时前
GPT-5 上线风波深度复盘:从口碑两极到策略调整,OpenAI 的变与不变
大数据·人工智能·gpt·chatgpt
硅谷秋水2 小时前
在相机空间中落地动作:以观察为中心的视觉-语言-行动策略
机器学习·计算机视觉·语言模型·机器人
新智元2 小时前
16 岁天才少年炒掉马斯克,空降华尔街巨头!9 岁上大学,14 岁进 SpaceX
人工智能·openai