一、基本符号定义
在 CTC(Connectionist Temporal Classification)中,每帧的输出可以是字符或 blank(空白)。定义以下概率变量:
1. 前缀相关概率
令 hhh 表示一个已折叠(collapsed)的前缀序列,即连续的相同字符已被合并,且 blank 已被删除。
-
rtn(h)r_t^n(h)rtn(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧不是 blank的概率。
rtn(h)=P(π1:t⇒h,πt≠blank) r_t^n(h) = P(\pi_{1:t} \Rightarrow h, \pi_t \neq \text{blank}) rtn(h)=P(π1:t⇒h,πt=blank) -
rtb(h)r_t^b(h)rtb(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧是 blank的概率。
rtb(h)=P(π1:t⇒h,πt=blank) r_t^b(h) = P(\pi_{1:t} \Rightarrow h, \pi_t = \text{blank}) rtb(h)=P(π1:t⇒h,πt=blank)
2. 前缀扩展相关概率
在已有前缀 hhh 后面追加一个字符 ccc时,定义:
-
rtn(h,c)r_t^n(h,c)rtn(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c(字符串拼接),且第 ttt 帧是 ccc的概率。
rtn(h,c)=P(π1:t⇒h+c,πt=c) r_t^n(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = c) rtn(h,c)=P(π1:t⇒h+c,πt=c) -
rtb(h,c)r_t^b(h,c)rtb(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c,且第 ttt 帧是 blank的概率。
rtb(h,c)=P(π1:t⇒h+c,πt=blank) r_t^b(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = \text{blank}) rtb(h,c)=P(π1:t⇒h+c,πt=blank)
二、状态转移推导
1. rtn(h,c)r_t^n(h,c)rtn(h,c) 的递推公式
第 ttt 帧预测为字符 ccc,且collapsed后结果为 h+ch+ch+c,可由两种状态转移而来:
(i) 前 t−1t-1t−1 帧已经得到了 h+ch+ch+c,且第 t−1t-1t−1 帧是 ccc
此时第 ttt 帧再预测 ccc,由于 CTC 的collapsed规则(连续相同字符合并),最终结果仍然是 h+ch+ch+c:
rt−1n(h,c)×P(c∣t) r_{t-1}^n(h,c) \times P(c|t) rt−1n(h,c)×P(c∣t)
(ii) 前 t−1t-1t−1 帧得到了前缀 hhh,且第 ttt 帧预测为 ccc
但需要注意:hhh 的最后一个字符不能是 ccc ,否则 h+ch + ch+c collapsed后会变成 hhh(相同字符合并)。
定义 ϕt−1(h,c)\phi_{t-1}(h, c)ϕt−1(h,c) 表示在时刻 t−1t-1t−1 已得到前缀 hhh,且允许在第 ttt 帧添加 ccc 的概率:
ϕt−1(h,c)={rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)rt−1n(h)+rt−1b(h),如果 c≠yL \phi_{t-1}(h, c) = \begin{cases} r^b_{t-1}(h), & \text{如果 } c = y_L \text{(c 等于 h 的最后一个字符)} \\ r^n_{t-1}(h) + r^b_{t-1}(h), & \text{如果 } c \neq y_L \end{cases} ϕt−1(h,c)={rt−1b(h),rt−1n(h)+rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)如果 c=yL
其中 yLy_LyL 是前缀 hhh 的最后一个字符。
综合两种情况,递推公式为:
rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t) r_t^n(h,c) = \big( r_{t-1}^n(h,c) + \phi_{t-1}(h, c) \big) \times P(c|t) rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t)
ϕ\phiϕ代码实现:
python
r_sum = torch.logsumexp(r_prev, 1) # 对 r^n 和 r^b 求和
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
# 特殊处理:若 c 等于 h 的最后一个字符,只能从 blank 状态转移
if scoring_ids is not None:
for idx in range(n_bh):
pos = scoring_idmap[idx, last_ids[idx]]
if pos >= 0:
log_phi[:, idx, pos] = r_prev[:, 1, idx]
2. rtb(h,c)r_t^b(h,c)rtb(h,c) 的递推公式
第 ttt 帧输出为 blank,且collapsed后结果为 h+ch+ch+c,则前 t−1t-1t−1 帧必须已经得到 h+ch+ch+c:
rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t) r_t^b(h,c) = \big( r_{t-1}^n(h,c) + r_{t-1}^b(h,c) \big) \times P(\text{blank}|t) rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t)
rtn(h,c)r_t^n(h,c)rtn(h,c)与rtb(h,c)r_t^b(h,c)rtb(h,c)代码实现:
python
for t in range(start, end):
rp = r[t - 1]
#rr包含了r_t^n和r_t^b
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
r[t] = torch.logsumexp(rr, 1) + x_[:, t] # 加上当前帧的发射概率
3. 前缀概率更新
当计算完 rtn(h,c)r_t^n(h,c)rtn(h,c) 和 rtb(h,c)r_t^b(h,c)rtb(h,c) 后,它们实际上对应新前缀 h+ch+ch+c 的状态:
rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c) r_t^n(h+c) = r_t^n(h,c), \quad r_t^b(h+c) = r_t^b(h,c) rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c)
三、CTC 前缀分数计算
CTC Prefix Score是指以h+ch+ch+c作为前缀的概率,对于固定时刻 ttt,collapsed后结果为 h+ch+ch+c 的概率为:
rt(h+c)=rtn(h,c)+rtb(h,c) r_t(h+c) = r_t^n(h,c) + r_t^b(h,c) rt(h+c)=rtn(h,c)+rtb(h,c)
但 CTC 前缀得分不是简单对所有时刻的 rt(h+c)r_t(h+c)rt(h+c) 求和,原因在于同一条路径可能在不同时刻对应相同的前缀。例如,若在 t1t_1t1 时刻首次得到前缀 h+ch+ch+c,之后在 t2>t1t_2 > t_1t2>t1 时刻,所有帧输出均为 ccc 或 blank,则collapsed结果仍为 h+ch+ch+c。这样,同一条路径会被重复计入不同时刻的概率中。
应计算在时刻 ttt 首次生成前缀 h+ch+ch+c 的概率:
P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t) P(\pi_{1:t} \Rightarrow h + c) = \phi_{t-1}(h, c) \times P(c|t) P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t)
对时间ttt求和得到CTC 前缀得分ψ\psiψ:
ψ(h+c)=∑tP(π1:t⇒h+c)=∑tϕt−1(h,c)×P(c∣t) \psi(h + c) = \sum_{t} P(\pi_{1:t} \Rightarrow h + c) = \sum_{t} \phi_{t-1}(h, c) \times P(c|t) ψ(h+c)=t∑P(π1:t⇒h+c)=t∑ϕt−1(h,c)×P(c∣t)
ψ\psiψ代码实现:
python
# 计算所有时刻的首次出现概率
log_phi_x = log_phi + x_[:, :, None] # 加上字符c的概率
log_psi_ = torch.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0
)
# 将结果映射到对应的字符索引
for si in range(n_bh):
log_psi[si, scoring_ids[si]] = log_psi_[si]