Prompt-Tuning源码分析
源码
我们这里的代码解析以huggingface peft源码为主
从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。
伪代码示例
python
soft_prompt=torch.nn.Parameter(#Make tensor trainable
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor
def input_with_softprompt(x,soft_prompt):
x=concatenate([soft_prompt,x] #Prepend soft prompt to input
dim=seq_len)
return x
model(input_with_softprompt(x))
peft源码
bash
class PromptEmbedding(torch.nn.Module):
"""
```py
>>> from peft import PromptEmbedding, PromptTuningConfig
>>> config = PromptTuningConfig(
... peft_type="PROMPT_TUNING",
... task_type="SEQ_2_SEQ_LM",
... num_virtual_tokens=20,
... token_dim=768,
... num_transformer_submodules=1,
... num_attention_heads=12,
... num_layers=12,
... prompt_tuning_init="TEXT",
... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
... tokenizer_name_or_path="t5-base",
... )
>>> # t5_model.shared is the word embeddings of the base model
>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
```
Input Shape: (`batch_size`, `total_virtual_tokens`)
Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
"""
def __init__(self, config, word_embeddings):
super().__init__()
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
init_text = config.prompt_tuning_init_text
init_token_ids = tokenizer(init_text)["input_ids"]
# Trim or iterate until num_text_tokens matches total_virtual_tokens
num_text_tokens = len(init_token_ids)
if num_text_tokens > total_virtual_tokens:
init_token_ids = init_token_ids[:total_virtual_tokens]
elif num_text_tokens < total_virtual_tokens:
num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
init_token_ids = init_token_ids * num_reps
init_token_ids = init_token_ids[:total_virtual_tokens]
word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
def forward(self, indices):
# Just get embeddings
prompt_embeddings = self.embedding(indices)
return prompt_embeddings
输出的模型权重文件如下所示:
bash
/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500] adapter_config.json
├── [ 33K] adapter_model.bin
└── [ 111] README.md
0 directories, 3 files
其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。
推理
bash
from peft import PeftModel, PeftConfig
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)
# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)
# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")
# 模型推理
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=10,
eos_token_id=3
)
# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))