面试题:推导一下softmax中为啥要除以根号d

我整理好的1000+面试题,请看
大模型面试题总结-CSDN博客

或者

https://gitee.com/lilitom/ai_interview_questions/blob/master/README.md

最好将URL复制到浏览器中打开,不然可能无法直接打开


好了,我们今天针对上面的问题,

推导一下softmax中为啥要除以根号d

先看下具体的详细推导,文末有代码来可视化的证明。

  • 1. 定义与假设 这里我们需要事先做一个假设:

  • 和 是两个独立的随机矩阵,维度均为 。

  • 每个元素 和 是独立同分布(i.i.d.)的随机变量,满足:

    • 均值 ,

    • 方差 。

计算 的任意一个元素 =。

  • 2. 计算均值 由于 和 独立且均值为 0:
  • 3. 计算方差 方差定义为:

由于 ,第二项为 0,因此:

展开平方项:

因此:

由于 和 独立:

  • 对于交叉项 :

    因为 和 独立(),且 ,同理 。

  • 对于平方项:

综上:

  • 4. 缩放因子的作用 为了控制方差,使 的方差为 1(便于 Softmax 计算和梯度稳定),我们对其除以 :

此时,方差变为:

  • 5. 直观解释

  • 未缩放时: 的方差随 线性增长,导致 Softmax 输入值过大,输出接近 one-hot 分布,梯度消失。

  • 缩放后:方差稳定为 1,Softmax 输入分布更平滑,梯度传播更稳定。

让LLM写了一段代码

复制代码
import numpy as np
import matplotlib.pyplot as plt

def verify_scaled_dot_product_variance(d_k=64, num_samples=10000):
    """
    验证缩放点积注意力中QK^T的方差计算
    参数:
        d_k: 键/查询向量的维度
        num_samples: 采样次数
    """
    # 存储未缩放和缩放后的方差结果
    variances_unscaled = []
    variances_scaled = []
    
    for _ in range(num_samples):
        # 生成随机矩阵Q和K (元素服从标准正态分布 N(0,1))
        Q = np.random.randn(1, d_k)  # 1 x d_k
        K = np.random.randn(1, d_k)  # 1 x d_k
        
        # 计算点积 QK^T (这里Q和K都是单行向量,点积即内积)
        dot_product = np.dot(Q, K.T)[0, 0]  # 提取标量值
        
        # 缩放后的点积
        scaled_dot_product = dot_product / np.sqrt(d_k)
        
        variances_unscaled.append(dot_product ** 2)  # 因为E[QK^T]=0,方差=E[(QK^T)^2]
        variances_scaled.append(scaled_dot_product ** 2)
    
    # 转换为numpy数组
    variances_unscaled = np.array(variances_unscaled)
    variances_scaled = np.array(variances_scaled)
    
    # 计算统计方差
    empirical_var_unscaled = np.mean(variances_unscaled)
    empirical_var_scaled = np.mean(variances_scaled)
    
    # 理论值
    theoretical_var_unscaled = d_k
    theoretical_var_scaled = 1.0
    
    print(f"理论方差 (未缩放): {theoretical_var_unscaled:.2f}")
    print(f"实际方差 (未缩放): {empirical_var_unscaled:.2f}")
    print(f"理论方差 (缩放后): {theoretical_var_scaled:.2f}")
    print(f"实际方差 (缩放后): {empirical_var_scaled:.2f}")
    
    # 绘制结果
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.hist(variances_unscaled, bins=50, alpha=0.7)
    plt.axvline(theoretical_var_unscaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"未缩放点积方差 (d_k={d_k})")
    plt.xlabel("方差")
    plt.ylabel("频次")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(variances_scaled, bins=50, alpha=0.7, color='orange')
    plt.axvline(theoretical_var_scaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"缩放后点积方差 (除以√{d_k})")
    plt.xlabel("方差")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 运行验证 (默认d_k=64)
verify_scaled_dot_product_variance()

# 测试不同d_k值的影响
for d_k in [16, 64, 256]:
    print(f"\n验证 d_k = {d_k}:")
    verify_scaled_dot_product_variance(d_k=d_k, num_samples=10000)

运行结果如下:

看到没,哥哥

相关推荐
HIT_Weston4 小时前
45、【Agent】【OpenCode】本地代理分析(请求&接收回调)
人工智能·agent·opencode
逻辑君4 小时前
认知神经科学研究报告【20260010】
人工智能·深度学习·神经网络·机器学习
星河耀银海5 小时前
远控体验分享:安全与实用性参考
人工智能·安全·微服务
企业架构师老王5 小时前
2026企业架构演进:科普Agent(龙虾)如何从“极客玩具”走向实在Agent规模化落地?
人工智能·ai·架构
GreenTea5 小时前
一文搞懂Harness Engineering与Meta-Harness
前端·人工智能·后端
鬼先生_sir5 小时前
Spring AI Alibaba 1.1.2.2 完整知识点库
人工智能·ai·agent·源码解析·springai
深念Y5 小时前
豆包AI能力集成方案:基于会话管理的API网关设计
人工智能
龙文浩_5 小时前
Attention Mechanism: From Theory to Code
人工智能·深度学习·神经网络·学习·自然语言处理
ulimate_5 小时前
八卡算力、三个Baseline算法(WALLOSS、pi0、DreamZero)
人工智能
深小乐5 小时前
AI 周刊【2026.04.06-04.12】:Anthropic 藏起最强模型、AI 社会矛盾激化、"欢乐马"登顶
人工智能