【深度学习】序列生成模型(四):评价方法

文章目录

  • 一、困惑度(Perplexity)
    • [1. 定义](#1. 定义)
    • [2. 计算](#2. 计算)
    • [3. 衡量两个分布之间的差异](#3. 衡量两个分布之间的差异)
    • [4. 意义](#4. 意义)
  • [二、BLEU(Bilingual Evaluation Understudy)](#二、BLEU(Bilingual Evaluation Understudy))
    • [1. 定义](#1. 定义)
    • [2. 意义](#2. 意义)
    • [3. 实例](#3. 实例)
  • [三、ROUGE(Recall-Oriented Understudy for Gisting Evaluation)](#三、ROUGE(Recall-Oriented Understudy for Gisting Evaluation))
    • [1. 定义](#1. 定义)
    • [2. 意义](#2. 意义)
    • [3. 实例](#3. 实例)
  • 四、人工评估

构建序列生成模型后,为了评价其性能,通常采用一些度量方法。本文将介绍一些常见的评价方法:

一、困惑度(Perplexity)

困惑度(Perplexity)是一种用来衡量序列生成模型性能的指标。在给定一个测试文本集合的情况下,一个好的序列生成模型应该使得测试集合中句子的联合概率尽可能高。困惑度是信息论中的一个概念,用来度量一个分布的不确定性。

1. 定义

对于离散随机变量 X ∈ X X \in \mathcal{X} X∈X,其概率分布为 p ( x ) p(x) p(x),困惑度定义如下:
Perplexity ( X , p ) = 2 H ( p ) = 2 − ∑ x ∈ X p ( x ) log ⁡ 2 p ( x ) \text{Perplexity}(\mathcal{X}, p) = 2^{H(p)}=2^{- \sum_{x \in \mathcal{X}} p(x) \log_2 p(x)} Perplexity(X,p)=2H(p)=2−∑x∈Xp(x)log2p(x)这里的熵 H ( p ) H(p) H(p) 衡量了分布 p p p 的不确定性。困惑度可以看作是对观察到的数据集的估计概率的逆

2. 计算

考虑一个序列长度为 (T) 的测试集,模型的困惑度为:

PPL ( θ ) = 2 − 1 T ∑ n = 1 N ∑ t = 1 T n log ⁡ 2 p θ ( x n , t ∣ x n , 1 t − 1 ) \text{PPL}(\theta) = 2^{- \frac{1}{T} \sum_{n=1}^{N} \sum_{t=1}^{T_n} \log_2 p_\theta(x_{n,t} | x_{n,1}^{t-1})} PPL(θ)=2−T1∑n=1N∑t=1Tnlog2pθ(xn,t∣xn,1t−1)

其中 N N N 为测试集中序列的数量, T n T_n Tn 为第 n n n 个序列的长度, p θ p_\theta pθ 是模型对条件概率的估计。困惑度越低,表示模型在给定数据上的拟合越好。

3. 衡量两个分布之间的差异

对于一个未知的数据分布 p true ( x ) p_{\text{true}}(x) ptrue(x) 和一个模型分布 p θ ( x ) p_\theta(x) pθ(x),困惑度可以用来衡量它们之间的差异。两者之间的交叉熵(cross entropy)为:

Cross Entropy ( p true , p θ ) = − 1 T ∑ n = 1 N ∑ t = 1 T n log ⁡ 2 p θ ( x n , t ) \text{Cross Entropy}(p_{\text{true}}, p_\theta) = -\frac{1}{T} \sum_{n=1}^{N} \sum_{t=1}^{T_n} \log_2 p_\theta(x_{n,t}) Cross Entropy(ptrue,pθ)=−T1n=1∑Nt=1∑Tnlog2pθ(xn,t)

困惑度可以表示为交叉熵的形式:

PPL ( θ ) = 2 Cross Entropy ( p true , p θ ) / T \text{PPL}(\theta) = 2^{\text{Cross Entropy}(p_{\text{true}}, p_\theta) / T} PPL(θ)=2Cross Entropy(ptrue,pθ)/T

困惑度越低,表示模型分布与真实数据分布越接近。

4. 意义

困惑度为每个词条件概率的几何平均数的倒数。测试集中所有序列的概率越大,困惑度越小,模型越好。一般情况下,困惑度范围在50到1000之间。在自然语言处理中,困惑度是一个常用的评估指标,用于衡量语言模型的性能。

二、BLEU(Bilingual Evaluation Understudy)

BLEU(BiLingual Evaluation Understudy)算法是一种用于衡量机器翻译模型或其他序列生成任务中生成序列和参考序列之间的相似度的评价指标。该算法通过计算N元词组(N-Gram)的重合度来评估生成序列的质量。

1. 定义

设 𝒙 为模型生成的候选序列, s ( 1 ) , ⋯ , s ( K ) \mathbf{s^{(1)}}, ⋯ , \mathbf{s^{(K)}} s(1),⋯,s(K) 为一组参考序列,𝒲 为从生成的候选序列中提取所有N元组合的集合。BLEU算法的精度(Precision)定义如下:

P N ( x ) = ∑ w ∈ W min ⁡ ( c w ( x ) , max ⁡ k = 1 K c w ( s k ) ) ∑ w ∈ W c w ( x ) P_N(\mathbf{x}) = \frac{\sum_{w \in \mathcal{W}} \min(c_w(\mathbf{x}), \max_{k=1}^{K} c_w(\mathbf{s}k))}{\sum{w \in \mathcal{W}} c_w(\mathbf{x})} PN(x)=∑w∈Wcw(x)∑w∈Wmin(cw(x),maxk=1Kcw(sk))

其中 c w ( x ) c_w(\mathbf{x}) cw(x) 是N元组合 w w w 在生成序列 x \mathbf{x} x 中出现的次数, c w ( s k ) c_w(\mathbf{s}_k) cw(sk) 是N元组合 w w w 在参考序列 s k \mathbf{s}_k sk 中出现的次数。

为了处理生成序列长度短于参考序列的情况,引入长度惩罚因子 b ( x ) b(\mathbf{x}) b(x):

b ( x ) = { 1 if l x > l s exp ⁡ ( 1 − l s l x ) if l x ≤ l s b(\mathbf{x}) = \begin{cases} 1 & \text{if } l_x > l_s \\ \exp\left(1 - \frac{l_s}{l_x}\right) & \text{if } l_x \leq l_s \end{cases} b(x)={1exp(1−lxls)if lx>lsif lx≤ls

其中 l x l_x lx 是生成序列的长度, l s l_s ls 是参考序列的最短长度。

BLEU算法通过计算不同长度的N元组合的精度,并进行几何加权平均,得到最终的BLEU分数:

BLEU-N ( x ) = b ( x ) × exp ⁡ ( ∑ N = 1 N ′ α N log ⁡ P N ( x ) ) \text{BLEU-N}(\mathbf{x}) = b(\mathbf{x}) \times \exp\left(\sum_{N=1}^{N'} \alpha_N \log P_N(\mathbf{x})\right) BLEU-N(x)=b(x)×exp N=1∑N′αNlogPN(x)

其中 N ′ N' N′ 为最长N元组合的长度, α N \alpha_N αN 是不同N元组合的权重,一般设为 1 / N ′ 1/N' 1/N′。

2. 意义

  • BLEU算法的值域范围是 [0, 1],值越大表示生成的序列与参考序列越相似,质量越高。
  • BLEU只关注精度,不考虑召回率,即不关心参考序列中的N元组合是否在生成序列中出现。

3. 实例

【深度学习】序列生成模型(五):评价方法计算实例:计算BLEU-N得分【理论到程序】

python 复制代码
main_string = 'the cat sat on the mat'
string1 = 'the cat is on the mat'
string2 = 'the bird sat on the bush'

# 计算单词
unique_words = set(main_string.split())
total_occurrences, matching_occurrences = 0, 0

for word in unique_words:
    count_main_string = main_string.count(word)
    total_occurrences += count_main_string
    matching_occurrences += min(count_main_string, max(string1.count(word), string2.count(word)))

similarity_word = matching_occurrences / total_occurrences
print(f"N=1: {similarity_word}")

# 计算双词
word_tokens = main_string.split()
bigrams = set([f"{word_tokens[i]} {word_tokens[i + 1]}" for i in range(len(word_tokens) - 1)])
total_occurrences, matching_occurrences = 0, 0

for bigram in bigrams:
    count_main_string = main_string.count(bigram)
    total_occurrences += count_main_string
    matching_occurrences += min(count_main_string, max(string1.count(bigram), string2.count(bigram)))

similarity_bigram = matching_occurrences / total_occurrences
print(f"N=2: {similarity_bigram}")

三、ROUGE(Recall-Oriented Understudy for Gisting Evaluation)

ROUGE(Recall-Oriented Understudy for Gisting Evaluation)算法最初被应用于文本摘要领域,类似于BLEU算法,但ROUGE算法关注的是召回率(Recall)。

1. 定义

设 x \mathbf{x} x 为从模型分布 p θ p_{\theta} pθ 中生成的一个候选序列, s ( 1 ) , ⋯ , s ( K ) \mathbf{s^{(1)}}, ⋯ , \mathbf{s^{(K)}} s(1),⋯,s(K) 为从真实数据分布中采样得到的一组参考序列, W \mathcal{W} W 为从参考序列中提取N元组合的集合,ROUGE-N算法的定义为:

ROUGE-N ( x ) = ∑ k = 1 K ∑ w ∈ W min ⁡ ( c w ( x ) , c w ( s ( k ) ) ) ∑ k = 1 K ∑ w ∈ W c w ( s ( k ) ) \text{ROUGE-N}(\mathbf{x}) = \frac{\sum_{k=1}^{K} \sum_{w \in \mathcal{W}} \min(c_w(\mathbf{x}), c_w(\mathbf{s}(k)))}{\sum_{k=1}^{K} \sum_{w \in \mathcal{W}} c_w(\mathbf{s}(k))} ROUGE-N(x)=∑k=1K∑w∈Wcw(s(k))∑k=1K∑w∈Wmin(cw(x),cw(s(k)))

其中 c w ( x ) c_w(\mathbf{x}) cw(x) 是N元组合 w w w 在生成序列 x \mathbf{x} x 中出现的次数, c w ( s ( k ) ) c_w(\mathbf{s}(k)) cw(s(k)) 是N元组合 w w w 在参考序列 s ( k ) \mathbf{s}(k) s(k) 中出现的次数。

2. 意义

  • ROUGE算法的评价重点是召回率,即生成序列中有多少N元组合与参考序列中的N元组合相匹配。与BLEU算法不同,ROUGE更注重生成序列覆盖参考序列的内容。
  • ROUGE-N可用于评估模型生成的文本与参考文本之间的相似性,尤其在文本摘要等任务中常被使用。

3. 实例

【深度学习】序列生成模型(六):评价方法计算实例:计算ROUGE-N得分【理论到程序】

python 复制代码
main_string = 'the cat sat on the mat'
string1 = 'the cat is on the mat'
string2 = 'the bird sat on the bush'

words = list(set(string1.split(' ')+string2.split(' ')))  # 去除重复元素

total_occurrences, matching_occurrences = 0, 0
for word in words:
    matching_occurrences += min(main_string.count(word), string1.count(word)) + min(main_string.count(word), string2.count(word))
    total_occurrences += string1.count(word) + string2.count(word)

print(matching_occurrences / total_occurrences)

bigrams = []
split1 = string1.split(' ')
for i in range(len(split1) - 1):
    bigrams.append(split1[i] + ' ' + split1[i + 1])

split2 = string2.split(' ')
for i in range(len(split2) - 1):
    bigrams.append(split2[i] + ' ' + split2[i + 1])

bigrams = list(set(bigrams))  # 去除重复元素

total_occurrences, matching_occurrences = 0, 0
for bigram in bigrams:
    matching_occurrences += min(main_string.count(bigram), string1.count(bigram)) + min(main_string.count(bigram), string2.count(bigram))
    total_occurrences += string1.count(bigram) + string2.count(bigram)

print(matching_occurrences / total_occurrences)

四、人工评估

  • 定义: 通过人工评价来获取生成序列的质量,可以包括流畅性、准确性等方面。

  • 解释: 人工评估是一种直观且综合性的评估方法,但相对来说较为主观。

在实际应用中,通常会综合使用多个评价指标,以全面评估生成模型的性能。

相关推荐
卓_尔_不_凡9 分钟前
Pytorch学习---基于经典网络架构ResNet训练花卉图像分类模型
人工智能·分类·数据挖掘
神奇夜光杯18 分钟前
Python酷库之旅-第三方库Pandas(123)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
SEU-WYL22 分钟前
基于神经网络的光线追踪
人工智能·神经网络·计算机视觉
Bill6623 分钟前
OpenCV GUI常用函数详解
人工智能·opencv·计算机视觉
DisonTangor24 分钟前
OpenAI面向开发者继续提高o1系列模型的调用速率 最高每分钟可调用1000次
人工智能
zhangbin_23725 分钟前
【Python机器学习】NLP信息提取——提取人物/事物关系
开发语言·人工智能·python·机器学习·自然语言处理
王豫翔25 分钟前
OpenAl o1论文:Let’s Verify Step by Step 快速解读
人工智能·深度学习·机器学习·chatgpt
xuehaikj31 分钟前
婴儿接触危险物品检测系统源码分享
人工智能·计算机视觉·目标跟踪
crownyouyou1 小时前
第一次安装Pytorch
人工智能·pytorch·python
qq_435070781 小时前
python乱炖6——sum(),指定维度进行求和
pytorch·python·深度学习