Prompt-Tuning源码分析

Prompt-Tuning源码分析

源码

我们这里的代码解析以huggingface peft源码为主

从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

伪代码示例

python 复制代码
soft_prompt=torch.nn.Parameter(#Make tensor trainable 
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor 
def input_with_softprompt(x,soft_prompt):
	x=concatenate([soft_prompt,x] #Prepend soft prompt to input 
				  dim=seq_len)
	return x 
model(input_with_softprompt(x))

peft源码

bash 复制代码
class PromptEmbedding(torch.nn.Module):
    """

    ```py
    >>> from peft import PromptEmbedding, PromptTuningConfig

    >>> config = PromptTuningConfig(
    ...     peft_type="PROMPT_TUNING",
    ...     task_type="SEQ_2_SEQ_LM",
    ...     num_virtual_tokens=20,
    ...     token_dim=768,
    ...     num_transformer_submodules=1,
    ...     num_attention_heads=12,
    ...     num_layers=12,
    ...     prompt_tuning_init="TEXT",
    ...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
    ...     tokenizer_name_or_path="t5-base",
    ... )

    >>> # t5_model.shared is the word embeddings of the base model
    >>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
    ```

    Input Shape: (`batch_size`, `total_virtual_tokens`)

    Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
    """

    def __init__(self, config, word_embeddings):
        super().__init__()

        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
        if config.prompt_tuning_init == PromptTuningInit.TEXT:
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
            init_text = config.prompt_tuning_init_text
            init_token_ids = tokenizer(init_text)["input_ids"]
            # Trim or iterate until num_text_tokens matches total_virtual_tokens
            num_text_tokens = len(init_token_ids)
            if num_text_tokens > total_virtual_tokens:
                init_token_ids = init_token_ids[:total_virtual_tokens]
            elif num_text_tokens < total_virtual_tokens:
                num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
                init_token_ids = init_token_ids * num_reps
            init_token_ids = init_token_ids[:total_virtual_tokens]

            word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
            self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

    def forward(self, indices):
        # Just get embeddings
        prompt_embeddings = self.embedding(indices)
        return prompt_embeddings

输出的模型权重文件如下所示:

bash 复制代码
/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500]  adapter_config.json
├── [ 33K]  adapter_model.bin
└── [ 111]  README.md

0 directories, 3 files

其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

推理

bash 复制代码
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)

# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

# 模型推理
outputs = model.generate(
        input_ids=inputs["input_ids"], 
        attention_mask=inputs["attention_mask"], 
        max_new_tokens=10, 
        eos_token_id=3
    )

# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
相关推荐
胖头鱼的鱼缸(尹海文)25 分钟前
数据库管理-第376期 Oracle AI DB 23.26新特性一览(20251016)
数据库·人工智能·oracle
瑞禧生物ruixibio28 分钟前
4-ARM-PEG-Pyrene(2)/Biotin(2),多功能化聚乙二醇修饰荧光标记生物分子的设计与应用探索
arm开发·人工智能
大千AI助手32 分钟前
Huber损失函数:稳健回归的智慧之选
人工智能·数据挖掘·回归·损失函数·mse·mae·huber损失函数
墨利昂44 分钟前
10.17RNN情感分析实验:加载预训练词向量模块整理
人工智能·rnn·深度学习
【建模先锋】1 小时前
一区直接写!CEEMDAN分解 + Informer-LSTM +XGBoost组合预测模型
人工智能·lstm·ceemdan·预测模型·风速预测·时间序列预测模型
fsnine1 小时前
YOLOv2原理介绍
人工智能·计算机视觉·目标跟踪
倔强的石头1061 小时前
AI修图革命:IOPaint+cpolar让废片拯救触手可及
人工智能·cpolar·iopaint
文火冰糖的硅基工坊1 小时前
[人工智能-大模型-15]:大模型典型产品对比 - 数字人
人工智能·大模型·大语言模型
JJJJ_iii1 小时前
【机器学习05】神经网络、模型表示、前向传播、TensorFlow实现
人工智能·pytorch·python·深度学习·神经网络·机器学习·tensorflow
William.csj2 小时前
服务器/Pytorch——对于只调用一次的函数初始化,放在for训练外面和里面的差异
人工智能·pytorch·python