Prompt-Tuning源码分析

Prompt-Tuning源码分析

源码

我们这里的代码解析以huggingface peft源码为主

从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

伪代码示例

python 复制代码
soft_prompt=torch.nn.Parameter(#Make tensor trainable 
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor 
def input_with_softprompt(x,soft_prompt):
	x=concatenate([soft_prompt,x] #Prepend soft prompt to input 
				  dim=seq_len)
	return x 
model(input_with_softprompt(x))

peft源码

bash 复制代码
class PromptEmbedding(torch.nn.Module):
    """

    ```py
    >>> from peft import PromptEmbedding, PromptTuningConfig

    >>> config = PromptTuningConfig(
    ...     peft_type="PROMPT_TUNING",
    ...     task_type="SEQ_2_SEQ_LM",
    ...     num_virtual_tokens=20,
    ...     token_dim=768,
    ...     num_transformer_submodules=1,
    ...     num_attention_heads=12,
    ...     num_layers=12,
    ...     prompt_tuning_init="TEXT",
    ...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
    ...     tokenizer_name_or_path="t5-base",
    ... )

    >>> # t5_model.shared is the word embeddings of the base model
    >>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
    ```

    Input Shape: (`batch_size`, `total_virtual_tokens`)

    Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
    """

    def __init__(self, config, word_embeddings):
        super().__init__()

        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
        if config.prompt_tuning_init == PromptTuningInit.TEXT:
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
            init_text = config.prompt_tuning_init_text
            init_token_ids = tokenizer(init_text)["input_ids"]
            # Trim or iterate until num_text_tokens matches total_virtual_tokens
            num_text_tokens = len(init_token_ids)
            if num_text_tokens > total_virtual_tokens:
                init_token_ids = init_token_ids[:total_virtual_tokens]
            elif num_text_tokens < total_virtual_tokens:
                num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
                init_token_ids = init_token_ids * num_reps
            init_token_ids = init_token_ids[:total_virtual_tokens]

            word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
            self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

    def forward(self, indices):
        # Just get embeddings
        prompt_embeddings = self.embedding(indices)
        return prompt_embeddings

输出的模型权重文件如下所示:

bash 复制代码
/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500]  adapter_config.json
├── [ 33K]  adapter_model.bin
└── [ 111]  README.md

0 directories, 3 files

其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

推理

bash 复制代码
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)

# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

# 模型推理
outputs = model.generate(
        input_ids=inputs["input_ids"], 
        attention_mask=inputs["attention_mask"], 
        max_new_tokens=10, 
        eos_token_id=3
    )

# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
相关推荐
云烟成雨TD14 小时前
Agent Scope Java 2.x 系列【3】从零构建 ReActAgent
java·人工智能·agent
❀抽抽14 小时前
证件照制作API接入指南:700+规格一键生成
大数据·网络·人工智能
Promise微笑14 小时前
绝缘油介损(油介损)测试仪的深层机理、技术演进与精准诊断策略
大数据·网络·人工智能
开发者小布14 小时前
Claude Code 国内配置完整指南:通过中转 API 实现稳定访问(macOS / Linux / Windows)
人工智能
大C聊AI14 小时前
通用大模型纷纷收费,垂直场景AI工具的价值正在被重估
大数据·人工智能·机器学习·办公效率·ai 工具·智标领航·ai 辅助办公
苏州邦恩精密14 小时前
2026江苏GOM三维扫描仪定制厂家找哪家?企业数字化转型视角
人工智能·机器学习·3d·自动化·制造
python-码博士14 小时前
PyTorch 从零实现 Flow Matching:训练、采样、画图一条龙
人工智能·pytorch·python
砍光二叉树14 小时前
一文打通 AI 认知:LLM、Agent、MCP、Skill 完整体系
人工智能·llm·agent·skill·mcp
努力写A题的小菜鸡15 小时前
PyTorch 图像预处理 transforms 与 TensorBoard 可视化 (自己学习记录)
人工智能·pytorch·学习