面试题:推导一下softmax中为啥要除以根号d

我整理好的1000+面试题,请看
大模型面试题总结-CSDN博客

或者

https://gitee.com/lilitom/ai_interview_questions/blob/master/README.md

最好将URL复制到浏览器中打开,不然可能无法直接打开


好了,我们今天针对上面的问题,

推导一下softmax中为啥要除以根号d

先看下具体的详细推导,文末有代码来可视化的证明。

  • 1. 定义与假设 这里我们需要事先做一个假设:

  • 和 是两个独立的随机矩阵,维度均为 。

  • 每个元素 和 是独立同分布(i.i.d.)的随机变量,满足:

    • 均值 ,

    • 方差 。

计算 的任意一个元素 =。

  • 2. 计算均值 由于 和 独立且均值为 0:
  • 3. 计算方差 方差定义为:

由于 ,第二项为 0,因此:

展开平方项:

因此:

由于 和 独立:

  • 对于交叉项 :

    因为 和 独立(),且 ,同理 。

  • 对于平方项:

综上:

  • 4. 缩放因子的作用 为了控制方差,使 的方差为 1(便于 Softmax 计算和梯度稳定),我们对其除以 :

此时,方差变为:

  • 5. 直观解释

  • 未缩放时: 的方差随 线性增长,导致 Softmax 输入值过大,输出接近 one-hot 分布,梯度消失。

  • 缩放后:方差稳定为 1,Softmax 输入分布更平滑,梯度传播更稳定。

让LLM写了一段代码

复制代码
import numpy as np
import matplotlib.pyplot as plt

def verify_scaled_dot_product_variance(d_k=64, num_samples=10000):
    """
    验证缩放点积注意力中QK^T的方差计算
    参数:
        d_k: 键/查询向量的维度
        num_samples: 采样次数
    """
    # 存储未缩放和缩放后的方差结果
    variances_unscaled = []
    variances_scaled = []
    
    for _ in range(num_samples):
        # 生成随机矩阵Q和K (元素服从标准正态分布 N(0,1))
        Q = np.random.randn(1, d_k)  # 1 x d_k
        K = np.random.randn(1, d_k)  # 1 x d_k
        
        # 计算点积 QK^T (这里Q和K都是单行向量,点积即内积)
        dot_product = np.dot(Q, K.T)[0, 0]  # 提取标量值
        
        # 缩放后的点积
        scaled_dot_product = dot_product / np.sqrt(d_k)
        
        variances_unscaled.append(dot_product ** 2)  # 因为E[QK^T]=0,方差=E[(QK^T)^2]
        variances_scaled.append(scaled_dot_product ** 2)
    
    # 转换为numpy数组
    variances_unscaled = np.array(variances_unscaled)
    variances_scaled = np.array(variances_scaled)
    
    # 计算统计方差
    empirical_var_unscaled = np.mean(variances_unscaled)
    empirical_var_scaled = np.mean(variances_scaled)
    
    # 理论值
    theoretical_var_unscaled = d_k
    theoretical_var_scaled = 1.0
    
    print(f"理论方差 (未缩放): {theoretical_var_unscaled:.2f}")
    print(f"实际方差 (未缩放): {empirical_var_unscaled:.2f}")
    print(f"理论方差 (缩放后): {theoretical_var_scaled:.2f}")
    print(f"实际方差 (缩放后): {empirical_var_scaled:.2f}")
    
    # 绘制结果
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.hist(variances_unscaled, bins=50, alpha=0.7)
    plt.axvline(theoretical_var_unscaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"未缩放点积方差 (d_k={d_k})")
    plt.xlabel("方差")
    plt.ylabel("频次")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(variances_scaled, bins=50, alpha=0.7, color='orange')
    plt.axvline(theoretical_var_scaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"缩放后点积方差 (除以√{d_k})")
    plt.xlabel("方差")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 运行验证 (默认d_k=64)
verify_scaled_dot_product_variance()

# 测试不同d_k值的影响
for d_k in [16, 64, 256]:
    print(f"\n验证 d_k = {d_k}:")
    verify_scaled_dot_product_variance(d_k=d_k, num_samples=10000)

运行结果如下:

看到没,哥哥

相关推荐
蓝速科技几秒前
蓝速科技视觉 3D 全息舱 AI 数字人一体机带灯与无灯款深度评测
人工智能·科技·3d
武子康8 分钟前
调查研究-213 UBTech U1:当人形机器人从“听懂指令“走向“情绪陪伴“
人工智能·机器人·agent
xyz_CDragon9 分钟前
OpenAI发布首款自研芯片Jalapeño:9个月流片,AI设计芯片的时代来了
人工智能·单片机·深度学习·神经网络·芯片设计
gooxi_hui15 分钟前
海量存力,智驭未来丨国鑫4U60盘位高密度存储服务器SL401-G4重磅上市
运维·服务器·人工智能
太子釢17 分钟前
Claude Code 工具调用机制详解
人工智能
ksueh23 分钟前
AI写小说接入文心一言教程:千帆API+向量记忆系统实现百万字长篇智能创作
人工智能·ai助手
不焦躁的程序员27 分钟前
程序员该补获客能力了
人工智能·程序员
AI科技星28 分钟前
基于32维Cayley_Dickson超复数的全域拓扑统一场论——反重力、真空自持供能、维度瞬移与星际宇宙脑秩序体系
人工智能·学习·算法·机器学习·数据挖掘
星马梦缘32 分钟前
机器学习与模式识别 第十四章 神经网络中的反向传播 模拟卷及答案
人工智能·神经网络·机器学习·微分·反向传播
吴bug35 分钟前
认识 Open-ACE — AI 编程智能体的工作空间
人工智能·ai·ai编程