Transformer 代码剖析16 - BLEU分数(pytorch实现)

一、BLEU算法全景图

是 否 开始 初始化统计信息列表 添加假设序列长度 添加参考序列长度 n=1 统计假设序列1-gram 统计参考序列1-gram n=2 统计假设序列2-gram 统计参考序列2-gram n=3 统计假设序列3-gram 统计参考序列3-gram n=4 统计假设序列4-gram 统计参考序列4-gram 遍历n=1到4 计算共有n-gram总数 计算假设序列中可能的n-gram总数 结束统计信息收集 根据统计信息计算BLEU分数 检查统计信息中是否有0 BLEU分数为0 计算修正后的n-gram精度几何平均值 计算BLEU分数 遍历假设序列和参考序列 累加BLEU统计信息 获取验证集BLEU分数 结束

二、核心函数模块解析

2.1 统计信息收集:bleu_stats

python 复制代码
def bleu_stats(hypothesis, reference):
    stats = []
    stats.append(len(hypothesis))
    stats.append(len(reference))
    for n in range(1, 5):
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )
        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))
        stats.append(max([len(hypothesis) + 1 - n, 0]))
    return stats 

数据流可视化
假设序列 长度统计 参考序列 长度统计 n=1 1-gram统计 1-gram匹配 n=2 2-gram统计 2-gram匹配 n=3 3-gram统计 3-gram匹配 n=4 4-gram统计 4-gram匹配 统计结果 统计信息列表

2.2 BLEU分数计算:bleu

python 复制代码
def bleu(stats):
    if len(list(filter(lambda x: x == 0, stats))) > 0:
        return 0 
    (c, r) = stats[:2]
    log_bleu_prec = sum(
        [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]
    ) / 4.
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)

计算流程分解
1. 零值检查: 确保所有统计信息非零
2. 长度提取: 获取假设序列长度c和参考序列长度r
3. 精度计算: 计算n-gram精度的几何平均值的对数
4. 惩罚因子: 引入长度惩罚因子 m i n ( 0 , 1 − r c ) min(0,1-\frac{r}{c}) min(0,1−cr)
5. 指数转换: 将对数结果转换回概率值

2.3 验证集BLEU计算:get_bleu

python 复制代码
def get_bleu(hypotheses, reference):
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
    for hyp, ref in zip(hypotheses, reference):
        stats += np.array(bleu_stats(hyp, ref))
    return 100 * bleu(stats)

批处理流程
初始化统计数组 遍历假设-参考对 累加统计信息 计算最终BLEU分数 转换为百分比

2.4 索引转换:idx_to_word

python 复制代码
def idx_to_word(x, vocab):
    words = []
    for i in x:
        word = vocab.itos[i]
        if '<' not in word:
            words.append(word)
    words = " ".join(words)
    return words 

转换机制
索引列表 遍历索引 查询词汇表 过滤特殊符号 构建单词列表 生成字符串

三、数学运算可视化推演

3.1 示例数据

复制代码
假设序列: [1, 2, 3, 4]
参考序列: [1, 2, 4, 5]

3.2 统计信息生成

统计项 计算过程 结果值
假设长度 len(hypothesis) 4
参考长度 len(reference) 4
1-gram匹配 sum((s_ngrams & r_ngrams).values()) 3
1-gram总数 len(hypothesis) + 1 - 1 4
2-gram匹配 sum((s_ngrams & r_ngrams).values()) 2
2-gram总数 len(hypothesis) + 1 - 2 3
3-gram匹配 sum((s_ngrams & r_ngrams).values()) 1
3-gram总数 len(hypothesis) + 1 - 3 2
4-gram匹配 sum((s_ngrams & r_ngrams).values()) 0
4-gram总数 len(hypothesis) + 1 - 4 1

3.3 BLEU分数计算

1. 精度计算:

  • 1-gram: 3/4 = 0.75
  • 2-gram: 2/3 ≈ 0.6667
  • 3-gram: 1/2 = 0.5
  • 4-gram: 0/1 = 0

2. 几何平均:

  • ( 0.75 × 0.6667 × 0.5 × 0 ) 1 4 = 0 (0.75 × 0.6667 × 0.5 × 0)^{\frac{1}{4}} = 0 (0.75×0.6667×0.5×0)41=0

3. 长度惩罚:

  • m i n ( 0 , 1 − 4 / 4 ) = 0 min(0, 1-4/4) = 0 min(0,1−4/4)=0

4. 最终BLEU:

  • e x p ( 0 + l o g ( 0 ) ) = 0 exp(0 + log(0)) = 0 exp(0+log(0))=0

四、工程实践要点

4.1 性能优化

  • 向量化操作: 使用numpy数组提高统计信息累加效率
  • 内存优化: 使用生成器表达式减少内存占用
  • 并行计算: 支持多线程处理大规模数据

4.2 扩展应用

  • 多参考支持: 扩展支持多个参考序列
  • 权重调整: 支持自定义n-gram权重
  • 平滑处理: 引入平滑方法处理零值问题

五、BLEU分数的意义与设计思想

5.1 BLEU的核心意义

BLEU(Bilingual Evaluation Understudy)是机器翻译领域最具代表性的自动评估指标,其核心意义在于通过量化生成文本与参考文本的n-gram匹配程度,为翻译质量提供可计算的客观标准。它的设计初衷是解决人工评估成本高、一致性差的问题,通过以下特性实现高效评估:
1. 自动化与可扩展性: 可快速处理大规模数据,支持模型训练中的实时反馈。
2. 多粒度匹配: 结合1-gram到4-gram的共现统计,兼顾局部词汇匹配和长距离语义连贯性。
3. 长度敏感性: 通过惩罚因子抑制过短译文的得分膨胀,避免逐字翻译的投机行为。

5.2 设计思想与权衡

5.2.1 精度导向的评估逻辑

BLEU以修正后的n-gram精度(Modified n-gram Precision)为基础,统计假设文本中每个n-gram在参考文本中的最大出现次数。例如,若参考文本中出现2次"the cat",而假设文本出现3次,则匹配计数为2。这种设计通过截断计数法(Clipped Count)避免高频词汇的过度奖励,体现了对翻译"准确性"而非"冗余性"的追求。

5.2.2 几何平均与对数空间

对不同n-gram的精度取几何平均(而非算术平均),强化低精度值的惩罚效应。例如,若某n-gram精度为0,则整体分数直接归零,这迫使模型必须全面兼顾不同粒度的匹配。而对数转换(log(float(x)/y))则将乘法运算转化为加法,避免连乘导致的数值下溢问题。

5.2.3 长度惩罚机制

通过惩罚因子 min(0, 1 - r/c) 解决短译文问题:

  • 当假设文本长度(c)小于参考文本(r)时,惩罚因子为 1 - r/c,显著降低分数。
  • 当c ≥ r时,惩罚因子为0,保留原始精度值。此机制迫使模型生成与参考长度匹配的译文,而非通过截断获取高分。

5.3 数学表达解析

BLEU分数的完整计算公式为:

B L E U = B P ⋅ e x p ( ∑ n = 1 N w n l o g p n ) BLEU = BP · exp(∑_{n=1}^N w_n log p_n) BLEU=BP⋅exp(∑n=1Nwnlogpn)

其中:

  • BP(Brevity Penalty): 长度惩罚因子,计算公式为 B P = e m i n ( 0 , 1 − r c ) BP = e^{min(0, 1 - \frac{r}{c})} BP=emin(0,1−cr)
  • p n p_n pn: n-gram精度,计算方式为 匹配的n-gram数 / 假设文本中n-gram总数
  • w n w_n wn: 权重系数,默认取均匀权重1/4(N=4时)

该公式在代码中体现为两个阶段:
1. 统计阶段(bleu_stats函数): 遍历1-4 gram,分别计算匹配数和总数。
2. 合成阶段(bleu函数): 对精度取对数平均,叠加长度惩罚,最终通过指数映射得到0-1区间值。

5.4 优势与局限性

优势

  • 高效客观: 摆脱人工评估的主观性与滞后性,支持模型迭代的快速验证。
  • 多语言通用: 不依赖语言特定规则,适用于任意语言对的翻译评估。
  • 强相关性: 在多数研究中,BLEU与人工评分呈现显著正相关(相关系数约0.6-0.8)。

局限性

  • 语义盲区: 无法捕捉同义词替换、语序调换等语义等价变化。
  • 多样性惩罚: 对创造性译文施加不公平惩罚(例如用近义词替换参考文本中的词汇)。
  • 参考依赖: 分数质量高度依赖参考文本的覆盖度和多样性,多参考文本可部分缓解此问题。

原项目代码+注释(附)

python 复制代码
"""
@author : Hyunwoong
@when : 2019-12-22
@homepage : https://github.com/gusdnd852
"""

import math
from collections import Counter

import numpy as np

# 计算BLEU分数的统计信息
def bleu_stats(hypothesis, reference):
    """计算BLEU分数的统计信息。"""
    stats = []  # 初始化一个空列表用于存储统计信息
    stats.append(len(hypothesis))  # 添加假设序列的长度
    stats.append(len(reference))  # 添加参考序列的长度
    for n in range(1, 5):  # 对于n-gram,n从1到4
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )  # 统计假设序列中所有n-gram出现的次数
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )  # 统计参考序列中所有n-gram出现的次数

        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))  # 添加假设序列和参考序列中共有的n-gram的总数
        stats.append(max([len(hypothesis) + 1 - n, 0]))  # 添加假设序列中可能的n-gram的总数
    return stats  # 返回统计信息列表

# 根据统计信息计算BLEU分数
def bleu(stats):
    """根据n-gram统计信息计算BLEU分数。"""
    if len(list(filter(lambda x: x == 0, stats))) > 0:  # 如果统计信息中有任何一项为0,则BLEU分数为0
        return 0
    (c, r) = stats[:2]  # 假设序列长度和参考序列长度
    log_bleu_prec = sum(
        [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]
    ) / 4.  # 计算修正后的n-gram精度的几何平均值(对数形式)
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)  # 计算BLEU分数

# 获取验证集的BLEU分数
def get_bleu(hypotheses, reference):
    """获取验证集的BLEU分数。"""
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])  # 初始化一个numpy数组用于累加统计信息
    for hyp, ref in zip(hypotheses, reference):  # 遍历假设序列和参考序列
        stats += np.array(bleu_stats(hyp, ref))  # 累加每个假设序列和参考序列的BLEU统计信息
    return 100 * bleu(stats)  # 返回BLEU分数(乘以100转换为百分比形式)

# 将索引转换为单词
def idx_to_word(x, vocab):
    """将索引列表转换为单词序列。"""
    words = []  # 初始化一个空列表用于存储单词
    for i in x:  # 遍历索引列表
        word = vocab.itos[i]  # 根据索引查找对应的单词
        if '<' not in word:  # 如果单词中不包含特殊符号'<',则添加到单词列表中
            words.append(word)
    words = " ".join(words)  # 将单词列表转换为以空格分隔的字符串
    return words  # 返回单词序列
相关推荐
lijianhua_97125 小时前
国内某顶级大学内部用的ai自动生成论文的提示词
人工智能
EDPJ5 小时前
当图像与文本 “各说各话” —— CLIP 中的模态鸿沟与对象偏向
深度学习·计算机视觉
蔡俊锋5 小时前
用AI实现乐高式大型可插拔系统的技术方案
人工智能·ai工程·ai原子能力·ai乐高工程
自然语5 小时前
人工智能之数字生命 认知架构白皮书 第7章
人工智能·架构
大熊背5 小时前
利用ISP离线模式进行分块LSC校正的方法
人工智能·算法·机器学习
eastyuxiao5 小时前
如何在不同的机器上运行多个OpenClaw实例?
人工智能·git·架构·github·php
诸葛务农5 小时前
AGI 主要技术路径及核心技术:归一融合及未来之路5
大数据·人工智能
光影少年5 小时前
AI Agent智能体开发
人工智能·aigc·ai编程
极梦网络无忧5 小时前
OpenClaw 基础使用说明(中文版)
python
codeJinger5 小时前
【Python】操作Excel文件
python·excel