面试题:推导一下softmax中为啥要除以根号d

我整理好的1000+面试题,请看
大模型面试题总结-CSDN博客

或者

https://gitee.com/lilitom/ai_interview_questions/blob/master/README.md

最好将URL复制到浏览器中打开,不然可能无法直接打开


好了,我们今天针对上面的问题,

推导一下softmax中为啥要除以根号d

先看下具体的详细推导,文末有代码来可视化的证明。

  • 1. 定义与假设 这里我们需要事先做一个假设:

  • 和 是两个独立的随机矩阵,维度均为 。

  • 每个元素 和 是独立同分布(i.i.d.)的随机变量,满足:

    • 均值 ,

    • 方差 。

计算 的任意一个元素 =。

  • 2. 计算均值 由于 和 独立且均值为 0:
  • 3. 计算方差 方差定义为:

由于 ,第二项为 0,因此:

展开平方项:

因此:

由于 和 独立:

  • 对于交叉项 :

    因为 和 独立(),且 ,同理 。

  • 对于平方项:

综上:

  • 4. 缩放因子的作用 为了控制方差,使 的方差为 1(便于 Softmax 计算和梯度稳定),我们对其除以 :

此时,方差变为:

  • 5. 直观解释

  • 未缩放时: 的方差随 线性增长,导致 Softmax 输入值过大,输出接近 one-hot 分布,梯度消失。

  • 缩放后:方差稳定为 1,Softmax 输入分布更平滑,梯度传播更稳定。

让LLM写了一段代码

复制代码
import numpy as np
import matplotlib.pyplot as plt

def verify_scaled_dot_product_variance(d_k=64, num_samples=10000):
    """
    验证缩放点积注意力中QK^T的方差计算
    参数:
        d_k: 键/查询向量的维度
        num_samples: 采样次数
    """
    # 存储未缩放和缩放后的方差结果
    variances_unscaled = []
    variances_scaled = []
    
    for _ in range(num_samples):
        # 生成随机矩阵Q和K (元素服从标准正态分布 N(0,1))
        Q = np.random.randn(1, d_k)  # 1 x d_k
        K = np.random.randn(1, d_k)  # 1 x d_k
        
        # 计算点积 QK^T (这里Q和K都是单行向量,点积即内积)
        dot_product = np.dot(Q, K.T)[0, 0]  # 提取标量值
        
        # 缩放后的点积
        scaled_dot_product = dot_product / np.sqrt(d_k)
        
        variances_unscaled.append(dot_product ** 2)  # 因为E[QK^T]=0,方差=E[(QK^T)^2]
        variances_scaled.append(scaled_dot_product ** 2)
    
    # 转换为numpy数组
    variances_unscaled = np.array(variances_unscaled)
    variances_scaled = np.array(variances_scaled)
    
    # 计算统计方差
    empirical_var_unscaled = np.mean(variances_unscaled)
    empirical_var_scaled = np.mean(variances_scaled)
    
    # 理论值
    theoretical_var_unscaled = d_k
    theoretical_var_scaled = 1.0
    
    print(f"理论方差 (未缩放): {theoretical_var_unscaled:.2f}")
    print(f"实际方差 (未缩放): {empirical_var_unscaled:.2f}")
    print(f"理论方差 (缩放后): {theoretical_var_scaled:.2f}")
    print(f"实际方差 (缩放后): {empirical_var_scaled:.2f}")
    
    # 绘制结果
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.hist(variances_unscaled, bins=50, alpha=0.7)
    plt.axvline(theoretical_var_unscaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"未缩放点积方差 (d_k={d_k})")
    plt.xlabel("方差")
    plt.ylabel("频次")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(variances_scaled, bins=50, alpha=0.7, color='orange')
    plt.axvline(theoretical_var_scaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"缩放后点积方差 (除以√{d_k})")
    plt.xlabel("方差")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 运行验证 (默认d_k=64)
verify_scaled_dot_product_variance()

# 测试不同d_k值的影响
for d_k in [16, 64, 256]:
    print(f"\n验证 d_k = {d_k}:")
    verify_scaled_dot_product_variance(d_k=d_k, num_samples=10000)

运行结果如下:

看到没,哥哥

相关推荐
Elwin Wong12 小时前
浅析OpenClaw:从“贾维斯”梦想看下一代 AI 操作系统的架构演进
人工智能·agent·clawdbot·moltbot·openclaw
Rorsion12 小时前
PyTorch实现线性回归
人工智能·pytorch·线性回归
AI资源库12 小时前
OpenClaw:159K Star的开源AI助手正在重新定义“个人AI“的边界
人工智能·语言模型
凯子坚持 c12 小时前
StreamingLLM:无需训练即可支持无限上下文的推理技术
人工智能
Tfly__12 小时前
在PX4 gazebo仿真中加入Mid360(最新)
linux·人工智能·自动驾驶·ros·无人机·px4·mid360
LLWZAI12 小时前
让朱雀AI检测无法判断的AI公众号文章,当创作者开始与算法「躲猫猫」
大数据·人工智能·深度学习
深圳市九鼎创展科技12 小时前
瑞芯微 RK3399 开发板 X3399 评测:高性能 ARM 平台的多面手
linux·arm开发·人工智能·单片机·嵌入式硬件·边缘计算
HELLO程序员12 小时前
Claude Code 2.1 发布:2026 年 AI 智能体开发的范式革命
人工智能
DFCED12 小时前
OpenClaw部署实战:5分钟搭建你的专属AI数字员工(附避坑指南)
人工智能·大模型·agent·openclaw
Java新手村12 小时前
基于 Vue 3 + Spring Boot 3 的 AI 面试辅助系统:实时语音识别 + 大模型智能回答
vue.js·人工智能·spring boot