文声图防御框架原理笔记:Interpret then Deactivate(ItD)

Sparse Autoencoder as a Zero-Shot Classifier for Concept Erasing in Text-to-Image Diffusion Models 这篇论文提出了一种名为Interpret then Deactivate (ItD) 的框架,旨在文本到图像(T2I)扩散模型中实现精准、可扩展的概念擦除(即移除不想要的概念,如有害内容、特定名人等),同时不影响正常概念的生成。以下从思路、方法原理、数学公式推导三方面详细总结:

一、核心思路

现有概念擦除方法存在两大局限:1)微调模型参数会导致正常概念生成质量下降;2)集成定制模块泛化能力弱且需额外训练。为此,ItD框架通过"解释-停用"两步解决问题:

  1. 解释(Interpret):用稀疏自编码器(SAE)将概念分解为稀疏特征的线性组合,明确概念的特征构成;
  2. 停用(Deactivate):仅停用目标概念特有的特征(排除与正常概念共享的特征),实现精准擦除,同时保留正常概念的生成能力。

此外,SAE被复用为零样本分类器,可判断输入是否包含目标概念,仅在必要时应用擦除,进一步减少对正常概念的影响。

二、方法原理

1. 稀疏自编码器(SAE)的训练

SAE的作用是将文本编码器的语义信息(残差流输出)分解为稀疏特征的组合,为概念"解释"提供基础。

  • 训练对象 :文本编码器中transformer块的残差流输出(即token嵌入 \(e_l^h\)),其中 \(l\)为层索引,\(h\)为token索引)。
  • 模型选择:采用K-稀疏自编码器(KSAE),强制每次重构仅保留K个最大激活特征,确保稀疏性。
  • 核心目标 :使SAE能将输入的token嵌入 \(e\)重构为稀疏特征的线性组合,即 \(e \approx \sum_{\rho=1}^{d_{hid}} z^\rho f_\rho\)(\(z^\rho\)为特征激活值,\(f_\rho\)为解码器矩阵的列向量,即特征向量)。

2. 特征选择:定位目标概念的"特有特征"

为避免擦除正常概念的特征,需筛选出目标概念特有的特征:

  • 步骤1:收集目标概念的相关特征
    对目标概念的每个token,通过SAE获取特征激活值,取每个特征在所有token中的最大激活值,筛选出前 \(K_{sel}\)个高激活特征,构成目标特征集 \(F_{tar}\)。
  • 步骤2:排除与正常概念共享的特征
    用正常概念集(retain set)的特征集 \(F_{retain}\)与 \(F_{tar}\)对比,移除两者共有的特征,得到目标概念特有特征集 \(\hat{F}{tar} = F{tar} \setminus \bigcup F_{retain}\)。
  • 多概念擦除 :对多个目标概念,取其特有特征集的并集 \(F_{erase} = \bigcup \hat{F}_{tar}\)。

3. 概念擦除机制

通过调整目标特征的激活值,移除文本嵌入中目标概念的语义信息:

  • 编码与调整 :文本嵌入经SAE编码为特征激活 \(s\)后,将 \(F_{erase}\)中的特征激活值缩放(如乘以小系数 \(\tau\)),削弱其影响;
  • 解码重构:调整后的激活经SAE解码器重构为新的文本嵌入,该嵌入不再包含目标概念的信息,从而阻止扩散模型生成相关图像。

4. 零样本分类器:选择性擦除

利用SAE的重构损失判断输入是否包含目标概念,仅在包含时应用擦除,减少对正常概念的干扰:

  • 若输入文本嵌入 \(e\)包含目标概念,其经SAE重构后的误差 \(\|e - \hat{e}\|\)较小(因目标特征被调整);
  • 若为正常概念,重构误差较大。通过阈值 \(\tau\)区分:\(G(e) = 1\)(含目标概念,应用擦除)若 \(\|e - \hat{e}\|^2 < \tau\),否则为0(不擦除)。

三、数学公式原理推导

1. SAE的编码器与解码器

  • 编码器 :将输入token嵌入 \(e\)转换为稀疏特征激活 \(z\)

    \[z = \text{TopK}(W_{enc}(e - b_{pre})) \]

    其中,\(W_{enc} \in \mathbb{R}^{d_{hid} \times d_{in}}\)为编码器权重,\(b_{pre}\)为偏置,\(\text{TopK}\)保留前K个最大激活值(其余置0),确保稀疏性。

  • 解码器 :将特征激活 \(z\)重构为嵌入 \(\hat{e}\)

    \[\hat{e} = W_{dec} z + b_{pre} \]

    其中,\(W_{dec} \in \mathbb{R}^{d_{in} \times d_{hid}}\)为解码器权重,每列 \(f_\rho\)为特征向量。

2. SAE的训练损失函数

目标是最小化重构误差并保证特征稀疏性,损失函数为:

\[\mathcal{L}(e) = \|e - \hat{e}\|2^2 + \alpha \mathcal{L}{aux} \]

  • 第一项 \(\|e - \hat{e}\|_2^2\):L2重构损失,确保输入与输出接近;
  • 第二项 \(\alpha \mathcal{L}{aux}\):辅助损失,防止"死特征"(即极少激活的特征)。\(\mathcal{L}{aux}\)定义为使用前 \(K_{aux}\)(\(K_{aux} > K\))个特征的重构误差,\(\alpha\)为权重系数。

3. 特征选择公式

  • 目标概念特征集 \(F_{tar}\):

    \[F_{tar} = \{\rho \mid s_C^\rho \in \text{TopK}(s_C^1, ..., s_C^{d_{hid}})\} \]

    其中 \(s_C^\rho = \max(s_1^\rho, ..., s_H^\rho)\),\(s_h^\rho\)为第h个token的第\(\rho\)个特征激活值,\(H\)为目标概念的token数。

  • 特有特征集 \(\hat{F}_{tar}\):

    \[\hat{F}{tar} = F{tar} \setminus \bigcup_{C_r \in C_{retain}} F_{C_r} \]

    其中 \(C_{retain}\)为正常概念集,\(F_{C_r}\)为正常概念的特征集。

    当需要擦除多个目标概念时,总擦除特征集为各目标特有特征集的并集:

\[F_{erase} = \bigcup_{C \in \mathcal{C}_{tar}} \hat{F}_C \]

其中\(\mathcal{C}_{tar}\)为所有目标概念的集合,该式确保一次擦除多个概念且无需额外训练。

4. 概念擦除的激活调整

对特征激活 \(s\)进行调整,削弱目标特征的影响:

\[\hat{s}^\rho = \begin{cases} s^\rho \cdot \tau & \text{if } \rho \in F_{erase} \\ s^\rho & \text{otherwise} \end{cases} \]

其中 \(\tau\)为缩放系数(如 \(\tau < 1\),削弱激活),调整后通过解码器重构为 \(\hat{e} = W_{dec} \hat{s} + b_{pre}\)。

5. 零样本分类器的判断公式

基于重构损失判断是否包含目标概念:

\[G(e) = \begin{cases} 1 & \text{if } \|e - \hat{e}\|^2 < \tau \\ 0 & \text{otherwise} \end{cases} \]

其中 \(\tau\)为阈值,\(G(e)=1\)时应用擦除,否则直接输出原嵌入。

四、总结

ItD框架通过SAE将概念分解为稀疏特征,结合对比特征选择和选择性擦除,实现了精准、可扩展的概念擦除。数学公式确保了SAE的稀疏性、特征的特异性及擦除的针对性,解决了现有方法对正常概念生成的干扰问题,同时通过零样本分类器进一步提升了鲁棒性。