CTC Prefix Score计算

一、基本符号定义

在 CTC(Connectionist Temporal Classification)中,每帧的输出可以是字符或 blank(空白)。定义以下概率变量:

1. 前缀相关概率

令 hhh 表示一个已折叠(collapsed)的前缀序列,即连续的相同字符已被合并,且 blank 已被删除。

  • rtn(h)r_t^n(h)rtn(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧不是 blank的概率。
    rtn(h)=P(π1:t⇒h,πt≠blank) r_t^n(h) = P(\pi_{1:t} \Rightarrow h, \pi_t \neq \text{blank}) rtn(h)=P(π1:t⇒h,πt=blank)

  • rtb(h)r_t^b(h)rtb(h) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 hhh,且第 ttt 帧是 blank的概率。
    rtb(h)=P(π1:t⇒h,πt=blank) r_t^b(h) = P(\pi_{1:t} \Rightarrow h, \pi_t = \text{blank}) rtb(h)=P(π1:t⇒h,πt=blank)

2. 前缀扩展相关概率

在已有前缀 hhh 后面追加一个字符 ccc时,定义:

  • rtn(h,c)r_t^n(h,c)rtn(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c(字符串拼接),且第 ttt 帧是 ccc的概率。
    rtn(h,c)=P(π1:t⇒h+c,πt=c) r_t^n(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = c) rtn(h,c)=P(π1:t⇒h+c,πt=c)

  • rtb(h,c)r_t^b(h,c)rtb(h,c) :表示从第 1 帧到第 ttt 帧的输出,collapsed后结果为 h+ch+ch+c,且第 ttt 帧是 blank的概率。
    rtb(h,c)=P(π1:t⇒h+c,πt=blank) r_t^b(h,c) = P(\pi_{1:t} \Rightarrow h + c, \pi_t = \text{blank}) rtb(h,c)=P(π1:t⇒h+c,πt=blank)

二、状态转移推导

1. rtn(h,c)r_t^n(h,c)rtn(h,c) 的递推公式

第 ttt 帧预测为字符 ccc,且collapsed后结果为 h+ch+ch+c,可由两种状态转移而来:

(i) 前 t−1t-1t−1 帧已经得到了 h+ch+ch+c,且第 t−1t-1t−1 帧是 ccc

此时第 ttt 帧再预测 ccc,由于 CTC 的collapsed规则(连续相同字符合并),最终结果仍然是 h+ch+ch+c:
rt−1n(h,c)×P(c∣t) r_{t-1}^n(h,c) \times P(c|t) rt−1n(h,c)×P(c∣t)

(ii) 前 t−1t-1t−1 帧得到了前缀 hhh,且第 ttt 帧预测为 ccc

但需要注意:hhh 的最后一个字符不能是 ccc ,否则 h+ch + ch+c collapsed后会变成 hhh(相同字符合并)。

定义 ϕt−1(h,c)\phi_{t-1}(h, c)ϕt−1(h,c) 表示在时刻 t−1t-1t−1 已得到前缀 hhh,且允许在第 ttt 帧添加 ccc 的概率:
ϕt−1(h,c)={rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)rt−1n(h)+rt−1b(h),如果 c≠yL \phi_{t-1}(h, c) = \begin{cases} r^b_{t-1}(h), & \text{如果 } c = y_L \text{(c 等于 h 的最后一个字符)} \\ r^n_{t-1}(h) + r^b_{t-1}(h), & \text{如果 } c \neq y_L \end{cases} ϕt−1(h,c)={rt−1b(h),rt−1n(h)+rt−1b(h),如果 c=yL(c 等于 h 的最后一个字符)如果 c=yL

其中 yLy_LyL 是前缀 hhh 的最后一个字符。

综合两种情况,递推公式为:
rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t) r_t^n(h,c) = \big( r_{t-1}^n(h,c) + \phi_{t-1}(h, c) \big) \times P(c|t) rtn(h,c)=(rt−1n(h,c)+ϕt−1(h,c))×P(c∣t)

ϕ\phiϕ代码实现:

python 复制代码
r_sum = torch.logsumexp(r_prev, 1)  # 对 r^n 和 r^b 求和
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
# 特殊处理:若 c 等于 h 的最后一个字符,只能从 blank 状态转移
if scoring_ids is not None:
    for idx in range(n_bh):
        pos = scoring_idmap[idx, last_ids[idx]]
        if pos >= 0:
            log_phi[:, idx, pos] = r_prev[:, 1, idx]

2. rtb(h,c)r_t^b(h,c)rtb(h,c) 的递推公式

第 ttt 帧输出为 blank,且collapsed后结果为 h+ch+ch+c,则前 t−1t-1t−1 帧必须已经得到 h+ch+ch+c:
rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t) r_t^b(h,c) = \big( r_{t-1}^n(h,c) + r_{t-1}^b(h,c) \big) \times P(\text{blank}|t) rtb(h,c)=(rt−1n(h,c)+rt−1b(h,c))×P(blank∣t)

rtn(h,c)r_t^n(h,c)rtn(h,c)与rtb(h,c)r_t^b(h,c)rtb(h,c)代码实现:

python 复制代码
for t in range(start, end):
    rp = r[t - 1]
    #rr包含了r_t^n和r_t^b
    rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
    r[t] = torch.logsumexp(rr, 1) + x_[:, t]  # 加上当前帧的发射概率

3. 前缀概率更新

当计算完 rtn(h,c)r_t^n(h,c)rtn(h,c) 和 rtb(h,c)r_t^b(h,c)rtb(h,c) 后,它们实际上对应新前缀 h+ch+ch+c 的状态:
rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c) r_t^n(h+c) = r_t^n(h,c), \quad r_t^b(h+c) = r_t^b(h,c) rtn(h+c)=rtn(h,c),rtb(h+c)=rtb(h,c)

三、CTC 前缀分数计算

CTC Prefix Score是指以h+ch+ch+c作为前缀的概率,对于固定时刻 ttt,collapsed后结果为 h+ch+ch+c 的概率为:
rt(h+c)=rtn(h,c)+rtb(h,c) r_t(h+c) = r_t^n(h,c) + r_t^b(h,c) rt(h+c)=rtn(h,c)+rtb(h,c)

但 CTC 前缀得分不是简单对所有时刻的 rt(h+c)r_t(h+c)rt(h+c) 求和,原因在于同一条路径可能在不同时刻对应相同的前缀。例如,若在 t1t_1t1 时刻首次得到前缀 h+ch+ch+c,之后在 t2>t1t_2 > t_1t2>t1 时刻,所有帧输出均为 ccc 或 blank,则collapsed结果仍为 h+ch+ch+c。这样,同一条路径会被重复计入不同时刻的概率中。

应计算在时刻 ttt 首次生成前缀 h+ch+ch+c 的概率:
P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t) P(\pi_{1:t} \Rightarrow h + c) = \phi_{t-1}(h, c) \times P(c|t) P(π1:t⇒h+c)=ϕt−1(h,c)×P(c∣t)

对时间ttt求和得到CTC 前缀得分ψ\psiψ:
ψ(h+c)=∑tP(π1:t⇒h+c)=∑tϕt−1(h,c)×P(c∣t) \psi(h + c) = \sum_{t} P(\pi_{1:t} \Rightarrow h + c) = \sum_{t} \phi_{t-1}(h, c) \times P(c|t) ψ(h+c)=t∑P(π1:t⇒h+c)=t∑ϕt−1(h,c)×P(c∣t)

ψ\psiψ代码实现:

python 复制代码
# 计算所有时刻的首次出现概率
log_phi_x = log_phi + x_[:, :, None]  # 加上字符c的概率
log_psi_ = torch.logsumexp(
    torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
    dim=0
)
# 将结果映射到对应的字符索引
for si in range(n_bh):
    log_psi[si, scoring_ids[si]] = log_psi_[si]
相关推荐
Z1Jxxx3 分钟前
C++ P1150 Peter 的烟
数据结构·c++·算法
踮起脚看烟花10 分钟前
chapter10_泛型算法
c++·算法
笨笨饿10 分钟前
# 52_浅谈为什么工程基本进入复数域?
linux·服务器·c语言·数据结构·人工智能·算法·学习方法
Code-keys11 分钟前
ADSP/ARM 性能/稳定性排查专栏总述
arm开发·算法·边缘计算·dsp开发
山栀shanzhi14 分钟前
C++四大常见排序对比
c++·算法·排序算法
Allen_LVyingbo28 分钟前
量子测量三部曲:投影测量、POVM 与坍缩之谜—从形式主义到物理图像
算法·性能优化·健康医疗·量子计算·空间计算
qiqsevenqiqiqiqi33 分钟前
位运算 计算
算法
人工智能培训34 分钟前
多模态AI模型融合难?核心问题与解决思路
人工智能·机器学习·prompt·agent·智能体
甄心爱学习44 分钟前
【最优化】1-6章习题
人工智能·算法
PD我是你的真爱粉1 小时前
向量数据库原理与检索算法入门:ANN、HNSW、LSH、PQ 与相似度计算
数据库·人工智能·算法