LLaMA-Adapter源码解析

LLaMA-Adapter源码解析

伪代码

python 复制代码
def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt):
	residual =x
	y= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefix
	x= self_attention(x)
	x = x+ gating_factor * y  # llama-adapter: apply zero_init_attention
	x = LayerNorm(x+residual)
	residual = x
	x = FullyConnectedLayers(x)
	x = AdapterLayers(x)
	x = LayerNorm(x + residual)
	return x

源码

python 复制代码
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.gate = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        if adapter is not None:
           adapter_len = adapter.shape[1]
           adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
           adapter_k = adapter_k.transpose(1, 2)
           adapter_v = adapter_v.transpose(1, 2)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        if adapter is not None:
            adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
            adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
            output = output + torch.matmul(adapter_scores, adapter_v)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)
相关推荐
晨尘光2 天前
在Windows下编译出llama_cpp_python的DLL后,在虚拟环境中使用方法
python·llama
风筝超冷5 天前
LLaMA-Factory - 批量推理(inference)的脚本
llama
bluebonnet276 天前
【agent开发】部署LLM(一)
python·llama
阿牛大牛中7 天前
LLaDa——基于 Diffusion 的大语言模型 打平 LLama 3
人工智能·语言模型·llama
Lilith的AI学习日记7 天前
【AI面试秘籍】| 第25期:RAG的关键痛点及解决方案深度解析
人工智能·深度学习·机器学习·chatgpt·aigc·llama
LChuck10 天前
【大模型微调】魔搭社区GPU进行LLaMA-Factory微调大模型自我认知
人工智能·语言模型·自然语言处理·nlp·llama·魔搭社区·modelscope
燕双嘤10 天前
Fine-tuning:微调技术,训练方式,LLaMA-Factory,ms-swift
llama
装不满的克莱因瓶12 天前
【小白AI教程】大模型知识扫盲通识
人工智能·数学建模·ai·大模型·llm·llama·rag
TGITCIC14 天前
英伟达破局1000 Token/秒!Llama 4以光速重塑AI推理边界
人工智能·大模型·llama·英伟达·大模型速度·ai赛道·大模型基座
天天爱吃肉821815 天前
【 大模型技术驱动智能网联汽车革命:关键技术解析与未来趋势】
语言模型·汽车·llama