rust-candle学习笔记11-实现一个简单的自注意力

参考:about-pytorch

定义ScaledDotProductAttention结构体:

rust 复制代码
use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};

struct ScaledDotProductAttention {
    wq: Linear,
    wk: Linear,
    wv: Linear,
    d_model: Tensor,
    device: Device,
}

为ScaledDotProductAttention结构体实现new方法:

rust 复制代码
impl ScaledDotProductAttention {
    fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {
        Ok(Self { 
            wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, 
            wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, 
            wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,
            d_model: Tensor::new(embedding_dim as f32, &device)?,
            device,
        })
    }
}

为结构体实现Module的forward trait:

rust 复制代码
impl Module for ScaledDotProductAttention {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let q = self.wq.forward(xs)?;
        let k = self.wk.forward(xs)?;
        let v = self.wv.forward(xs)?;
        let attn_score = q.matmul(&k.t()?)?;
        let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;
        let dim = attn_score.rank() - 1;
        let attn_weights = ops::softmax(&attn_score, dim)?;
        let attn_output = attn_weights.matmul(&v)?;
        Ok(attn_output)
    }
}

融合qkv实现:

定义ScaledDotProductAttentionFusedQKV结构体:

rust 复制代码
struct ScaledDotProductAttentionFusedQKV {
    w_qkv: Linear,
    d_model: Tensor,
    device: Device,
}

为结构体实现new方法:

rust 复制代码
impl ScaledDotProductAttentionFusedQKV {
    fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {
        Ok(Self { 
            w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,
            d_model: Tensor::new(embedding_dim as f32, &device)?,
            device,
        })
    }
}

为结构体实现forward trait:

rust 复制代码
impl Module for ScaledDotProductAttentionFusedQKV {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let qkv = self.w_qkv.forward(xs)?;
        let (batch_size, seq_len, _) = qkv.dims3()?;
        let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;
        let q = qkv.get_on_dim(2, 0)?;
        let q = q.reshape((batch_size, seq_len, ()))?;
        let k = qkv.get_on_dim(2, 1)?;
        let k = k.reshape((batch_size, seq_len, ()))?;
        let v = qkv.get_on_dim(2, 2)?;
        let v = v.reshape((batch_size, seq_len, ()))?;
        let attn_score = q.matmul(&k.t()?)?;
        let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;
        let dim = attn_score.rank() - 1;
        let attn_weights = ops::softmax(&attn_score, dim)?;
        let attn_output = attn_weights.matmul(&v)?;
        Ok(attn_output)
    }
}

测试:

rust 复制代码
fn main() -> Result<()> {
    let device = Device::cuda_if_available(0)?;
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
    
    let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55, 
                                                    0.43, 0.15, 0.89, 
                                                    0.55, 0.87, 0.66,
                                                    0.57, 0.85, 0.64,
                                                    0.22, 0.58, 0.33,
                                                    0.77, 0.25, 0.10,
                                                    0.05, 0.80, 0.55], (2, 6, 3), &device)?;
    // let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;
    let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;
    let output = model.forward(&input)?;
    println!("output: {:?}\n", output);
    println!("output: {:?}\n", output.to_vec3::<f32>()?);
    Ok(())
}
相关推荐
Kiri霧3 小时前
Linux下的Rust 与 C 的互操作性解析
c语言·开发语言·rust
聪明的笨猪猪3 小时前
Java Redis “持久化”面试清单(含超通俗生活案例与深度理解)
java·经验分享·笔记·面试
聪明的笨猪猪3 小时前
Java Redis “核心基础”面试清单(含超通俗生活案例与深度理解)
java·经验分享·笔记·面试
空白到白4 小时前
NLP-注意力机制
人工智能·自然语言处理
繁花与尘埃6 小时前
HTML5简介与基本骨架(本文为个人学习笔记,内容整理自哔哩哔哩UP主【非学者勿扰】的公开课程。 > 所有知识点归属原作者,仅作非商业用途分享)
笔记·学习·html5
东方芷兰6 小时前
LLM 笔记 —— 04 为什么语言模型用文字接龙,图片模型不用像素接龙呢?
人工智能·笔记·深度学习·语言模型·自然语言处理
Rock_yzh7 小时前
AI学习日记——卷积神经网络(CNN):完整实现与可视化分析
人工智能·python·深度学习·神经网络·学习·cnn
大鱼七成饱8 小时前
Rust 多线程编程入门:从 thread::spawn 步入 Rust 并发世界
后端·rust
Test.X8 小时前
学习16天:pytest学习
学习·pytest
XISHI_TIANLAN8 小时前
【多模态学习】Q&A6: 什么是MOE架构?Router Z Loss函数是指什么?负载均衡损失(Load Balancing Loss)又是什么?
学习·算法·语言模型