面试题:推导一下softmax中为啥要除以根号d

我整理好的1000+面试题,请看
大模型面试题总结-CSDN博客

或者

https://gitee.com/lilitom/ai_interview_questions/blob/master/README.md

最好将URL复制到浏览器中打开,不然可能无法直接打开


好了,我们今天针对上面的问题,

推导一下softmax中为啥要除以根号d

先看下具体的详细推导,文末有代码来可视化的证明。

  • 1. 定义与假设 这里我们需要事先做一个假设:

  • 和 是两个独立的随机矩阵,维度均为 。

  • 每个元素 和 是独立同分布(i.i.d.)的随机变量,满足:

    • 均值 ,

    • 方差 。

计算 的任意一个元素 =。

  • 2. 计算均值 由于 和 独立且均值为 0:
  • 3. 计算方差 方差定义为:

由于 ,第二项为 0,因此:

展开平方项:

因此:

由于 和 独立:

  • 对于交叉项 :

    因为 和 独立(),且 ,同理 。

  • 对于平方项:

综上:

  • 4. 缩放因子的作用 为了控制方差,使 的方差为 1(便于 Softmax 计算和梯度稳定),我们对其除以 :

此时,方差变为:

  • 5. 直观解释

  • 未缩放时: 的方差随 线性增长,导致 Softmax 输入值过大,输出接近 one-hot 分布,梯度消失。

  • 缩放后:方差稳定为 1,Softmax 输入分布更平滑,梯度传播更稳定。

让LLM写了一段代码

复制代码
import numpy as np
import matplotlib.pyplot as plt

def verify_scaled_dot_product_variance(d_k=64, num_samples=10000):
    """
    验证缩放点积注意力中QK^T的方差计算
    参数:
        d_k: 键/查询向量的维度
        num_samples: 采样次数
    """
    # 存储未缩放和缩放后的方差结果
    variances_unscaled = []
    variances_scaled = []
    
    for _ in range(num_samples):
        # 生成随机矩阵Q和K (元素服从标准正态分布 N(0,1))
        Q = np.random.randn(1, d_k)  # 1 x d_k
        K = np.random.randn(1, d_k)  # 1 x d_k
        
        # 计算点积 QK^T (这里Q和K都是单行向量,点积即内积)
        dot_product = np.dot(Q, K.T)[0, 0]  # 提取标量值
        
        # 缩放后的点积
        scaled_dot_product = dot_product / np.sqrt(d_k)
        
        variances_unscaled.append(dot_product ** 2)  # 因为E[QK^T]=0,方差=E[(QK^T)^2]
        variances_scaled.append(scaled_dot_product ** 2)
    
    # 转换为numpy数组
    variances_unscaled = np.array(variances_unscaled)
    variances_scaled = np.array(variances_scaled)
    
    # 计算统计方差
    empirical_var_unscaled = np.mean(variances_unscaled)
    empirical_var_scaled = np.mean(variances_scaled)
    
    # 理论值
    theoretical_var_unscaled = d_k
    theoretical_var_scaled = 1.0
    
    print(f"理论方差 (未缩放): {theoretical_var_unscaled:.2f}")
    print(f"实际方差 (未缩放): {empirical_var_unscaled:.2f}")
    print(f"理论方差 (缩放后): {theoretical_var_scaled:.2f}")
    print(f"实际方差 (缩放后): {empirical_var_scaled:.2f}")
    
    # 绘制结果
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.hist(variances_unscaled, bins=50, alpha=0.7)
    plt.axvline(theoretical_var_unscaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"未缩放点积方差 (d_k={d_k})")
    plt.xlabel("方差")
    plt.ylabel("频次")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(variances_scaled, bins=50, alpha=0.7, color='orange')
    plt.axvline(theoretical_var_scaled, color='r', linestyle='--', label='理论方差')
    plt.title(f"缩放后点积方差 (除以√{d_k})")
    plt.xlabel("方差")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 运行验证 (默认d_k=64)
verify_scaled_dot_product_variance()

# 测试不同d_k值的影响
for d_k in [16, 64, 256]:
    print(f"\n验证 d_k = {d_k}:")
    verify_scaled_dot_product_variance(d_k=d_k, num_samples=10000)

运行结果如下:

看到没,哥哥

相关推荐
数智工坊11 小时前
【SAM-DETR论文阅读】:基于语义对齐匹配的DETR极速收敛检测框架
网络·论文阅读·人工智能·深度学习·transformer
风落无尘11 小时前
LangChain 完全入门指南:从基础到实战(附面试题)
人工智能·langchain
IT_陈寒11 小时前
Vue的这个响应式陷阱,我debug了一整天才爬出来
前端·人工智能·后端
zz_lzh11 小时前
arm版AI牛马:armbian(rk3588)设备部署openclaw
arm开发·人工智能·arm
AI医影跨模态组学11 小时前
如何通过影像组学模型无创预测三阴性乳腺癌中的三级淋巴结构(TLSs),并借助病理组学揭示其与治疗响应、预后及细胞侵袭性表型的机制联系
人工智能·论文·医学·医学影像·影像组学·医学科研
Awesome Baron12 小时前
skill、tool calling、MCP区别
开发语言·人工智能·python
十铭忘12 小时前
FLUX.1架构的理解5——Transformer之前的预处理
人工智能
weixin_4462608512 小时前
城市智能化的底层基石:基于腾讯地图服务生态的移动定位与导航架构指引
大数据·人工智能·架构
Smilezyl12 小时前
一个独立开发者,靠一份 markdown 驱动 Claude Code, 用 20 天跑通 9 个包的 monorepo 工程
前端·人工智能·github