CTC Prefix Score计算

一、基本符号定义

在 CTC(Connectionist Temporal Classification)中,每帧的输出可以是字符或 blank(空白)。定义以下概率变量:

1. 前缀相关概率

令 hhh 表示一个已折叠(collapsed)的前缀序列,即连续的相同字符已被合并,且 blank 已被删除。

  • rtn(h)r_t^n(h)rtn(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧不是 blank的概率。
    rtn(h)=P(π1:t⇒h,πt≠blank) r_t^n(h) = P(\pi_{1:t} \Rightarrow h, \pi_t \neq \text{blank}) rtn(h)=P(π1:t⇒h,πt=blank)

  • rtb(h)r_t^b(h)rtb(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧是 blank的概率。
    rtb(h)=P(π1:t⇒h,πt=blank) r_t^b(h) = P(\pi_{1:t} \Rightarrow h, \pi_t = \text{blank}) rtb(h)=P(π1:t⇒h,πt=blank)

2. 前缀扩展相关概率

在已有前缀 hhh 后面追加一个字符 ccc时,定义:

  • rtn(h,c)r_t^n(h,c)rtn(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c(字符串拼接),且第 ttt 帧是 ccc的概率。
    rtn(h,c)=P(π1:t⇒h+c,πt=c) r_t^n(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = c) rtn(h,c)=P(π1:t⇒h+c,πt=c)

  • rtb(h,c)r_t^b(h,c)rtb(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c,且第 ttt 帧是 blank的概率。
    rtb(h,c)=P(π1:t⇒h+c,πt=blank) r_t^b(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = \text{blank}) rtb(h,c)=P(π1:t⇒h+c,πt=blank)

二、状态转移推导

1. rtn(h,c)r_t^n(h,c)rtn(h,c) 的递推公式

第 ttt 帧预测为字符 ccc,且collapsed后结果为 h+ch+ch+c,可由两种状态转移而来:

(i) 前 t−1t-1t−1 帧已经得到了 h+ch+ch+c,且第 t−1t-1t−1 帧是 ccc

此时第 ttt 帧再预测 ccc,由于 CTC 的collapsed规则(连续相同字符合并),最终结果仍然是 h+ch+ch+c:
rt−1n(h,c)×P(c∣t) r_{t-1}^n(h,c) \times P(c|t) rt−1n(h,c)×P(c∣t)

(ii) 前 t−1t-1t−1 帧得到了前缀 hhh,且第 ttt 帧预测为 ccc

但需要注意:hhh 的最后一个字符不能是 ccc ,否则 h+ch + ch+c collapsed后会变成 hhh(相同字符合并)。

定义 ϕt−1(h,c)\phi_{t-1}(h, c)ϕt−1(h,c) 表示在时刻 t−1t-1t−1 已得到前缀 hhh,且允许在第 ttt 帧添加 ccc 的概率:
ϕt−1(h,c)={rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)rt−1n(h)+rt−1b(h),如果 c≠yL \phi_{t-1}(h, c) = \begin{cases} r^b_{t-1}(h), & \text{如果 } c = y_L \text{(c 等于 h 的最后一个字符)} \\ r^n_{t-1}(h) + r^b_{t-1}(h), & \text{如果 } c \neq y_L \end{cases} ϕt−1(h,c)={rt−1b(h),rt−1n(h)+rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)如果 c=yL

其中 yLy_LyL 是前缀 hhh 的最后一个字符。

综合两种情况,递推公式为:
rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t) r_t^n(h,c) = \big( r_{t-1}^n(h,c) + \phi_{t-1}(h, c) \big) \times P(c|t) rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t)

ϕ\phiϕ代码实现:

python 复制代码
r_sum = torch.logsumexp(r_prev, 1)  # 对 r^n 和 r^b 求和
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
# 特殊处理:若 c 等于 h 的最后一个字符,只能从 blank 状态转移
if scoring_ids is not None:
    for idx in range(n_bh):
        pos = scoring_idmap[idx, last_ids[idx]]
        if pos >= 0:
            log_phi[:, idx, pos] = r_prev[:, 1, idx]

2. rtb(h,c)r_t^b(h,c)rtb(h,c) 的递推公式

第 ttt 帧输出为 blank,且collapsed后结果为 h+ch+ch+c,则前 t−1t-1t−1 帧必须已经得到 h+ch+ch+c:
rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t) r_t^b(h,c) = \big( r_{t-1}^n(h,c) + r_{t-1}^b(h,c) \big) \times P(\text{blank}|t) rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t)

rtn(h,c)r_t^n(h,c)rtn(h,c)与rtb(h,c)r_t^b(h,c)rtb(h,c)代码实现:

python 复制代码
for t in range(start, end):
    rp = r[t - 1]
    #rr包含了r_t^n和r_t^b
    rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
    r[t] = torch.logsumexp(rr, 1) + x_[:, t]  # 加上当前帧的发射概率

3. 前缀概率更新

当计算完 rtn(h,c)r_t^n(h,c)rtn(h,c) 和 rtb(h,c)r_t^b(h,c)rtb(h,c) 后,它们实际上对应新前缀 h+ch+ch+c 的状态:
rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c) r_t^n(h+c) = r_t^n(h,c), \quad r_t^b(h+c) = r_t^b(h,c) rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c)

三、CTC 前缀分数计算

CTC Prefix Score是指以h+ch+ch+c作为前缀的概率,对于固定时刻 ttt,collapsed后结果为 h+ch+ch+c 的概率为:
rt(h+c)=rtn(h,c)+rtb(h,c) r_t(h+c) = r_t^n(h,c) + r_t^b(h,c) rt(h+c)=rtn(h,c)+rtb(h,c)

但 CTC 前缀得分不是简单对所有时刻的 rt(h+c)r_t(h+c)rt(h+c) 求和,原因在于同一条路径可能在不同时刻对应相同的前缀。例如,若在 t1t_1t1 时刻首次得到前缀 h+ch+ch+c,之后在 t2>t1t_2 > t_1t2>t1 时刻,所有帧输出均为 ccc 或 blank,则collapsed结果仍为 h+ch+ch+c。这样,同一条路径会被重复计入不同时刻的概率中。

应计算在时刻 ttt 首次生成前缀 h+ch+ch+c 的概率:
P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t) P(\pi_{1:t} \Rightarrow h + c) = \phi_{t-1}(h, c) \times P(c|t) P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t)

对时间ttt求和得到CTC 前缀得分ψ\psiψ:
ψ(h+c)=∑tP(π1:t⇒h+c)=∑tϕt−1(h,c)×P(c∣t) \psi(h + c) = \sum_{t} P(\pi_{1:t} \Rightarrow h + c) = \sum_{t} \phi_{t-1}(h, c) \times P(c|t) ψ(h+c)=t∑P(π1:t⇒h+c)=t∑ϕt−1(h,c)×P(c∣t)

ψ\psiψ代码实现:

python 复制代码
# 计算所有时刻的首次出现概率
log_phi_x = log_phi + x_[:, :, None]  # 加上字符c的概率
log_psi_ = torch.logsumexp(
    torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
    dim=0
)
# 将结果映射到对应的字符索引
for si in range(n_bh):
    log_psi[si, scoring_ids[si]] = log_psi_[si]
相关推荐
A尘埃7 小时前
超市购物篮关联分析与货架优化(Apriori算法)
算法
.小墨迹7 小时前
apollo学习之借道超车的速度规划
linux·c++·学习·算法·ubuntu
不穿格子的程序员7 小时前
从零开始刷算法——贪心篇1:跳跃游戏1 + 跳跃游戏2
算法·游戏·贪心
大江东去浪淘尽千古风流人物7 小时前
【SLAM新范式】几何主导=》几何+学习+语义+高效表示的融合
深度学习·算法·slam
重生之我是Java开发战士7 小时前
【优选算法】模拟算法:替换所有的问号,提莫攻击,N字形变换,外观数列,数青蛙
算法
仟濹7 小时前
算法打卡 day1 (2026-02-06 周四) | 算法: DFS | 1_卡码网98 可达路径 | 2_力扣797_所有可能的路径
算法·leetcode·深度优先
yang)7 小时前
欠采样时的相位倒置问题
算法
历程里程碑7 小时前
Linux20 : IO
linux·c语言·开发语言·数据结构·c++·算法
A尘埃7 小时前
物流公司配送路径动态优化(Q-Learning算法)
算法
天若有情6738 小时前
【自研实战】轻量级ASCII字符串加密算法:从设计到落地(防查岗神器版)
网络·c++·算法·安全·数据安全·加密