代码层面上解读ACE-Step

总览

ACE-Step 是文生音频模型。比 LLM 方法更快,目标是成为音频生成领域的 Stable Diffusion。好大的口气。

论文还没有发布,就只能看代码了。

arxiv 上有论文预印本。后面有时间看一看。

模型构成

先从 ACEStepPipeline 的代码里看看模型大致构成。

从需要下载的权重来看,主要分为三部分:

  • self.ace_step_transformerACEStepTransformer2DModel,权重名称 ace_step_transformer,大小在 6.2GB。是整个模型的重点
  • self.music_dcaeMusicDCAE,权重名称 music_dcae_f8c8 和 music_vocoder,来源于 Deep Compression AutoEncoder for Efficient High-Resolution Diffusion Models(2024),大小在 300MB
  • self.text_encoder_modelUMT5EncoderModel,权重名称 umt5-base,大小在 1.1GB,是个文本编码器。用于将提示词转换为嵌入,维度 768

就像 Stable Diffusion 一样,模型使用到了许多其他项目的代码片段。以下是一些我发现的:

self.lyric_tokenizer,取自 coqui-ai 团队 TTS 项目代码的 VoiceBpeTokenizer,是一个支持多种语言的分词器。用于将歌词转换为 ids。

self.ace_step_transformer.lyric_encoder,取自 FunAudioLLM 团队 CosyVoice 项目的 ConformerEncoder。用于将歌词 ids 编码为嵌入,维度 1024。应该是用预训练模型集成进来后经过了微调。

self.ace_step_transformer.rotary_emb,取自 transformers 库中的 mixtral.modeling_mixtral.MixtralRotaryEmbedding。用于获得 RoPE 编码。比较奇怪的是项目把 Mixtral 换成了 Qwen2 这个字眼,原因未知。

self.scheduler 改变自 diffusors 库的 FlowMatchEulerDiscreteScheduler,主要改动是 step() 函数在叠加噪声时增加了 omega 和 均值控制(具体来说,假设 prev_sample = sample + (sigma_next - sigma) * model_output 是原来的实现,那么现在的实现就变为了 dx = (sigma_next - sigma) * model_outputprev_sample = sample + (dx - dx.mean()) * omega + dx.mean())。

附上论文里给出的模型结构示意图:

推理步骤

ACEStepPipline

查阅 ACEStepPipeline 代码,里面除了 get_text_embeddings() 函数用于获得文本嵌入外,还有个长得很像的 get_text_embeddings_null()。区别是后者对 attention 层中的线性层 q 绑定了一个 hook,让 query 总是被乘上 0.01。总共有 12 个 encoder 层,这个操作只影响其中第 9 第 10 层。要解释这个操作,应该是让注意力分数分散到各个 token,使得注意更加全局的信息。最终会获得 encoder_text_hidden_states_null

查阅 ACEStepPipeline 代码,里面除了 self.ace_step_transformer.encode() 用于获得整体潜在嵌入外,还有个功能很像的 forward_encoder_with_temperature()。区别是后者对 attention 层中的线性层 q 绑定了一个 hook,让 query 总是被乘上 0.01。总共有 6 个 encoder 层,这个操作只影响其中第 5 第 6 层。要解释这个操作,应该是让注意力分数分散到各个歌词 token,使得注意更加全局的歌词信息。最终会获得 encoder_hidden_states_null

self.text_tokenizer 来自 t5,self.lyric_tokenizer 是专门用来转换歌词的 tokenizer。

self.text2music_diffusion_process() 降噪完毕后,使用 self.latents2audio 将潜在表示解码为音频。会用到 MusicDCAE 模型。

ACEStepPipline.text2music_diffusion_process()

文生音频是要进入到 self.text2music_diffusion_process() 进行的。函数内涉及了采样方法定义、target_latents 的生成,以及调用 self.ace_step_transformer.encode() 获得控制信息嵌入、反复调用 self.ace_step_transformer.decode() 进行去噪的过程。

去噪前值得注意的地方:

  • 和 Stable Diffusion 3 一样使用 FlowMatchEulerDiscreteScheduler 进行采样,shift 为 3,获得采样的步骤调度。该类改写于 diffusors 库同名类,主要改动是 step() 函数在叠加噪声时增加了 omega 和 均值控制
  • 时间与 latent 序列长度的转换规则:x * 44100 / 512 / 8。其中 x 单位为秒。
  • 使用 diffusers 库的 randn_tensor() 方法,结合预先准备的 torch.Generator 对象 random_generators,为各个 batch 生成不同的随机 target_latents。维度 [batch, 8, 16, frame_length]

去噪时值得注意的地方:

  • guidance_interval 作为输入参数,会控制 Classifier-Free Guidance 的介入步骤时机。基本思路是,在初始阶段自由探索潜在空间,在中间阶段受到引导引导,在最终阶段再次自由发展(TODO 理论该是这样。但代码在介入时机外反而使用了条件引导。也许需要看论文才能知道为什么)
  • Classifier-Free Guidance 介入时,无条件的推理类似于 encoder_text_hidden_states_nullencoder_hidden_states_null,会绑定 hook 到自注意力和交叉注意力的 to_q,使得 15 到 20 层的 query 总被乘上 0.01
  • Classifier-Free Guidance 介入时,条件输出和无条件输出的合并不是直接相加那样简单(被称为 CFG 方法),而是使用名为 APG 的方案进行合并:
    • 两者相减获得 diff,在 MomentumBuffer 的辅佐下平滑 diff(running average 被更新为平滑后的 diff)
    • 自适应乘上一个 scale 以避免 diff 二范数过大
    • diff 分解为两个向量 diff_orthogonal diff_parallel,分别平行和正交于条件输出
    • normalized_update = diff_orthogonal + eta * diff_parallel
    • pred_guided = pred_cond + (guidance_scale - 1) * normalized_update

ACEStepTransformer2DModel.encode()

self.ace_step_transformer.encode() 负责将提示词嵌入、歌词 ids 和不知道干什么的 speaker_embds 融合成一个统一的 hidden states。

输入参数,

  • encoder_text_hidden_states,提示词文本
  • text_attention_mask,在 token 层面的提示词文本 mask
  • speaker_embeds,维度在 [batch, 512] 的、目前没有输入途径的变量。现在是全 0 代替
  • lyric_token_idx,歌词,还在 ids 形态
  • lyric_mask,在 token 层面的歌词 mask

其中 peaker_embeds encoder_text_hidden_states lyric_token_idx,分别会被:

  • self.genre_embedder,Linear(in_features=768, out_features=2560, bias=True)
  • self.speaker_embedder,Linear(in_features=512, out_features=2560, bias=True)
  • self.forward_lyric_encoder,包含
    • self.lyric_embs,Embedding(6693, 1024)
    • self.lyric_encoder,借用了 CosyVoice 项目的 ConformerEncoder
    • self.lyric_proj,Linear(in_features=1024, out_features=2560, bias=True)

获得:

  • encoder_spk_hidden_states
  • encoder_text_hidden_states
  • encoder_lyric_hidden_states

三个 tensor 的通道数量都为 2560。在 seq 维度上拼接为一个大一统的 encoder_hidden_states。顺便会拼接出对应的 mask encoder_hidden_mask

ACEStepTransformer2DModel.decode()

self.ace_step_transformer.decode() 用于计算噪声,输入有 代表音频的 hidden_states、控制信息 encoder_hidden_states 和时间步 timestep

时间步 timestep 通过 diffusers.models.embeddings.Timesteps 转换为正余弦编码,然后经过 Linear(256, 2560) > SiLU > Linear(2560, 2560) 获得 embedded_timestep,然后进一步 SiLU > Linear(2560, 15360) 获得 temb

hidden_states 维度是 [batch, 8, 16, frame_length],此时视为 BCHW 进行二维卷积。通过

  • Conv2d(8, 2048, kernel_size=(16, 1), stride=(16, 1)),抹掉 H 维
  • GroupNorm(32, 2048, eps=1e-06, affine=True)
  • Conv2d(2048, 2560, kernel_size=(1, 1), stride=(1, 1)),只升维

与此同时加上维度变换操作,维度变化:

  • batch, 8, 16, frame_length

  • 经过 Conv2d,[batch, 2048, 1, frame_length]
  • 经过 Conv2d,[batch, 2560, 1, frame_length]
  • batch, frame_length, 2560

现在使用 self.rotary_emb 获得 RoPE 位置编码。获得了 hidden_states 的位置编码 rotary_freqs_cisencoder_hidden_states 的位置编码 encoder_rotary_freqs_cis

项目用的位置编码不太讲究。

控制信息 encoder_hidden_states 被视为连续序列获得了位置编码,有点不符合常理。毕竟控制信息是三个不连续不同来源的张量拼接而来的。

另外模型涉及到音频特征和控制信息的交叉注意力,但这两个张量是用同一个方法获取的 RoPE 编码。全然不顾 RoPE 编码在交叉注意力时内积的意义------RoPE 拥有一定相对位置编码的特性,序列的内积结果会根据相对位置距离的变化而变化。

这样都能训练收敛得到还行的效果,只能说明模型自己足够鲁棒。

万事俱备。总结一下现在手上有哪些张量:

  1. hidden_states,代表音频本身
  2. attention_mask,应该是音频的 mask。被设置为了全 1
  3. encoder_hidden_states,控制信息
  4. encoder_hidden_mask,控制信息 mask
  5. rotary_freqs_cis,音频的位置编码
  6. encoder_rotary_freqs_cis,控制信息的位置编码
  7. temb,时间步编码

接下来将这些张量输入到 self.transformer_blocks 不断对 hidden_states 进行降噪处理。总共有 24 个 block,每个 block 是 LinearTransformerBlock 对象。

在完成特定 block 的处理后,会存储当前的 hidden_states,用在之计算出额外的 loss,对齐音频特征 hidden_states 与 mert 和 hubert 模型输出的特征 ssl_hidden_states。这个过程具体来说,先用包含六个线形层的 MLP 映射一下 inner_hidden_states,然后用 F.interpolate 和 F.normalize 将 inner_hidden_statesssl_hidden_states 转换到同一尺度,最后用 nn.CosineEmbeddingLoss 获得损失。若是在推理阶段则不会传入 ssl_hidden_states,也就不会计算这个 loss 了。

最后经过 self.final_layer,一个 RMSNorm 和一个 Linear(2560 -> 128)。完成一个时间步的噪声预测。

LinearTransformerBlock

有一个 AdaLN 层。将已经从正余弦编码映射到 15360 通道的 temb reshape 到维度 [6, 2560],加上一个大小同样是 [6, 2560] 的可学习权重后,分出六个张量,分别作为 shift_msa scale_msa gate_msa shift_mlp scale_mlp gate_mlp

AdaLN 是把控制信息融入到 Transformer 的一种方法,在 DiT 中非常常见。在此不展开说明。

接下来经过以下步骤。使用伪代码展示。

python 复制代码
# norm
norm_hidden_states = RMSNorm(hidden_states)
# AdaLN
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

# self attention
attn_output = SelfAttention(norm_hidden_states)  # 线性注意力
# AdaLN
attn_output = gate_msa * attn_output
# residual
hidden_states = attn_output + hidden_states # 残差

# cross attention
attn_output = CrossAttention(hidden_states, encoder_hidden_states)  # 点积注意力
# residual
hidden_states = attn_output + hidden_states  # 残差

# norm
norm_hidden_states = RMSNorm(hidden_states)
# AdaLN
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

# mlp
ff_output = GLUMBConv(norm_hidden_states)
# AdaLN
ff_output = gate_mlp * ff_output
# residual
hidden_states = hidden_states + ff_output  # 残差

return hidden_states

SelfAttention 和 CrossAttention 部分借助了 diffusers 库,并自定义其中的 processor。自注意力和交叉注意力分别自定义为了 CustomLiteLAProcessor2_0CustomerAttnProcessor2_0。前者实现了 ReLU 激活的线性注意力,后者实现了标准的点积注意力。

GLUMBConv,

python 复制代码
x = x.transpose(1, 2)
x = ConvLayer(x)  # 2560 -> 12800, kernel=1, act=SiLU
x = ConvLayer(x)  # 12800 -> 12800, kernel=3, padding=1, groups=12800

x, gate = torch.chunk(x, 2, dim=1)
gate = SiLU(gate)
x = x * gate

x = ConvLayerv(x)  # 6400 -> 2560, kernel=1, padding=1, bias=False
x = x.transpose(1, 2)

ConvLayer,

python 复制代码
x = Conv1d(x)
if self.act:
    x = self.act(x)
return x

额外说明

线性 attention

项目使用了最经典的 linear attention 解决方案,就是去掉 softmax 并使用 ReLU 激活保证 qeury 和 key 的内积非负,将 (query × key) × value 的运算转变为 query × (key × value)

Adaptive Projected Guidance(APG)

Classifier-Free Guidance(CFG)方法允许控制单步预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ \hat{x} </math>x^ 遵循控制信息 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 的力度。以下是获得一步噪声的方法。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ = ( 1 − γ ) m o d e l ( ∅ ) + γ m o d e l ( y ) \hat{x}=(1-\gamma)\mathrm{model}(\varnothing) + \gamma\mathrm{model}(y) </math>x^=(1−γ)model(∅)+γmodel(y)

可见,是通过条件生成结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> m o d e l ( y ) \mathrm{model}(y) </math>model(y) 和无条件生成结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> m o d e l ( ∅ ) \mathrm{model}(\varnothing) </math>model(∅) 的线性组合获得的。其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 代表指导力度。

Adaptive Projected Guidance 是论文 (2024) Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models 提出的另一种组合方案。可用以下式子表示(写得有点乱。可以直接看后面的 Python 代码示例):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Δ D t = m o d e l ( y ) − m o d e l ( ∅ ) Δ D t = ∫ 0 t Δ D t β ( t − i ) d i Δ D t = Δ D t ⋅ min ⁡ ( 1 , r ∥ Δ D t ∥ ) Δ D t parallel , Δ D t orthogonal = proj ( Δ D t , m o d e l ( y ) ) Δ D t = Δ D t orthogonal + η ⋅ Δ D t parallel x ^ = m o d e l ( y ) + ( γ − 1 ) ⋅ Δ D t \begin{aligned} &\Delta D_t=\mathrm{model}(y)-\mathrm{model}(\varnothing)\\ &\Delta D_t=\int^t_0 \Delta D_t\beta^{(t-i)}\mathrm{d}i\\ &\Delta D_t=\Delta D_t·\min(1, \frac{r}{\Vert \Delta D_t\Vert})\\ &\Delta D_t^\text{parallel}, \Delta D_t^\text{orthogonal}=\text{proj}(\Delta D_t, \mathrm{model}(y))\\ &\Delta D_t=\Delta D_t^\text{orthogonal} + \eta ·\Delta D_t^\text{parallel}\\ &\hat{x}=\mathrm{model}(y) + (\gamma-1)·\Delta D_t \end{aligned} </math>ΔDt=model(y)−model(∅)ΔDt=∫0tΔDtβ(t−i)diΔDt=ΔDt⋅min(1,∥ΔDt∥r)ΔDtparallel,ΔDtorthogonal=proj(ΔDt,model(y))ΔDt=ΔDtorthogonal+η⋅ΔDtparallelx^=model(y)+(γ−1)⋅ΔDt

  • 第一个式子获取条件输出于非条件输出的差距 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ D t \Delta D_t </math>ΔDt
  • 第二个式子是想表达 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ D t ← Δ D t + β ⋅ Δ D t − 1 \Delta D_t \leftarrow \Delta D_t + \beta ·\Delta D_{t-1} </math>ΔDt←ΔDt+β⋅ΔDt−1 的意思。其中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β β </math>β 要求为负数
  • 第三个式子避免 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ D t \Delta D_t </math>ΔDt 一步更新过多,缩放到 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r 范围内
  • 第四个式子获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ D t \Delta D_t </math>ΔDt 关于 <math xmlns="http://www.w3.org/1998/Math/MathML"> m o d e l ( y ) \mathrm{model}(y) </math>model(y) 平行和正交的两个分量
  • 第五个式子使用超参 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 重组 <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ D t \Delta D_t </math>ΔDt
  • 第六个式子正式获得 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ \hat{x} </math>x^

可见,当 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = 0 , r = ∞ , η = 1 \beta=0,r=\infin,\eta=1 </math>β=0,r=∞,η=1 时,APG 等价于原始的 CFG 方法。

论文给出的实践结果可以看出,取值 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = − 0.75 , r = 2.5 , η = 0 \beta=-0.75,r=2.5,\eta=0 </math>β=−0.75,r=2.5,η=0 效果通常最好。ACE-Step 就是用的这组超参。

APG 即插即用,能直接用于现有的扩散模型。以下是论文给出的 Python 代码(摘自 github.com/MythicalChu... )。

python 复制代码
import torch

class MomentumBuffer:
    def __init__(self, momentum: float):
        self.momentum = momentum
        self.running_average = 0
        
    def update(self, update_value: torch.Tensor):
        new_average = self.momentum * self.running_average
        self.running_average = update_value + new_average
        
def project( v0: torch.Tensor, v1: torch.Tensor,):
    dtype = v0.dtype
    #v0, v1 = v0.double(), v1.double()
    v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
    v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
    v0_orthogonal = v0 - v0_parallel
    return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
    
def normalized_guidance( pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, momentum_buffer: MomentumBuffer = None, eta: float = 1.0, norm_threshold: float = 0.0,):
    diff = pred_cond - pred_uncond
    if momentum_buffer is not None:
        momentum_buffer.update(diff)
        diff = momentum_buffer.running_average
    if norm_threshold > 0:
        ones = torch.ones_like(diff)
        diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
        scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
        diff = diff * scale_factor
    diff_parallel, diff_orthogonal = project(diff, pred_cond)
    normalized_update = diff_orthogonal + eta * diff_parallel
    pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
    
    return pred_guided

目前(20250609)diffusers 库上已有 APG 相关的 Pull Requests,暂时没并入仓库。

临时

TODO

  • random_generators 是 torch.Generator 对象的 list。这是怎样被用上的?
    • 答:使用 diffusers 库的 randn_tensor 方法生成 target_latents
  • oss_steps(默认为空)是什么意思?
    • 答:是 diffusion 的 Euler 采样器的备选方案,手动指定采样步骤。
  • use_erg_tag(默认为 True)控制了什么行为
    • 会获取一个 encoder_text_hidden_states_null 作为 encoder_text_hidden_states 的补充?
  • use_erg_lyric(默认为 True)控制了什么行为
  • use_erg_diffusion(默认为 True)控制了什么行为
  • speaker_embeds 似乎是可选的额外参数,但没有公开输出形式。目前是全 zero,维度 [batch, 512]
  • retake 意味着什么?似乎是一个加噪操作,需要 task in ("retake", "repaint", "extend") 为真才生效
  • momentum_buffer 是干嘛的?
    • 用于 APG 方法的 Classifier-Free Guidance,平滑降噪过程

其他参考来源

相关推荐
小天才才20 分钟前
【大模型】解耦大语言模型中的记忆与推理能力
人工智能·深度学习·语言模型·自然语言处理
AI大模型学习教程33 分钟前
前端学AI之LangChain.js入门教程:实现智能对话机器人
人工智能·langchain
Java中文社群40 分钟前
超实用!手把手教你Dify版本升级
人工智能·后端
nbbsn1 小时前
第四十天打卡
python·深度学习·机器学习
奔跑吧邓邓子1 小时前
DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”
人工智能·deepseek·无人农场·协同作业·农田管理·神经网
面朝大海,春不暖,花不开1 小时前
Spring AI与Spring Modulith核心技术解析
人工智能·spring·flask
jieshenai1 小时前
Mac M4 芯片运行大模型指南,包括模型微调与推理
人工智能·自然语言处理
爱写代码的小朋友1 小时前
破局与重构:人工智能深度赋能基础教育变革研究
人工智能·重构
聚客AI2 小时前
大厂特邀大咖万字深度穿透:Transformer核心模块实现细节大揭秘
人工智能·神经网络·掘金·日新计划
Blossom.1182 小时前
基于区块链的供应链溯源系统:构建与实践
人工智能·python·深度学习·机器学习·计算机视觉·flask·区块链