CTC Prefix Score计算

一、基本符号定义

在 CTC(Connectionist Temporal Classification)中,每帧的输出可以是字符或 blank(空白)。定义以下概率变量:

1. 前缀相关概率

令 hhh 表示一个已折叠(collapsed)的前缀序列,即连续的相同字符已被合并,且 blank 已被删除。

  • rtn(h)r_t^n(h)rtn(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧不是 blank的概率。
    rtn(h)=P(π1:t⇒h,πt≠blank) r_t^n(h) = P(\pi_{1:t} \Rightarrow h, \pi_t \neq \text{blank}) rtn(h)=P(π1:t⇒h,πt=blank)

  • rtb(h)r_t^b(h)rtb(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧是 blank的概率。
    rtb(h)=P(π1:t⇒h,πt=blank) r_t^b(h) = P(\pi_{1:t} \Rightarrow h, \pi_t = \text{blank}) rtb(h)=P(π1:t⇒h,πt=blank)

2. 前缀扩展相关概率

在已有前缀 hhh 后面追加一个字符 ccc时,定义:

  • rtn(h,c)r_t^n(h,c)rtn(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c(字符串拼接),且第 ttt 帧是 ccc的概率。
    rtn(h,c)=P(π1:t⇒h+c,πt=c) r_t^n(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = c) rtn(h,c)=P(π1:t⇒h+c,πt=c)

  • rtb(h,c)r_t^b(h,c)rtb(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c,且第 ttt 帧是 blank的概率。
    rtb(h,c)=P(π1:t⇒h+c,πt=blank) r_t^b(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = \text{blank}) rtb(h,c)=P(π1:t⇒h+c,πt=blank)

二、状态转移推导

1. rtn(h,c)r_t^n(h,c)rtn(h,c) 的递推公式

第 ttt 帧预测为字符 ccc,且collapsed后结果为 h+ch+ch+c,可由两种状态转移而来:

(i) 前 t−1t-1t−1 帧已经得到了 h+ch+ch+c,且第 t−1t-1t−1 帧是 ccc

此时第 ttt 帧再预测 ccc,由于 CTC 的collapsed规则(连续相同字符合并),最终结果仍然是 h+ch+ch+c:
rt−1n(h,c)×P(c∣t) r_{t-1}^n(h,c) \times P(c|t) rt−1n(h,c)×P(c∣t)

(ii) 前 t−1t-1t−1 帧得到了前缀 hhh,且第 ttt 帧预测为 ccc

但需要注意:hhh 的最后一个字符不能是 ccc ,否则 h+ch + ch+c collapsed后会变成 hhh(相同字符合并)。

定义 ϕt−1(h,c)\phi_{t-1}(h, c)ϕt−1(h,c) 表示在时刻 t−1t-1t−1 已得到前缀 hhh,且允许在第 ttt 帧添加 ccc 的概率:
ϕt−1(h,c)={rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)rt−1n(h)+rt−1b(h),如果 c≠yL \phi_{t-1}(h, c) = \begin{cases} r^b_{t-1}(h), & \text{如果 } c = y_L \text{(c 等于 h 的最后一个字符)} \\ r^n_{t-1}(h) + r^b_{t-1}(h), & \text{如果 } c \neq y_L \end{cases} ϕt−1(h,c)={rt−1b(h),rt−1n(h)+rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)如果 c=yL

其中 yLy_LyL 是前缀 hhh 的最后一个字符。

综合两种情况,递推公式为:
rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t) r_t^n(h,c) = \big( r_{t-1}^n(h,c) + \phi_{t-1}(h, c) \big) \times P(c|t) rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t)

ϕ\phiϕ代码实现:

python 复制代码
r_sum = torch.logsumexp(r_prev, 1)  # 对 r^n 和 r^b 求和
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
# 特殊处理:若 c 等于 h 的最后一个字符,只能从 blank 状态转移
if scoring_ids is not None:
    for idx in range(n_bh):
        pos = scoring_idmap[idx, last_ids[idx]]
        if pos >= 0:
            log_phi[:, idx, pos] = r_prev[:, 1, idx]

2. rtb(h,c)r_t^b(h,c)rtb(h,c) 的递推公式

第 ttt 帧输出为 blank,且collapsed后结果为 h+ch+ch+c,则前 t−1t-1t−1 帧必须已经得到 h+ch+ch+c:
rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t) r_t^b(h,c) = \big( r_{t-1}^n(h,c) + r_{t-1}^b(h,c) \big) \times P(\text{blank}|t) rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t)

rtn(h,c)r_t^n(h,c)rtn(h,c)与rtb(h,c)r_t^b(h,c)rtb(h,c)代码实现:

python 复制代码
for t in range(start, end):
    rp = r[t - 1]
    #rr包含了r_t^n和r_t^b
    rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
    r[t] = torch.logsumexp(rr, 1) + x_[:, t]  # 加上当前帧的发射概率

3. 前缀概率更新

当计算完 rtn(h,c)r_t^n(h,c)rtn(h,c) 和 rtb(h,c)r_t^b(h,c)rtb(h,c) 后,它们实际上对应新前缀 h+ch+ch+c 的状态:
rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c) r_t^n(h+c) = r_t^n(h,c), \quad r_t^b(h+c) = r_t^b(h,c) rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c)

三、CTC 前缀分数计算

CTC Prefix Score是指以h+ch+ch+c作为前缀的概率,对于固定时刻 ttt,collapsed后结果为 h+ch+ch+c 的概率为:
rt(h+c)=rtn(h,c)+rtb(h,c) r_t(h+c) = r_t^n(h,c) + r_t^b(h,c) rt(h+c)=rtn(h,c)+rtb(h,c)

但 CTC 前缀得分不是简单对所有时刻的 rt(h+c)r_t(h+c)rt(h+c) 求和,原因在于同一条路径可能在不同时刻对应相同的前缀。例如,若在 t1t_1t1 时刻首次得到前缀 h+ch+ch+c,之后在 t2>t1t_2 > t_1t2>t1 时刻,所有帧输出均为 ccc 或 blank,则collapsed结果仍为 h+ch+ch+c。这样,同一条路径会被重复计入不同时刻的概率中。

应计算在时刻 ttt 首次生成前缀 h+ch+ch+c 的概率:
P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t) P(\pi_{1:t} \Rightarrow h + c) = \phi_{t-1}(h, c) \times P(c|t) P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t)

对时间ttt求和得到CTC 前缀得分ψ\psiψ:
ψ(h+c)=∑tP(π1:t⇒h+c)=∑tϕt−1(h,c)×P(c∣t) \psi(h + c) = \sum_{t} P(\pi_{1:t} \Rightarrow h + c) = \sum_{t} \phi_{t-1}(h, c) \times P(c|t) ψ(h+c)=t∑P(π1:t⇒h+c)=t∑ϕt−1(h,c)×P(c∣t)

ψ\psiψ代码实现:

python 复制代码
# 计算所有时刻的首次出现概率
log_phi_x = log_phi + x_[:, :, None]  # 加上字符c的概率
log_psi_ = torch.logsumexp(
    torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
    dim=0
)
# 将结果映射到对应的字符索引
for si in range(n_bh):
    log_psi[si, scoring_ids[si]] = log_psi_[si]
相关推荐
Stardep2 小时前
算法入门21——二分查找算法——山脉数组的峰顶索引
数据结构·算法·leetcode
mjhcsp2 小时前
P3145 [USACO16OPEN] Splitting the Field G(题解)
开发语言·c++·算法
空空潍2 小时前
hot100-合并区间(day14)
c++·算法·leetcode
橘颂TA2 小时前
【剑斩OFFER】算法的暴力美学——力扣 675 题:为高尔夫比赛砍树
数据结构·算法·c·结构与算法
rit84324992 小时前
UVE算法提取光谱特征波长的MATLAB实现与应用
开发语言·算法·matlab
是娇娇公主~2 小时前
算法——【最大子数组和】
数据结构·c++·算法
tkevinjd2 小时前
力扣hot100-283移动零(盲人拉屎)
算法·leetcode
POLITE32 小时前
Leetcode 94. 二叉树的中序遍历 104. 二叉树的最大深度 226. 翻转二叉树 101. 对称二叉树 (Day 13)
算法·leetcode·职场和发展
老鼠只爱大米2 小时前
LeetCode经典算法面试题 #2:两数相加(迭代法、字符串修改法等多种实现方案详解)
算法·leetcode·链表·两数相加·字符串修改法·两数相减·大数运算