留下阅读 (2023)Music Source Separation with Band-Split RoPE Transformer 和(2023)Mel-Band RoFormer for Music Source Separation 的痕迹。
论文提出的模型针对音乐分离任务(Music source separation, MSS),用于分离音频中的元素,例如将音乐分离成伴奏和人声。
1. 总览
从 mvsep 的 leaderboard 来看,BS-RoFormer 和 Mel-Band RoFormer 性能领先,是目前的 SOTA(20250819)。两个模型都是字节跳动做出来的,Mel-Band RoFormer 是 BS-RoFormer 的变体。
BS-RoFormer 已经达成了 SOTA,但其频段分割方案没有理论支撑,是凭经验确定的。这个 Mel-Band RoFormer 就是来补足这个缺陷,根据 mel scale 设计出符合人类耳朵的频段分割方案。
数据集使用 MUSDB18HQ,验证方法有效性。
2. BS-RoFormer
2.1. 实现思路与损失函数
设 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 代表输入音频, <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是其复频域形式。
期望 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 经过含有可学习参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的网络层 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_\theta </math>fθ 预测出 cIRMs,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> M ^ = f θ ( X ) \hat{M}=f_\theta(X) </math>M^=fθ(X)
借由这个 cIRMs 进行音频分离,最后就可以用逆短时傅里叶变换 iSTFT 获得预测结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Y ^ = M ^ ⊙ X \hat{Y}=\hat{M}\odot X </math>Y^=M^⊙X
什么是 cIRM?
理想比值掩蔽(Ideal Ratio Mask, IRM),可以理解为频谱的 mask。原频谱乘上 IRM 就能分离出想要的元素。
与之对应的是理想二值掩蔽(Ideal Binary Mask, IBM),这个 mask 就不是连续取值,而是只能取 0 和 1。
模型使用复数频谱,那就需要 complex IRMs,简称 cIRMs。
一句话总结:模型生成一个 mask,音频在复频域乘上这个 mask,获得推理结果。
关于损失函数,先看公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> loss = ∥ y − y ^ ∥ + ∑ s = 0 S − 1 ∥ Y ( s ) − Y ^ ( s ) ∥ \text{loss}= \Vert y-\hat{y} \Vert + \sum^{S-1}_{s=0}\Vert Y^{(s)}-\hat{Y}^{(s)} \Vert </math>loss=∥y−y^∥+s=0∑S−1∥Y(s)−Y^(s)∥
第一项很好理解,时域上的目标 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 与预测结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 的 MAE。
第二项是因为模型用到了多个不同窗口大小的 STFT,所以在频域上也计算 MAE。
模型的 STFT 窗口大小有 [4096, 2048, 1024, 512, 256],步长 147,等价于每秒 300 帧。
2.2. 模型流程
输入时域音频 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R B × C × L x\in\mathbb{R}^{B\times C\times L} </math>x∈RB×C×L。 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B 是 batch, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C 是通道数, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 是采样数。
首先,对 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 进行分频与映射。
- 输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 并 STFT 转换到复频域 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ C B × C × T × F X\in\mathbb{C}^{B\times C\times T\times F} </math>X∈CB×C×T×F。 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 是帧数, <math xmlns="http://www.w3.org/1998/Math/MathML"> F F </math>F 是频段数
- 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X,分为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 个不均匀不重叠的子频段数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> X n ∈ C B × C × T × F n X_n\in\mathbb{C}^{B\times C\times T\times F_n} </math>Xn∈CB×C×T×Fn
- 每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> X n X_n </math>Xn 经由一个由 RMSNorm 和 Linear 构成的 MLP,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> C × F n C\times F_n </math>C×Fn 维度映射为 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n ∈ C B × T × D H_n\in\mathbb{C}^{B\times T\times D} </math>Hn∈CB×T×D
- 将所有 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n H_n </math>Hn 进行拼接,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H ∈ C B × T × N × D H\in\mathbb{C}^{B\times T\times N\times D} </math>H∈CB×T×N×D
记这个 <math xmlns="http://www.w3.org/1998/Math/MathML"> H H </math>H 为 <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl( <math xmlns="http://www.w3.org/1998/Math/MathML"> l ∈ [ 1 , ... , L ] l\in[1, \dots, L] </math>l∈[1,...,L]),然后在时间和频率两个维度分别使用 Transformer。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl 变换维度到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B × N ) × T × D (B\times N)\times T\times D </math>(B×N)×T×D
- 施加 Transformer
- <math xmlns="http://www.w3.org/1998/Math/MathML"> H l H^l </math>Hl 变换维度到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B × T ) × N × D (B\times T)\times N\times D </math>(B×T)×N×D
- 施加 Transformer
- 输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> H l + 1 ∈ C B × T × N × D H^{l+1}\in\mathbb{C}^{B\times T\times N\times D} </math>Hl+1∈CB×T×N×D
最终获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> H L H^L </math>HL,像开头对分频的处理一样进行映射,获得 cIRM。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> H L H^L </math>HL 拥有分频 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n L ∈ C B × T × D H^L_n\in \mathbb{C}^{B\times T\times D} </math>HnL∈CB×T×D
- 每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> H n L H^L_n </math>HnL 经由一个由 RMSNorm、Linear、Tanh、Linear 和 GLU 构成的 MLP,将 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 映射回 <math xmlns="http://www.w3.org/1998/Math/MathML"> C × F n C\times F_n </math>C×Fn
- 拼接分频,获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> M ^ ∈ C B × C × T × F \hat{M}\in \mathbb{C}^{B\times C\times T\times F} </math>M^∈CB×C×T×F
2.3. Transformer 设计
很标准的 Transformer 结构。
- RMSNorm,采用 Pre-Norm 方案
- 多头注意力机制
- RoPE 位置编码
- GELU 激活的 MLP
除此之外:
- attention、attention 之后的 proj,以及 MLP 的上下投影,都有 dropout
2.4. 数据集与数据增强
Musdb18HQ 数据集有 100 首歌的训练集和 50 首歌的验证集,采样率 44.1kHz,每首歌分轨 vocal bass drum other 四轨。论文额外使用了自己独家的数据进行训练,有 450 首歌的训练集和 50 首歌的验证集。
评估指标为 SDR(signal-to-distortion ratio),其值越大越好。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> SDR ( y , y ^ ) = 10 log 10 ∥ y ∥ 2 ∥ y ^ − y ∥ 2 \text{SDR}(y, \hat{y})=10\log_{10}\frac{\Vert y\Vert^2}{\Vert \hat{y}-y\Vert^2} </math>SDR(y,y^)=10log10∥y^−y∥2∥y∥2
SDR 有个变体叫做 SI-SDR(scale-invariant SDR),会在幅度上做归一化以避免 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^ 音量变化导致分数急剧增大。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> SI-SDR ( y , y ^ ) = 10 log 10 ∥ α y ∥ 2 ∥ y ^ − α y ∥ 2 , α = ⟨ y ^ , y ⟩ ∥ y ∥ 2 \text{SI-SDR}(y, \hat{y})=10\log_{10}\frac{\Vert \alpha y\Vert^2}{\Vert \hat{y}-\alpha y\Vert^2}, \alpha=\frac{\langle\hat{y}, y\rangle}{\Vert y\Vert^2} </math>SI-SDR(y,y^)=10log10∥y^−αy∥2∥αy∥2,α=∥y∥2⟨y^,y⟩
训练使用 8 秒的分段,每次取出所有 4 个分轨。各个分轨施加的增强技巧包含:
- ±3 dB 的随机音量增益
- 10% 概率替换为空波形
- 分轨可能来自不同歌曲、同一歌曲的不同部分, <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 将会是分轨的线性和结果
- 选取的分段音量必须大于 -50dB
2.5. 模型超参与其他细节
STFT 使用长度为 2048 的 Hann 窗口,间距 10ms。
分频总共分为了 64 段:
- 1000Hz 以下分为 2 段
- 1000-2000Hz 分为 4 段
- 2000-4000Hz 分为 12 段
- 4000-8000Hz 分为 24 段
- 8000-16000Hz 分为 48 段
- 16000Hz 以上分为 2 段
关于分频策略,Mel-Band RoFormer 提出了基于 mel-scale 的分段方案。见后文。
Transformer 层数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 12 L=12 </math>L=12,8 个头,dropout=0.1,
模型超参:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 12 L=12 </math>L=12,Transformer 层数
- <math xmlns="http://www.w3.org/1998/Math/MathML"> D = 384 D=384 </math>D=384,隐空间维度
- <math xmlns="http://www.w3.org/1998/Math/MathML"> dropout = 0.1 \text{dropout}=0.1 </math>dropout=0.1
- <math xmlns="http://www.w3.org/1998/Math/MathML"> MLP Scale = 4 \text{MLP Scale}=4 </math>MLP Scale=4
此时模型参数量为 93.4M。
训练细节:
- 优化器 AdamW
- 使用 EMA(Exponential Moving Averagin),decay 设置为 0.999
- 学习率 <math xmlns="http://www.w3.org/1998/Math/MathML"> 5 × 1 0 − 4 5\times 10^{-4} </math>5×10−4,每 4 万步衰减到原来的 0.9 倍
- 混合精度训练
- STFT 操作保持 FP32
- 每张 GPU 的 batch size 为 128
- 16×A100 训练 4 周
推理时 8 秒分段,4 秒一间隔,相邻重合的 4 秒推理结果取平均。
3. Mel-Band RoFormer
从心理学上讲,人类对低频敏感对高频迟钝,换句话说就是低频部分需要更高的分辨率才更合适。BS-RoFormer 已经从此受益颇多,本文的 Mel-Band RoFormer 将在这方面做得更科学、更好。
论文使用 librosa 的 librosa.filters.mel
方法获得梅尔分频矩阵。
这个矩阵大概长这样:
text[0. 0.00098 0. 0. 0. 0. 0. 0. 0. ] [0. 0.00025 0.00041 0. 0. 0. 0. 0. 0. ] [0. 0. 0.00024 0.00036 0.00012 0. 0. 0. 0. ] [0. 0. 0. 0.00006 0.0002 0.00023 0.00015 0.00008 0. ]
每一行都是一个梅尔频谱滤波器。左低频右高频。
可见,滤波器都呈刺状,且低频的滤波器能量集中,高频的滤波器能量分散。
Mel-Band RoFormer 选择把这个矩阵二值化,转换为 mask。
4. 碎碎念
论文是两年前的了,模型放到现在依旧能打。很难不怀疑字节内部已经有更好的模型不想也没必要放出来,闷声发大财。恐怖如斯。
训练花了这么多资源和时间,看来我自己是没办法完全复现了。
这个方案简单直观,好评。但这么说要分 4 个轨道就需要训练 4 个不同的模型?那就不太优雅了。之后查阅代码实现时再看是怎么个情况。
看来纯频域的音频处理就能获得很好的效果,不需要像 HDemucs 那样时域频域混合,那问题简单了不少。需要注意 STFT 时保留虚部从而保留相位信息。
5. 参考来源
- 索罗格,"基于Mask的语音分离",zhuanlan.zhihu.com/p/139077771