LLaMA-Adapter源码解析

LLaMA-Adapter源码解析

伪代码

python 复制代码
def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt):
	residual =x
	y= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefix
	x= self_attention(x)
	x = x+ gating_factor * y  # llama-adapter: apply zero_init_attention
	x = LayerNorm(x+residual)
	residual = x
	x = FullyConnectedLayers(x)
	x = AdapterLayers(x)
	x = LayerNorm(x + residual)
	return x

源码

python 复制代码
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.gate = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        if adapter is not None:
           adapter_len = adapter.shape[1]
           adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_k = adapter_k.transpose(1, 2)
           adapter_v = adapter_v.transpose(1, 2)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        if adapter is not None:
            adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
            adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
            output = output + torch.matmul(adapter_scores, adapter_v)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)
相关推荐
JovaZou19 小时前
n8n 本地部署及实践应用,实现零成本自动化运营 Telegram 频道(保证好使)
运维·人工智能·docker·ai·自然语言处理·自动化·llama
openownworld1 天前
LLaMA-Factory双卡4090微调DeepSeek-R1-Distill-Qwen-14B医学领域
llama
Panesle1 天前
英伟达Llama-3.1-Nemotron-Ultra-253B-v1语言模型论文快读:FFN Fusion
人工智能·语言模型·llama·nvidia
福大大架构师每日一题1 天前
transformers v4.51.1正式发布!Llama 4多项关键修复,深度学习玩家速更!
人工智能·深度学习·llama
OpenBayes2 天前
OpenBayes 一周速览|1分钟生成完整音乐,DiffRhythm人声伴奏一键搞定; Stable Virtual Camera重塑3D视频创作
人工智能·深度学习·数据集·llama·视频生成·推理·蛋白质突变
x-cmd3 天前
[250411] Meta 发布 Llama 4 系列 AI 模型 | Rust 1.86 引入重大语言特性
人工智能·rust·llama
m0_540507783 天前
Meta LLaMA 4:对抗 GPT-4o 与 Claude 的开源王牌
llama
百年孤独百年3 天前
Ollama调用多GPU实现负载均衡
分布式·大模型·负载均衡·llama·ollama·deepseek
Jackilina_Stone4 天前
【微调大模型】轻松微调百余种大模型:LLaMA-Factory
大模型·微调·llama
是店小二呀4 天前
Llama 4革命性发布与绿色AI前沿研究
人工智能·llama