Transformer 代码剖析16 - BLEU分数(pytorch实现)

一、BLEU算法全景图

是 否 开始 初始化统计信息列表 添加假设序列长度 添加参考序列长度 n=1 统计假设序列1-gram 统计参考序列1-gram n=2 统计假设序列2-gram 统计参考序列2-gram n=3 统计假设序列3-gram 统计参考序列3-gram n=4 统计假设序列4-gram 统计参考序列4-gram 遍历n=1到4 计算共有n-gram总数 计算假设序列中可能的n-gram总数 结束统计信息收集 根据统计信息计算BLEU分数 检查统计信息中是否有0 BLEU分数为0 计算修正后的n-gram精度几何平均值 计算BLEU分数 遍历假设序列和参考序列 累加BLEU统计信息 获取验证集BLEU分数 结束

二、核心函数模块解析

2.1 统计信息收集:bleu_stats

python 复制代码
def bleu_stats(hypothesis, reference):
    stats = []
    stats.append(len(hypothesis))
    stats.append(len(reference))
    for n in range(1, 5):
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )
        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))
        stats.append(max([len(hypothesis) + 1 - n, 0]))
    return stats 

数据流可视化
假设序列 长度统计 参考序列 长度统计 n=1 1-gram统计 1-gram匹配 n=2 2-gram统计 2-gram匹配 n=3 3-gram统计 3-gram匹配 n=4 4-gram统计 4-gram匹配 统计结果 统计信息列表

2.2 BLEU分数计算:bleu

python 复制代码
def bleu(stats):
    if len(list(filter(lambda x: x == 0, stats))) > 0:
        return 0 
    (c, r) = stats[:2]
    log_bleu_prec = sum(
        [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]
    ) / 4.
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)

计算流程分解
1. 零值检查: 确保所有统计信息非零
2. 长度提取: 获取假设序列长度c和参考序列长度r
3. 精度计算: 计算n-gram精度的几何平均值的对数
4. 惩罚因子: 引入长度惩罚因子 m i n ( 0 , 1 − r c ) min(0,1-\frac{r}{c}) min(0,1−cr)
5. 指数转换: 将对数结果转换回概率值

2.3 验证集BLEU计算:get_bleu

python 复制代码
def get_bleu(hypotheses, reference):
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
    for hyp, ref in zip(hypotheses, reference):
        stats += np.array(bleu_stats(hyp, ref))
    return 100 * bleu(stats)

批处理流程
初始化统计数组 遍历假设-参考对 累加统计信息 计算最终BLEU分数 转换为百分比

2.4 索引转换:idx_to_word

python 复制代码
def idx_to_word(x, vocab):
    words = []
    for i in x:
        word = vocab.itos[i]
        if '<' not in word:
            words.append(word)
    words = " ".join(words)
    return words 

转换机制
索引列表 遍历索引 查询词汇表 过滤特殊符号 构建单词列表 生成字符串

三、数学运算可视化推演

3.1 示例数据

假设序列: [1, 2, 3, 4]
参考序列: [1, 2, 4, 5]

3.2 统计信息生成

统计项 计算过程 结果值
假设长度 len(hypothesis) 4
参考长度 len(reference) 4
1-gram匹配 sum((s_ngrams & r_ngrams).values()) 3
1-gram总数 len(hypothesis) + 1 - 1 4
2-gram匹配 sum((s_ngrams & r_ngrams).values()) 2
2-gram总数 len(hypothesis) + 1 - 2 3
3-gram匹配 sum((s_ngrams & r_ngrams).values()) 1
3-gram总数 len(hypothesis) + 1 - 3 2
4-gram匹配 sum((s_ngrams & r_ngrams).values()) 0
4-gram总数 len(hypothesis) + 1 - 4 1

3.3 BLEU分数计算

1. 精度计算:

  • 1-gram: 3/4 = 0.75
  • 2-gram: 2/3 ≈ 0.6667
  • 3-gram: 1/2 = 0.5
  • 4-gram: 0/1 = 0

2. 几何平均:

  • ( 0.75 × 0.6667 × 0.5 × 0 ) 1 4 = 0 (0.75 × 0.6667 × 0.5 × 0)^{\frac{1}{4}} = 0 (0.75×0.6667×0.5×0)41=0

3. 长度惩罚:

  • m i n ( 0 , 1 − 4 / 4 ) = 0 min(0, 1-4/4) = 0 min(0,1−4/4)=0

4. 最终BLEU:

  • e x p ( 0 + l o g ( 0 ) ) = 0 exp(0 + log(0)) = 0 exp(0+log(0))=0

四、工程实践要点

4.1 性能优化

  • 向量化操作: 使用numpy数组提高统计信息累加效率
  • 内存优化: 使用生成器表达式减少内存占用
  • 并行计算: 支持多线程处理大规模数据

4.2 扩展应用

  • 多参考支持: 扩展支持多个参考序列
  • 权重调整: 支持自定义n-gram权重
  • 平滑处理: 引入平滑方法处理零值问题

五、BLEU分数的意义与设计思想

5.1 BLEU的核心意义

BLEU(Bilingual Evaluation Understudy)是机器翻译领域最具代表性的自动评估指标,其核心意义在于通过量化生成文本与参考文本的n-gram匹配程度,为翻译质量提供可计算的客观标准。它的设计初衷是解决人工评估成本高、一致性差的问题,通过以下特性实现高效评估:
1. 自动化与可扩展性: 可快速处理大规模数据,支持模型训练中的实时反馈。
2. 多粒度匹配: 结合1-gram到4-gram的共现统计,兼顾局部词汇匹配和长距离语义连贯性。
3. 长度敏感性: 通过惩罚因子抑制过短译文的得分膨胀,避免逐字翻译的投机行为。

5.2 设计思想与权衡

5.2.1 精度导向的评估逻辑

BLEU以修正后的n-gram精度(Modified n-gram Precision)为基础,统计假设文本中每个n-gram在参考文本中的最大出现次数。例如,若参考文本中出现2次"the cat",而假设文本出现3次,则匹配计数为2。这种设计通过截断计数法(Clipped Count)避免高频词汇的过度奖励,体现了对翻译"准确性"而非"冗余性"的追求。

5.2.2 几何平均与对数空间

对不同n-gram的精度取几何平均(而非算术平均),强化低精度值的惩罚效应。例如,若某n-gram精度为0,则整体分数直接归零,这迫使模型必须全面兼顾不同粒度的匹配。而对数转换(log(float(x)/y))则将乘法运算转化为加法,避免连乘导致的数值下溢问题。

5.2.3 长度惩罚机制

通过惩罚因子 min(0, 1 - r/c) 解决短译文问题:

  • 当假设文本长度(c)小于参考文本(r)时,惩罚因子为 1 - r/c,显著降低分数。
  • 当c ≥ r时,惩罚因子为0,保留原始精度值。此机制迫使模型生成与参考长度匹配的译文,而非通过截断获取高分。

5.3 数学表达解析

BLEU分数的完整计算公式为:

B L E U = B P ⋅ e x p ( ∑ n = 1 N w n l o g p n ) BLEU = BP · exp(∑_{n=1}^N w_n log p_n) BLEU=BP⋅exp(∑n=1Nwnlogpn)

其中:

  • BP(Brevity Penalty): 长度惩罚因子,计算公式为 B P = e m i n ( 0 , 1 − r c ) BP = e^{min(0, 1 - \frac{r}{c})} BP=emin(0,1−cr)
  • p n p_n pn: n-gram精度,计算方式为 匹配的n-gram数 / 假设文本中n-gram总数
  • w n w_n wn: 权重系数,默认取均匀权重1/4(N=4时)

该公式在代码中体现为两个阶段:
1. 统计阶段(bleu_stats函数): 遍历1-4 gram,分别计算匹配数和总数。
2. 合成阶段(bleu函数): 对精度取对数平均,叠加长度惩罚,最终通过指数映射得到0-1区间值。

5.4 优势与局限性

优势

  • 高效客观: 摆脱人工评估的主观性与滞后性,支持模型迭代的快速验证。
  • 多语言通用: 不依赖语言特定规则,适用于任意语言对的翻译评估。
  • 强相关性: 在多数研究中,BLEU与人工评分呈现显著正相关(相关系数约0.6-0.8)。

局限性

  • 语义盲区: 无法捕捉同义词替换、语序调换等语义等价变化。
  • 多样性惩罚: 对创造性译文施加不公平惩罚(例如用近义词替换参考文本中的词汇)。
  • 参考依赖: 分数质量高度依赖参考文本的覆盖度和多样性,多参考文本可部分缓解此问题。

原项目代码+注释(附)

python 复制代码
"""
@author : Hyunwoong
@when : 2019-12-22
@homepage : https://github.com/gusdnd852
"""

import math
from collections import Counter

import numpy as np

# 计算BLEU分数的统计信息
def bleu_stats(hypothesis, reference):
    """计算BLEU分数的统计信息。"""
    stats = []  # 初始化一个空列表用于存储统计信息
    stats.append(len(hypothesis))  # 添加假设序列的长度
    stats.append(len(reference))  # 添加参考序列的长度
    for n in range(1, 5):  # 对于n-gram,n从1到4
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )  # 统计假设序列中所有n-gram出现的次数
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )  # 统计参考序列中所有n-gram出现的次数

        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))  # 添加假设序列和参考序列中共有的n-gram的总数
        stats.append(max([len(hypothesis) + 1 - n, 0]))  # 添加假设序列中可能的n-gram的总数
    return stats  # 返回统计信息列表

# 根据统计信息计算BLEU分数
def bleu(stats):
    """根据n-gram统计信息计算BLEU分数。"""
    if len(list(filter(lambda x: x == 0, stats))) > 0:  # 如果统计信息中有任何一项为0,则BLEU分数为0
        return 0
    (c, r) = stats[:2]  # 假设序列长度和参考序列长度
    log_bleu_prec = sum(
        [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]
    ) / 4.  # 计算修正后的n-gram精度的几何平均值(对数形式)
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)  # 计算BLEU分数

# 获取验证集的BLEU分数
def get_bleu(hypotheses, reference):
    """获取验证集的BLEU分数。"""
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])  # 初始化一个numpy数组用于累加统计信息
    for hyp, ref in zip(hypotheses, reference):  # 遍历假设序列和参考序列
        stats += np.array(bleu_stats(hyp, ref))  # 累加每个假设序列和参考序列的BLEU统计信息
    return 100 * bleu(stats)  # 返回BLEU分数(乘以100转换为百分比形式)

# 将索引转换为单词
def idx_to_word(x, vocab):
    """将索引列表转换为单词序列。"""
    words = []  # 初始化一个空列表用于存储单词
    for i in x:  # 遍历索引列表
        word = vocab.itos[i]  # 根据索引查找对应的单词
        if '<' not in word:  # 如果单词中不包含特殊符号'<',则添加到单词列表中
            words.append(word)
    words = " ".join(words)  # 将单词列表转换为以空格分隔的字符串
    return words  # 返回单词序列
相关推荐
IT猿手1 小时前
2025最新群智能优化算法:山羊优化算法(Goat Optimization Algorithm, GOA)求解23个经典函数测试集,MATLAB
人工智能·python·算法·数学建模·matlab·智能优化算法
萧鼎2 小时前
深入解析 Umi-OCR:高效的免费开源 OCR 文字识别工具
python·ocr·umi-ocr
Jet45052 小时前
玩转ChatGPT:GPT 深入研究功能
人工智能·gpt·chatgpt·deep research·深入研究
毕加锁2 小时前
chatgpt完成python提取PDF简历指定内容的案例
人工智能·chatgpt
Wis4e4 小时前
基于PyTorch的深度学习3——基于autograd的反向传播
人工智能·pytorch·深度学习
西猫雷婶4 小时前
神经网络|(十四)|霍普菲尔德神经网络-Hebbian训练
人工智能·深度学习·神经网络
梦丶晓羽5 小时前
自然语言处理:文本分类
人工智能·python·自然语言处理·文本分类·朴素贝叶斯·逻辑斯谛回归
SuperCreators5 小时前
DeepSeek与浏览器自动化AI Agent构建指南
人工智能·自动化
美狐美颜sdk6 小时前
什么是美颜SDK?从几何变换到深度学习驱动的美颜算法详解
人工智能·深度学习·算法·美颜sdk·第三方美颜sdk·视频美颜sdk·美颜api
訾博ZiBo6 小时前
AI日报 - 2025年3月10日
人工智能