MTP(Multi-Token Prediction)

MTP(Multi-Token Prediction)

一、实现原理

  • 训练

    主结构不变,在最后特征向量进行线性转换分类时,多了几个头。

    主头还是和之前一样预测下一个token的置信度,求损失

    第一个副头为预测下下个token的置信度,求损失

    第二个副头为预测下下下个token的置信度,求损失

    以此类推。

  • 推理

    以往推理是计算最后一个token的置信度推理下一个token

    现在推理是各个头都计算最后一个token的置信度,主头推理的是下个token,副头以此类推,下下个,下下下个。。。

    虽然一次性推理了多个token,但是需要检验下副头推理的token有没有用,也就是把已有token和推理后的token拼接在一起,重新放入主模型进行预测一遍。取出各个副头推理的token对应的置信度,看看是否到达了阈值。最终返回主头预测的token和接着主头预测的连续token(前提是达到阈值,并且连续)

二、代码

python 复制代码
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Union, Dict, Any
import json
import os
from torch.utils.tensorboard import SummaryWriter


class Config():
    def __init__(self,
                llm_model_path = '/home/user/Downloads/Qwen2.5-0.5B-Instruct',
                predict_tokens_num = 5,
                **kwargs):
        self.llm_model_path = llm_model_path
        self.predict_tokens_num = predict_tokens_num
        super().__init__(**kwargs)

class MTPModule(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(2 * hidden_size, 4 * hidden_size)
        self.linear2 = nn.Linear(4 * hidden_size, hidden_size)
      
    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x
      

class MTP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.main_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path).base_model
        # self.main_model.eval()
        # mtp模块
        self.mtp_modules = nn.ModuleList([MTPModule(self.main_model.config.hidden_size) for _ in range(self.config.predict_tokens_num-1)])
      
        # 每个头共享参数
        self.output_head = nn.Linear(self.main_model.config.hidden_size, self.main_model.config.vocab_size)
      
       
    def forward_main(self, input_ids, attention_mask=None, **kwargs):
      
        # with torch.no_grad():
        main_hidden_output = self.main_model(input_ids, attention_mask, **kwargs).last_hidden_state
     
        # [N,T,E] ---> [N,T,vocab_size]
        main_head_output = self.output_head(main_hidden_output)
      
        return main_hidden_output, main_head_output
  
    def forward_mtp(self, input_ids, previous_hidden_output, head_index):
        # [N,T] --> [N,T,E]
        input_embed = self.main_model.get_input_embeddings()(input_ids)
        # [N,T,E]和[N,T,E] -->[N,T,2E]
        mtp_input = torch.cat([previous_hidden_output, input_embed], dim=-1)
        # [N,T,2E] --->[N,T,E]
        mtp_hidden_output = self.mtp_modules[head_index](mtp_input)
        # [N,T,E] ---> [N,T,vocab_size]
        mtp_head_output = self.output_head(mtp_hidden_output)
      
        return mtp_hidden_output, mtp_head_output
  
  
    def forward(self, input_ids, attention_mask=None, **kwargs):
        # 预测的逻辑
        outputs = {}
        main_hidden_output, main_head_output = self.forward_main(input_ids, attention_mask, **kwargs)
        previous_hidden_output = main_hidden_output
        outputs['head_main'] = main_head_output
        for head_index in range(0, self.config.predict_tokens_num-1):
            previous_hidden_output, mtp_head_output = self.forward_mtp(input_ids, previous_hidden_output, head_index)
            outputs[f'mtp_head_{head_index}'] = mtp_head_output
          
        return outputs
  
    def generate(self,input_ids,max_length, **kwargs):
        self.eval()
        seq = input_ids.clone() # 问题; eg:今天天气怎么样?
        b, s = seq.size()
      
        with torch.no_grad():
          
            while seq.size(1) < max_length:
                outputs = self.forward(seq) # 输入问题预测下一个词。
                print(seq.shape)
                speculative_tokens = []
              
                # main模型头生成的token
                logits = outputs['head_main']
                logits = logits[:, -1, :] # 拿到最后一个token的预测值。上述案例中今天天气怎么样?的?所预测的值是什么。此时预测的是下一个词
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.argmax(probs, dim=-1) # 计算最大概率的token
                speculative_tokens.append(next_token) # 添加到推理的列表中
              
                # 汇总mtp头生成的token
                for i in  range(self.config.predict_tokens_num-1):
                    logits = outputs[f'mtp_head_{i}']
                    logits = logits[:, -1, :] # 拿到最后一个token的预测值。上述案例中今天天气怎么样?的?所预测的值是什么。i=0此时预测的是下下个词
                    probs = torch.softmax(logits, dim=-1)
                    next_token = torch.argmax(probs, dim=-1)
                  
                    speculative_tokens.append(next_token)
              
                
              
                speculative_tokens = torch.cat(speculative_tokens, dim=-1) # shape: (len)
                speculative_tokens = speculative_tokens.unsqueeze(0) # shape: (1, len)
              
                # 将新生成的tokens和原始序列拼接
                all_tokens = torch.cat([seq, speculative_tokens], dim=-1)
              
                # 将新序列输入main模型(验证模型)进行验证,保留符合条件的token
                # 也就是再次预测一遍。计算在主模型上,副头生成的这几个token能不能用。 如果能用则用,用不了则不用。
                _, all_logits = self.forward_main(all_tokens) # [N,T,vocab_size]
              
                # 取出需要验证的token对应的logits
                validation_logits = all_logits[:, -speculative_tokens.shape[1]:]  # [N,len(speculative_tokens),vocab_size]
              
                # 获取各个token在main模型的输出概率
                accept_probs =  []

                for i in range(speculative_tokens.shape[1]):
                    logits = validation_logits[:, i] # (batch_size, vocab_size)   [N,vocab_size]
                    probs = torch.softmax(logits, dim=-1) # (batch_size, vocab_size)  [N,vocab_size]
                    token = speculative_tokens[:, i] # [N] #
                 
                    token_prob = probs.gather(1, token.unsqueeze(0)) # 从probs获取第1维度下对应token的概率
                    accept_probs.append(token_prob)
           
                # 拼接各个token的生成概率
                accept_probs = torch.cat(accept_probs, dim=-1)
              
                # 保留概率值大于阈值的token, 接受这部分token,否则舍弃(舍弃某个token时,后面的token都要舍弃)
                # 接受token的掩码
                accept_mask = (accept_probs > 1e-6)
                print(f'接受掩码:{accept_mask}')
              
                if accept_mask.any():  # [1, 1, 0, 1]  ~accept_mask: [0, 0, 1, 0]
                    print(f'拒绝掩码:{~accept_mask}')
                    # 获取被拒绝(舍弃)token对应的索引
                    reject_token_index = (~accept_mask).nonzero(as_tuple=True)[1]
                    print(f'拒绝token的索引:{reject_token_index}')
                    # 如果有需要舍弃的token
                    if reject_token_index.shape[0] > 0:
                      
                        # 找出第一个被舍弃的token的索引,其之前的token是需要保留的,之后的全部舍弃
                        # 接受token的数量即是第一个被舍弃的token的索引
                        accept_num = reject_token_index[0]
                  
                    else:
                        # 如果没有需要舍弃的token,则全部接受
                        accept_num = speculative_tokens.shape[1]
                      
                  
              
                else:
                    accept_num = 0    
              
              
                if accept_num > 0:
                  
                   # 取出通过验证的token
                    accept_tokens = speculative_tokens[:, :accept_num]
                 
                    seq = torch.cat([seq, accept_tokens], dim=1)
              
                else:
                    logits = outputs['head_main']
                  
                    logits = logits[:, -1, :]
                    probs = torch.softmax(logits, dim=-1)
                    next_token = torch.argmax(probs, dim=-1)
                    next_token = next_token.unsqueeze(0)
              
                  
                    seq = torch.cat([seq, next_token], dim=-1)
                    # print(seq)
                  
              
            return seq
          
          
      
      
                  
      
def train(config, model, dataloader, optimizer, writer, device, epochs, print_step, save_step, save_path):
    steps = 0
    model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()

            # [N,T]
            input_ids = batch['input_ids'].to(device)
            # [N,T], 前面是-100因为是问题,不需要计算损失
            labels = batch['labels'].to(device)
          
            # 主头 [N,T,E] 和 [N,T,E]
            main_hidden_output, main_head_output = model.forward_main(input_ids)
            previous_hidden_output = main_hidden_output
            for index in range(0, config.predict_tokens_num-1):
                # 副头
                previous_hidden_output, mtp_head_output = model.forward_mtp(input_ids, previous_hidden_output, index)

                # index=0时候-(1+index+1)=-2  为1时时候则为-3。 index越大,则取出的token越少,
                # index=0时候,不取最后两个token  index=1的时候 不取最后3个token

                # eg:今天天气怎么样?天气很好。|eos|
                #    input为:今天天气怎么样?天气很好        --->对应的就预测值,
                #    label为:天气怎么样?天气很好。|eos|     --->对应的就是真实值。
                # 第1个副头,表示当前token预测下下个词。 即:?预测的是气。
                # 第2个副头,表示当前token预测下下下个词
                # ...
                mtp_head_output = mtp_head_output[:, :-(1+index+1)] # [batch_size, seq_len, vocab_size]
                mtp_head_output = mtp_head_output.reshape(-1, model.main_model.config.vocab_size) # [batch_size * seq_len, vocab_size]

                # index=0时候 不取前两个token  index=1的时候 不取前三个token
                target = labels[:, 1+index+1:] # [batch_size, seq_len]
                target = target.contiguous().view(-1) # [batch_size * seq_len]
              
                mtp_loss = F.cross_entropy(mtp_head_output, target, ignore_index=-100)
                # 反向传播,计算梯度。不断循环每个,retain_graph=True会进行梯度累加。
                mtp_loss.backward(retain_graph=True)
            # 主loss, 标签值:不取第一个(因为input的第一个token的label,对应的是下个token)  预测值:不取倒数第一个token(因为最后一个token预测的值没有意义了。取到前一个就行)
            # eg:今天天气怎么样?天气很好。|eos|
            #    input为:今天天气怎么样?天气很好。        --->对应的就预测值,
            #    label为:天天气怎么样?天气很好。|eos|     --->对应的就是真实值。
            main_loss = F.cross_entropy(main_head_output[:, :-1].reshape(-1, model.main_model.config.vocab_size), labels[:, 1:].reshape(-1), ignore_index=-100)
            # 反向传播,计算梯度,会把之前累加上,因为之前retain_graph=True
            main_loss.backward()
            # 更新参数值
            optimizer.step()
          
            if (steps+1) % print_step==0:
                writer.add_scalar('main_loss', main_loss.item(), steps)
                writer.add_scalar('mtp_loss', mtp_loss.item(), steps)
                print(f"Epoch {epoch+1}], Step {steps+1}, main_loss: {main_loss.item():.4f}, mtp_loss: {mtp_loss.item():.4f}")
              
            if (steps+1) % save_step==0:
                torch.save(model.state_dict(), f"{save_path}/model_{steps}.pth")
          
            steps += 1  
      
  
class MyDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        super().__init__()
        self.data_path = data_path
      
        self.tokenizer = tokenizer
      
        with open(self.data_path, 'r', encoding='utf-8') as f:
            self.datas = f.readlines()

          
    def __len__(self):
        return len(self.datas)
  
    def __getitem__(self, index):
        sample = self.datas[index].strip()
        sample = json.loads(sample)
        conversations = sample['conversations']
        user = conversations[0]['content']
        assistant = conversations[1]['content']
        # 把问题应用聊天模板
        q = self.tokenizer.apply_chat_template([{"role": "user", "content": user}], tokenize=False, add_generation_prompt=True)

        # 把回答加上终止符
        a = assistant + self.tokenizer.eos_token
        # 问题ids
        q_input_ids = self.tokenizer(q)['input_ids']
        # 答案ids
        a_input_ids = self.tokenizer(a)['input_ids']
        # 问题和答案进行拼接
        input_ids = q_input_ids + a_input_ids
        # labels_id 把问题填充-100,计算损失时候不会计算。 把答案作为目标值。
        labels = [-100] * len(q_input_ids) + a_input_ids
      
        return {
            "input_ids": input_ids, # input的长度和labels的长度是一致的,且没有错位
            "labels": labels,
        }
      
class MyDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
  
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # 找到样本中长度最大的input
        max_len = max(len(feature['input_ids']) for feature in features)
        input_ids = []
        labels = []
        # 遍历样本,
        for feature in features:
            #对input进行填充
            input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['input_ids'])))
            # 对label也进行填充
            labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['labels'])))

        # 返回
        return {'input_ids': torch.tensor(input_ids, dtype=torch.long),
                'labels': torch.tensor(labels, dtype=torch.long)}
      

          
      
      
if __name__ == '__main__':
    # 日志记录
    writer = SummaryWriter('./runs')
    config = Config()
    model = MTP(config)
    model.cuda()
    print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    tokenizer = AutoTokenizer.from_pretrained(config.llm_model_path)
    dataset = MyDataset('/home/user/wyf/deepseek_learn/MTP_train/lora_medical.jsonl', tokenizer)
    dataloader = DataLoader(dataset=dataset, batch_size=8, shuffle=True, num_workers=2, collate_fn=MyDataCollator(tokenizer))
  
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    save_path = './mtp'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    train(config, model, dataloader, optimizer, writer, device='cuda', epochs=10, print_step=10, save_step=500, save_path='mtp')
  

相关推荐
hello_ejb32 分钟前
聊聊Spring AI Alibaba的ObsidianDocumentReader
java·人工智能·spring
桥Dopey14 分钟前
Python常用的第三方模块之【jieba库】支持三种分词模式:精确模式、全模式和搜索引擎模式(提高召回率)
人工智能·python·分词模式
W流沙W15 分钟前
bert学习
人工智能·bert
想学好英文的ikun44 分钟前
【MCP】第二篇:IDE革命——用MCP构建下一代智能工具链
ide·人工智能·python·ai·个人开发·mcp
码上飞扬1 小时前
深度剖析:GPT-3.5与GPT-4的主要区别及架构解析
人工智能
whuzhang161 小时前
3DGS之齐次坐标
人工智能·3d·自动驾驶
闭月之泪舞1 小时前
《深度神经网络之数据增强、模型保存、模型调用、学习率调整》
人工智能·学习·dnn
掘金詹姆斯2 小时前
LangChain4j快速入门(一)
人工智能·langchain
快手技术2 小时前
新加坡见!快手 11 篇论文入选人工智能领域顶会 ICLR 2025
人工智能
结冰架构2 小时前
量子金融工程:蒙特卡洛算法误差压缩至0.3%
人工智能·算法·ai·金融·量子计算