PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化

在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致数值溢出。torch.logsumexp 正是为了解决这一问题而设计的。


  • torch.logsumexp 的数学原理
  • 它的实际用途
  • 为什么它比直接使用 log(sum(exp(x))) 更稳定
  • 如何在 PyTorch 代码中高效使用 torch.logsumexp

1. torch.logsumexp 是什么?

1.1 数学公式

torch.logsumexp(x, dim) 计算以下数学表达式:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logi∑exi


  • ( x i x_i xi ) 是输入张量中的元素,
  • dim 指定沿哪个维度执行计算。

1.2 为什么不直接计算 log(sum(exp(x)))

假设我们有一个很大的数值 ( x ),比如 x = 1000,如果直接计算:

python 复制代码
import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # 结果是 inf(溢出)

问题: exp(1000) 太大,超出了浮点数表示范围,导致溢出。

torch.logsumexp 解决方案:
log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logi∑exi=xmax+logi∑e(xi−xmax)

  • 核心思想 :先减去最大值 ( x max ⁡ x_{\max} xmax )(防止指数爆炸),然后再计算指数和的对数。
  • 这样能避免溢出,提高数值稳定性。

使用 torch.logsumexp

python 复制代码
log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # 正常输出

它不会溢出,因为先减去了最大值,再进行 log 操作。

2. torch.logsumexp 的实际应用

2.1 用于计算 softmax

Softmax 计算公式:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=∑jexjexi

取对数后,得到对数 softmax (log-softmax):
log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xi−logj∑exj

PyTorch 代码:

python 复制代码
import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)

这避免了指数溢出,比直接计算 torch.log(torch.sum(torch.exp(x))) 更稳定。

2.2 用于计算交叉熵损失


L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=−i∑yilogP(xi)

其中 ( P ( x i ) P(x_i) P(xi) ) 通过 softmax 计算得到,而 torch.logsumexp 让 softmax 的分母计算更稳定。

2.3 在 Transformer 模型中的应用

GPT、BERT 等 Transformer 语言模型 训练过程中,我们通常会计算 token_log_probs,如下:

python 复制代码
import torch

logits = torch.randn(4, 5)  # 假设 batch_size=4, vocab_size=5
input_ids = torch.tensor([1, 2, 3, 4])  # 假设真实的 token 位置

# 计算每个 token 的对数概率
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values


这里 torch.logsumexp(logits, dim=-1) 用于计算 softmax 分母的对数值,确保概率计算不会溢出。

3. torch.logsumexp 的性能优化

3.1 为什么 torch.logsumexplog(sum(exp(x))) 更快?

  • 避免额外存储 exp(x) :如果先 exp(x),再 sum(),会生成一个额外的大张量,而 logsumexp 直接在 C++/CUDA 内部优化了计算。
  • 减少数值溢出:减少浮点数不必要的运算,防止梯度爆炸。

3.2 实测性能

python 复制代码
import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")


c 复制代码
torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp 速度更快,并且避免了 exp(x) 可能导致的溢出。

4. 总结

  • torch.logsumexp(x, dim) 计算 log(sum(exp(x))),但使用数值稳定的方法,防止溢出。
  • 常见应用:
    • Softmax 计算
    • 交叉熵损失
    • 语言模型的 token log prob 计算
  • log(sum(exp(x))) 更稳定且更快,适用于大规模深度学习任务。


🚀 在涉及 log(sum(exp(x))) 计算时,尽量使用 torch.logsumexp,可以大幅提升数值稳定性和计算效率! 🚀

Understanding torch.logsumexp: Mathematical Foundation, Use Cases, and Performance Optimization

In deep learning, especially in probability models, computing logarithmic probabilities in a numerically stable way is crucial. Directly applying log(sum(exp(x))) can lead to numerical instability due to floating-point overflow . torch.logsumexp is designed to solve this problem efficiently.

In this article, we will cover:

  • The mathematical foundation of torch.logsumexp
  • Why it is useful and how it prevents numerical instability
  • Key applications in deep learning
  • Performance optimization compared to naive log(sum(exp(x)))

1. What is torch.logsumexp?

1.1 Mathematical Formula

torch.logsumexp(x, dim) computes the following function:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logi∑exi


  • ( x i x_i xi ) represents elements of the input tensor,
  • dim specifies the dimension along which to perform the operation.

1.2 Why Not Directly Compute log(sum(exp(x)))?

Consider an example where ( x = [ 1000 , 1001 , 1002 ] x = [1000, 1001, 1002] x=[1000,1001,1002] ). If we naively compute:

python 复制代码
import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # Output: inf (overflow)


  • exp(1000) is too large , exceeding the floating-point limit, causing an overflow.

Solution: Log-Sum-Exp Trick

To prevent overflow, torch.logsumexp applies the following transformation:

log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logi∑exi=xmax+logi∑e(xi−xmax)

where ( x max ⁡ x_{\max} xmax ) is the maximum value in ( x x x ).

  • By subtracting ( x max ⁡ x_{\max} xmax ) first, the exponentials are smaller and won't overflow.

Example using torch.logsumexp:

python 复制代码
log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # Outputs a valid value without overflow

This is more numerically stable.

2. Key Applications of torch.logsumexp

2.1 Computing Softmax in Log Space

The Softmax function is defined as:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=∑jexjexi

Taking the log:

log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xi−logj∑exj

Using PyTorch:

python 复制代码
import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)

This avoids computing exp(x), preventing numerical instability.

2.2 Cross-Entropy Loss Computation

Cross-entropy loss:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=−i∑yilogP(xi)

where ( P ( x i ) P(x_i) P(xi) ) is computed using Softmax .

Using torch.logsumexp, we avoid overflow in the denominator:

python 复制代码
logits = torch.tensor([[2.0, 1.0, 0.1]])
logsumexp_values = torch.logsumexp(logits, dim=-1)

This technique is used in torch.nn.CrossEntropyLoss.

2.3 Token Log Probabilities in Transformer Models

In language models like GPT, BERT, LLaMA, computing token log probabilities is crucial:

python 复制代码
import torch

logits = torch.randn(4, 5)  # Simulated logits for 4 tokens, vocab size 5
input_ids = torch.tensor([1, 2, 3, 4])  # Token positions

# Gather the logits corresponding to the actual tokens
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

# Compute log probabilities
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values


Here, torch.logsumexp ensures stable probability computation by handling large exponentiations.

3. Performance Optimization

3.1 Why is torch.logsumexp Faster?

Instead of:

python 复制代码


  1. Computes exp(x), creating an intermediate tensor.
  2. Sums the tensor.
  3. Computes log(sum(exp(x))).


  • Avoids unnecessary tensor storage.
  • Optimizes computation at the C++/CUDA level.
  • Improves numerical stability.

3.2 Performance Benchmark

python 复制代码
import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")


c 复制代码
torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp is significantly faster and more stable.

4. Summary

  • torch.logsumexp(x, dim) computes log(sum(exp(x))) safely, preventing overflow.
  • Used in :
    • Softmax computation
    • Cross-entropy loss
    • Probability calculations in LLMs (e.g., GPT, BERT)
  • More efficient than log(sum(exp(x))) due to internal optimizations.

🚀 Always prefer torch.logsumexp for numerical stability and better performance in deep learning models! 🚀



