rust-candle学习笔记12-实现因果注意力

参考:about-pytorch

定义结构体:

rust 复制代码
struct CausalAttention {
    w_qkv: Linear,
    dropout: Dropout, 
    d_model: Tensor,
    mask: Tensor,
    device: Device,   
}

定义new方法:

rust 复制代码
impl CausalAttention {
    fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, dropout: f32, device: Device) -> Result<Self> {
        Ok(Self { 
            w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,
            d_model: Tensor::new(embedding_dim as f32, &device)?,
            mask: Tensor::tril2(seq_len, DType::U32, &device)?,
            dropout: Dropout::new(dropout),
            device
        })
    }
}

定义forward方法:

rust 复制代码
    fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> { 
        let qkv = self.w_qkv.forward(x)?;
        let (batch_size, seq_len, _) = qkv.dims3()?;
        let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;
        let q = qkv.get_on_dim(2, 0)?;
        let q = q.reshape((batch_size, seq_len, ()))?;
        let k = qkv.get_on_dim(2, 1)?;
        let k = k.reshape((batch_size, seq_len, ()))?;
        let v = qkv.get_on_dim(2, 2)?;
        let v = v.reshape((batch_size, seq_len, ()))?;
        let mut attn_score = q.matmul(&k.t()?)?;
        // println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);
        let dim = attn_score.rank() - 1;
        let mask_dim = attn_score.dims()[dim];
        let mask = self.mask.broadcast_as(attn_score.shape())?;
        // println!("mask: {:?}\n", mask);
        // println!("mask: {:?}\n", mask.to_vec3::<u32>()?);
        attn_score = masked_fill(&attn_score, &mask, f32::NEG_INFINITY)?;
        // println!("attn_score: {:?}\n", attn_score);
        // println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);
        let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?; 
        let attn_weights = ops::softmax(&attn_score, dim)?;
        // println!("attn_weights: {:?}\n", attn_weights);
        // println!("attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); 
        let attn_weights = self.dropout.forward(&attn_weights, train)?;
        // println!("dropout attn_weights: {:?}\n", attn_weights);
        // println!("dropout attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); 
        let attn_output = attn_weights.matmul(&v)?;
        Ok(attn_output)
    }

测试:

rust 复制代码
fn main() -> Result<()> {
    let device = Device::cuda_if_available(0)?;
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
    
    let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55, 
                                                    0.43, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55], (2, 6, 3), &device)?;
    let model = CausalAttention::new(vb.clone(), 3, 2, 6, 0.5, device.clone())?;
    let output = model.forward(&input, true)?;
    println!("output: {:?}\n", output);
    println!("output: {:?}\n", output.to_vec3::<f32>()?);
    Ok(())
}
相关推荐
KG_LLM图谱增强大模型9 小时前
多智能体大语言模型框架赋能医学等多领域低资源命名实体识别:知识检索、消歧与反思分析的创新实践
人工智能·语言模型·自然语言处理
LIZHUOLONG19 小时前
AI 系统学习路径
人工智能·学习
17(无规则自律)10 小时前
【CSAPP 读书笔记】第一章:计算机系统漫游
linux·c语言·arm开发·嵌入式硬件·学习·ubuntu
曾浩轩10 小时前
C语言学习记录——BC113 数字三角形
c语言·学习
●VON10 小时前
Flutter 与 OpenHarmony 应用功能深化:构建独立任务表单页面与完善编辑体验
学习·flutter·openharmony·von
老鱼说AI10 小时前
论文精读第八期:Quiet-STaR 深度剖析:如何利用并行 Attention 与 REINFORCE 唤醒大模型的“潜意识”?
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
四谎真好看10 小时前
JavaWeb学习笔记(Day08+Day09)之Mybatis入门+基础操作
笔记·学习·学习笔记·javaweb
xqqxqxxq10 小时前
《智能仿真无人机平台(多线程V2.0)技术笔记》(线程进阶: 无人机自动防空平台开发教程)
笔记·无人机·cocos2d
三伏52210 小时前
Cortex-M3权威指南Cn第七章——笔记
笔记·cortex-m3
丝斯201110 小时前
AI学习笔记整理(56)——大模型微调
人工智能·笔记·学习