腾讯送命题:手写多头注意力机制。。。

最近这一两周不少公司已开启春招和实习招聘。

不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。

最近,我们又陆续整理了很多大厂的面试题,帮助一些球友解惑答疑,分享技术面试中的那些弯弯绕绕。

总结如下:


今天就来聊聊那些年我们一起踩过的大模型面试坑。

有个球友遇到了这个面试题:原题:手写实现多头注意力机制(MHA),并加入键值缓存(KV cache)

看到这题的时候,他内心是崩溃的:您这是要考代码能力还是要考背书能力?

不过冷静下来想想,多头注意力 其实就是把单头注意力做了个"克隆"操作,然后把结果拼起来。

核心思想分解

想象你在开会,需要同时关注多个方面的信息:

  • 头1专门关注技术细节

  • 头2专门关注商业逻辑

  • 头3专门关注时间节点

  • 头4专门关注资源配置

每个"头"都有自己的Q、K、V矩阵,就像每个人都有自己的关注点和思维方式。

手写实现(简化版)

复制代码
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 为什么要除以num_heads?因为最后要concat
        assert d_model % num_heads == 0
        
        # 线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model) 
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # KV Cache - 这是重点!
        self.cache = {}
        
    def forward(self, query, key, value, mask=None, use_cache=False, cache_key="default"):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # 1. 线性变换
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 2. 重塑为多头形状
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        # 现在形状是 (batch, num_heads, seq_len, d_k)
        
        # 3. KV Cache逻辑 - 面试加分项!
        if use_cache and cache_key in self.cache:
            # 从缓存中获取之前的K,V
            cached_K, cached_V = self.cache[cache_key]
            # 拼接新的K,V
            K = torch.cat([cached_K, K], dim=2)
            V = torch.cat([cached_V, V], dim=2)
        
        if use_cache:
            # 更新缓存
            self.cache[cache_key] = (K, V)
        
        # 4. 计算注意力
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask isnotNone:
            attention_scores.masked_fill_(mask == 0, -1e9)
            
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        # 5. 应用注意力权重
        attended_values = torch.matmul(attention_weights, V)
        
        # 6. 重新整合多头结果
        attended_values = attended_values.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # 7. 最终线性变换
        output = self.W_o(attended_values)
        
        return output, attention_weights

# 吴师兄提醒:面试时一定要解释每一步在做什么!

KV Cache的精髓

很多同学写到这里就卡住了,KV Cache到底是个什么鬼?

简单理解:在生成式任务中,每次只生成一个新token,但之前所有token的K和V都要重新计算,这太浪费了!

KV Cache就是把之前计算过的K、V存起来,新来的token只需要计算自己的K、V,然后和历史的拼接就行。

这样时间复杂度从 O(n²) 降到了 O(n),这就是为什么现在大模型推理这么快的原因之一。

二、字节跳动经典题:Transformer中的d_k有什么玩意

原题:Transformer中的Attention为什么要除以sqrt(d_k)?

这个问题看似简单,实际上考的是你对数学原理的理解深度。

不除以sqrt(d_k)会怎样?

我们先用数学直觉来理解:

假设Q和K的维度是d_k,那么它们的点积结果的方差会随着d_k线性增长。

具体来说,如果Q和K的每个元素都是独立的标准正态分布N(0,1),那么:

  • d_k = 64 时,Q·K的方差约为 64

  • d_k = 512 时,Q·K的方差约为 512

  • d_k = 2048 时,Q·K的方差约为 2048

问题在哪里?

方差太大,Softmax就"崩"了!

复制代码
import torch
import numpy as np

# 演示不同d_k下的softmax行为
d_k_values = [64, 256, 512, 1024]

for d_k in d_k_values:
    # 模拟点积结果
    scores = torch.randn(10, 10) * math.sqrt(d_k)  # 模拟未缩放的情况
    
    print(f"d_k={d_k}:")
    print(f"  分数范围: [{scores.min():.2f}, {scores.max():.2f}]")
    
    # 计算softmax
    softmax_result = torch.softmax(scores, dim=-1)
    print(f"  最大注意力权重: {softmax_result.max():.4f}")
    print(f"  最小注意力权重: {softmax_result.min():.4f}")
    print()

运行结果大概是这样的:

复制代码
d_k=64:
  分数范围: [-15.23, 18.45]
  最大注意力权重: 0.0234
  最小注意力权重: 0.0001

d_k=1024:
  分数范围: [-67.89, 71.23]  
  最大注意力权重: 0.9999
  最小注意力权重: 0.0000

看到没?d_k越大,softmax的输出越"极端" ,几乎所有权重都集中到一个位置上,梯度就消失了

除以sqrt(d_k)的数学原理

通过除以sqrt(d_k),我们把点积结果的方差重新缩放到1,这样:

  1. 保持softmax输出的多样性

  2. 避免梯度消失

  3. 让模型训练更稳定

这就是所谓的"缩放点积注意力"(Scaled Dot-Product Attention)。

三、阿里送分题:投机解码是怎么工作的?

原题:投机解码(Speculative Decoding)是如何工作的?请详细说明其原理和优势。

说实话,第一次听到"投机解码"这个词,我以为是什么高深的算法。研究了一下发现,这玩意儿的核心思想特别朴素:

"让小模型先猜,大模型再验证"

核心思想

传统的大模型生成是这样的:

复制代码
输入 -> 大模型 -> token1 -> 大模型 -> token2 -> 大模型 -> token3 -> ...

每次都要走一遍大模型,慢得要死。

投机解码的思路:

复制代码
输入 -> 小模型快速生成N个token -> 大模型一次性验证这N个token -> 接受/拒绝

具体工作流程

  1. Draft阶段:小模型(比如7B)快速生成k个候选token

  2. Verify阶段 :大模型(比如70B)对这k个token进行并行验证

  3. Accept/Reject:根据概率分布决定接受多少个token

为什么能加速?

关键在于"并行验证"!

大模型验证k个token的时间 ≈ 生成1个token的时间(因为都是一次forward pass)

如果k个token中有3个被接受,那么我们用生成1个token的时间 ,得到了3个token的结果,加速比达到3x!

简化代码示例

复制代码
def speculative_decoding(draft_model, target_model, input_ids, k=4):
    """
    投机解码的简化实现
    """
    accepted_tokens = []
    current_input = input_ids
    
    while len(accepted_tokens) < max_length:
        # 1. Draft阶段:小模型快速生成k个token
        draft_tokens = []
        draft_input = current_input
        
        for _ in range(k):
            with torch.no_grad():
                draft_logits = draft_model(draft_input)
                next_token = sample_token(draft_logits)
                draft_tokens.append(next_token)
                draft_input = torch.cat([draft_input, next_token.unsqueeze(0)], dim=-1)
        
        # 2. Verify阶段:大模型并行验证
        verify_input = torch.cat([current_input] + draft_tokens, dim=-1)
        with torch.no_grad():
            target_logits = target_model(verify_input)
        
        # 3. Accept/Reject决策
        accepted_count = 0
        for i in range(k):
            # 比较大小模型的概率分布
            draft_prob = get_prob(draft_model_logits[i], draft_tokens[i])
            target_prob = get_prob(target_logits[i], draft_tokens[i])
            
            # 如果大模型概率 >= 小模型概率,接受
            if target_prob >= draft_prob:
                accepted_tokens.append(draft_tokens[i])
                accepted_count += 1
            else:
                # 概率采样决定是否接受
                accept_prob = target_prob / draft_prob
                if random.random() < accept_prob:
                    accepted_tokens.append(draft_tokens[i])
                    accepted_count += 1
                break# 一旦拒绝,后续都不要了
        
        current_input = torch.cat([current_input] + accepted_tokens[-accepted_count:], dim=-1)
    
    return accepted_tokens

实际效果

在实践中,投机解码通常能带来1.5x - 3x的加速,具体取决于:

  • 小模型和大模型的能力差距

  • 任务的难度(越简单的任务,小模型猜得越准)

  • k值的选择(太大了容易被拒绝,太小了加速不明显)

四、美团实战题:Loss变成NaN了怎么办?

原题:如果训练过程中出现Loss NaN,可能有哪些原因?如何排查?

这个问题太接地气了!相信每个训练过大模型的同学都遇到过这个问题。

常见原因分析

1. 梯度爆炸

现象 :Loss突然从正常值跳到NaN
原因:梯度太大,参数更新过头了

复制代码
# 检查梯度范数
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm()
        if grad_norm > 100:  # 阈值可调
            print(f"梯度爆炸警告: {name}, 梯度范数: {grad_norm}")

解决方案

  • 使用梯度裁剪:torch.nn.utils.clip_grad_norm_()

  • 降低学习率

  • 检查网络初始化

2. 学习率过大

现象 :训练开始没多久就NaN
原因:步子迈得太大,直接跳到了loss landscape的悬崖边

复制代码
# 学习率调试技巧
initial_lr = 1e-4  # 从小开始
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 
                                            start_factor=0.1,  # 前10%时间用更小的lr
                                            total_iters=int(0.1 * total_steps))
3. 数值下溢/上溢

现象 :特定操作后出现NaN
原因:FP16精度不够,或者某些中间结果超出了数值范围

复制代码
# 混合精度训练的正确姿势
scaler = torch.cuda.amp.GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        outputs = model(batch)
        loss = criterion(outputs, targets)
    
    # 检查loss是否为NaN
    if torch.isnan(loss):
        print("检测到NaN loss,跳过这个batch")
        continue
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
4. 数据问题

现象 :某些batch后出现NaN
原因:训练数据中包含异常值

复制代码
# 数据清洗检查
def check_data_quality(dataloader):
    for batch_idx, batch in enumerate(dataloader):
        # 检查输入是否包含NaN/Inf
        if torch.isnan(batch['input_ids']).any():
            print(f"Batch {batch_idx} 包含NaN输入")
        
        if torch.isinf(batch['input_ids']).any():
            print(f"Batch {batch_idx} 包含Inf输入")
        
        # 检查标签
        if'labels'in batch:
            if torch.isnan(batch['labels']).any():
                print(f"Batch {batch_idx} 包含NaN标签")

完整的排查流程

复制代码
class NaNDetector:
    def __init__(self, model):
        self.model = model
        self.step_count = 0
        
    def check_and_log(self, loss, batch_idx):
        self.step_count += 1
        
        # 1. 检查loss
        if torch.isnan(loss):
            print(f"Step {self.step_count}: Loss is NaN!")
            self.diagnose()
            returnTrue
        
        # 2. 检查梯度
        if self.step_count % 100 == 0:
            self.check_gradients()
            
        returnFalse
    
    def diagnose(self):
        print("开始NaN诊断...")
        
        # 检查模型参数
        for name, param in self.model.named_parameters():
            if torch.isnan(param).any():
                print(f"参数 {name} 包含NaN")
            if torch.isinf(param).any():
                print(f"参数 {name} 包含Inf")
        
        # 检查梯度
        for name, param in self.model.named_parameters():
            if param.grad isnotNone:
                if torch.isnan(param.grad).any():
                    print(f"梯度 {name} 包含NaN")
                if torch.isinf(param.grad).any():
                    print(f"梯度 {name} 包含Inf")
    
    def check_gradients(self):
        total_norm = 0
        param_count = 0
        
        for name, param in self.model.named_parameters():
            if param.grad isnotNone:
                param_norm = param.grad.norm()
                total_norm += param_norm ** 2
                param_count += 1
                
                if param_norm > 10.0:  # 可调阈值
                    print(f"大梯度警告: {name}, 范数: {param_norm:.4f}")
        
        total_norm = total_norm ** (1. / 2)
        print(f"总梯度范数: {total_norm:.4f}")

# 使用方法
detector = NaNDetector(model)

for batch_idx, batch in enumerate(dataloader):
    loss = training_step(batch)
    
    if detector.check_and_log(loss, batch_idx):
        # 检测到NaN,可以选择停止训练或跳过
        break

总结与求职建议

整理了这么多题目,发现大厂面试的套路基本是:

  1. 基础概念要烂熟于心(Transformer、Attention机制)

  2. 数学原理要能自圆其说(为什么要除以sqrt(d_k))

  3. 实践经验要有案例支撑(Loss NaN怎么排查)

  4. 前沿技术要跟上节奏(投机解码、KV Cache)

最重要的是 ,面试时不要光背答案,要讲出原理和直觉 。面试官问为什么要除以sqrt(d_k),你不能只说"防止梯度消失",还要能解释为什么不除就会梯度消失

这样面试官才会觉得你是真的理解,而不是死记硬背。

相关推荐
讨厌吃蛋黄酥4 小时前
🔥 JavaScript异步之谜:单线程如何实现“同时”做多件事?99%的人都理解错了!
前端·javascript·面试
Rock_yzh4 小时前
AI学习日记——PyTorch深度学习快速入门:神经网络构建与训练实战
人工智能·pytorch·python·深度学习·神经网络·学习
前端小刘哥4 小时前
现场直播的技术革新者:视频直播点播平台EasyDSS在现场直播场景中的技术应用
算法
violet-lz4 小时前
数据结构八大排序:堆排序-从二叉树到堆排序实现
数据结构·算法
十八岁讨厌编程4 小时前
【算法训练营 · 补充】LeetCode Hot100(上)
算法·leetcode
razelan4 小时前
第一例:石头剪刀布的机器学习(xedu,示例15)
人工智能·机器学习
渣哥4 小时前
Spring Boot 本质揭秘:约定优于配置 + 自动装配
javascript·后端·面试
一条星星鱼4 小时前
从0到1:如何用统计学“看透”不同睡眠PSG数据集的差异(域偏差分析实战)
人工智能·深度学习·算法·概率论·归一化·睡眠psg
浮灯Foden4 小时前
算法-每日一题(DAY18)多数元素
开发语言·数据结构·c++·算法·leetcode·面试