RUST笔记:candle使用基础

candle介绍

  • candle是huggingface开源的Rust的极简 ML 框架。

candle-矩阵乘法示例

cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
cargo build # 测试,或执行 cargo ckeck
  • main.rs

    use candle_core::{Device, Tensor};

    fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

      let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
      let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
    
      let c = a.matmul(&b)?;
      println!("{c}");
      Ok(())
    

    }

  • 项目输出

    ~/myrust$ cargo new myapp
    Created binary (application) myapp package
    ~/myrust$ cd myapp
    ~/myrust/myapp$ cargo add --git https://github.com/huggingface/candle.git candle-core
    Updating git repository https://github.com/huggingface/candle.git
    Updating git submodule https://github.com/NVIDIA/cutlass.git
    Adding candle-core (git) to dependencies.
    Features:
    - accelerate
    - cuda
    - cudarc
    - cudnn
    - metal
    - mkl
    Updating git repository https://github.com/huggingface/candle.git
    Updating crates.io index
    ~/myrust/myapp$ cargo build
    Downloaded serde_derive v1.0.195
    Downloaded either v1.9.0
    Downloaded autocfg v1.1.0
    Downloaded zerofrom v0.1.3
    Downloaded zerofrom-derive v0.1.3
    Downloaded synstructure v0.13.0
    Downloaded crossbeam-deque v0.8.5
    Downloaded yoke-derive v0.7.3
    Downloaded half v2.3.1
    Downloaded bytemuck v1.14.1
    Downloaded rand_core v0.6.4
    Downloaded paste v1.0.14
    Downloaded proc-macro2 v1.0.78
    Downloaded itoa v1.0.10
    Downloaded memmap2 v0.9.4
    Downloaded syn v2.0.48
    Downloaded crossbeam-epoch v0.9.18
    Downloaded cfg-if v1.0.0
    Downloaded bitflags v1.3.2
    Downloaded num_cpus v1.16.0
    Downloaded gemm-f32 v0.17.0
    Downloaded reborrow v0.5.5
    Downloaded stable_deref_trait v1.2.0
    Downloaded rayon-core v1.12.1
    Downloaded seq-macro v0.3.5
    Downloaded thiserror-impl v1.0.56
    Downloaded dyn-stack v0.10.0
    Downloaded thiserror v1.0.56
    Downloaded unicode-xid v0.2.4
    Downloaded rand_chacha v0.3.1
    Downloaded ppv-lite86 v0.2.17
    Downloaded bytemuck_derive v1.5.0
    Downloaded getrandom v0.2.12
    Downloaded once_cell v1.19.0
    Downloaded unicode-ident v1.0.12
    Downloaded byteorder v1.5.0
    Downloaded crc32fast v1.3.2
    Downloaded num-complex v0.4.4
    Downloaded gemm-common v0.17.0
    Downloaded crossbeam-utils v0.8.19
    Downloaded quote v1.0.35
    Downloaded ryu v1.0.16
    Downloaded num-traits v0.2.17
    Downloaded zip v0.6.6
    Downloaded rand_distr v0.4.3
    Downloaded serde v1.0.195
    Downloaded rand v0.8.5
    Downloaded raw-cpuid v10.7.0
    Downloaded libm v0.2.8
    Downloaded serde_json v1.0.111
    Downloaded rayon v1.8.1
    Downloaded libc v0.2.152
    Downloaded gemm-c64 v0.17.0
    Downloaded gemm-c32 v0.17.0
    Downloaded safetensors v0.4.2
    Downloaded gemm-f64 v0.17.0
    Downloaded gemm v0.17.0
    Downloaded gemm-f16 v0.17.0
    Downloaded yoke v0.7.3
    Downloaded pulp v0.18.6
    Downloaded 60 crates (3.1 MB) in 14.91s
    Compiling proc-macro2 v1.0.78
    Compiling unicode-ident v1.0.12
    Compiling libc v0.2.152
    Compiling cfg-if v1.0.0
    Compiling libm v0.2.8
    Compiling autocfg v1.1.0
    Compiling crossbeam-utils v0.8.19
    Compiling ppv-lite86 v0.2.17
    Compiling rayon-core v1.12.1
    Compiling reborrow v0.5.5
    Compiling paste v1.0.14
    Compiling either v1.9.0
    Compiling bitflags v1.3.2
    Compiling seq-macro v0.3.5
    Compiling once_cell v1.19.0
    Compiling unicode-xid v0.2.4
    Compiling raw-cpuid v10.7.0
    Compiling serde v1.0.195
    Compiling crc32fast v1.3.2
    Compiling serde_json v1.0.111
    Compiling stable_deref_trait v1.2.0
    Compiling itoa v1.0.10
    Compiling ryu v1.0.16
    Compiling thiserror v1.0.56
    Compiling byteorder v1.5.0
    Compiling num-traits v0.2.17
    Compiling zip v0.6.6
    Compiling crossbeam-epoch v0.9.18
    Compiling quote v1.0.35
    Compiling syn v2.0.48
    Compiling crossbeam-deque v0.8.5
    Compiling getrandom v0.2.12
    Compiling memmap2 v0.9.4
    Compiling num_cpus v1.16.0
    Compiling rand_core v0.6.4
    Compiling rand_chacha v0.3.1
    Compiling rayon v1.8.1
    Compiling rand v0.8.5
    Compiling rand_distr v0.4.3
    Compiling synstructure v0.13.0
    Compiling bytemuck_derive v1.5.0
    Compiling serde_derive v1.0.195
    Compiling zerofrom-derive v0.1.3
    Compiling thiserror-impl v1.0.56
    Compiling yoke-derive v0.7.3
    Compiling bytemuck v1.14.1
    Compiling num-complex v0.4.4
    Compiling dyn-stack v0.10.0
    Compiling half v2.3.1
    Compiling zerofrom v0.1.3
    Compiling yoke v0.7.3
    Compiling pulp v0.18.6
    Compiling gemm-common v0.17.0
    Compiling gemm-f32 v0.17.0
    Compiling gemm-c64 v0.17.0
    Compiling gemm-f64 v0.17.0
    Compiling gemm-c32 v0.17.0
    Compiling gemm-f16 v0.17.0
    Compiling gemm v0.17.0
    Compiling safetensors v0.4.2
    Compiling candle-core v0.3.3 (https://github.com/huggingface/candle.git#fd7c8565)
    Compiling myapp v0.1.0 (/home/pdd/myrust/myapp)
    Finished dev [unoptimized + debuginfo] target(s) in 32.90s

candle_test的简单测试项目

Cargo.toml 文件

csharp 复制代码
[package]
name = "candle_test"
version = "0.1.0"
edition = "2021" #  Rust 版本

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.2.1", features = ["cuda"] }
# `candle-core`:项目依赖的包的名称。`git` 字段指定了包的源代码仓库地址。`version` 字段指定了使用的包的版本。`features` 字段是一个数组,指定了启用的功能。在这里,启用了 "cuda" 功能。
# 可以通过以下命令添加,取消可注释掉"cuda",再cargo build
# cargo add --git https://github.com/huggingface/candle.git candle-core
# cargo add candle-core --features cuda

main.rs

rust 复制代码
use candle_core::{DType, Device, Result, Tensor};

// 定义一个模型结构体
struct Model {
    first: Tensor,
    second: Tensor,
}

impl Model {
    // 定义模型的前向传播方法
    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let x = image.matmul(&self.first)?; // 输入乘以第一层权重
        let x = x.relu()?; // 使用 ReLU 激活函数
        x.matmul(&self.second) // 结果乘以第二层权重
    }
}

fn main() -> Result<()> {
    // 初始化设备,如果 GPU 可用则使用 GPU,否则使用 CPU
    let device = match Device::new_cuda(0) {
        Ok(device) => device,
        Err(_) => Device::Cpu,
    };

    // 创建模型的第一层和第二层权重张量
    let first = Tensor::zeros((784, 100), DType::F32, &device)
        .unwrap()
        .contiguous()?;
    let second = Tensor::zeros((100, 10), DType::F32, &device)
        .unwrap()
        .contiguous()?;
    
    // 初始化模型
    let model = Model { first, second };

    // 创建一个用于测试的虚拟图像张量
    let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)
        .unwrap()
        .contiguous()?;

    // 调用模型的前向传播方法获取预测结果
    let digit = model.forward(&dummy_image)?;

    // 打印预测结果
    println!("Digit {digit:?} digit");

    Ok(())
}

知识点总结

candle_core:: Result

// Result定义在/home/pdd/.cargo/git/checkouts/candle-0c2b4fa9e5801351/e8e3375/candle-core/src/error.rs
pub type Result<T> = std::result::Result<T, Error>; // 定义了一个 `Result` 类型,这是一个 `Result<T, Error>` 类型的别名。其中 `T` 是成功时的返回类型,而 `Error` 是失败时的错误类型。
rust 复制代码
// Ok(()) 定义在 /home/pdd/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs
// 这是 Rust 标准库中的 `Result` 公共的枚举类型,它有两个泛型参数 `T` 和 `E`。`T` 代表成功时返回的值的类型,`E` 代表错误时返回的错误类型。
// #[]是属性(attribute),提供额外信息
pub enum Result<T, E> {
    /// Contains the success value
    #[lang = "Ok"]
    #[stable(feature = "rust1", since = "1.0.0")]
    Ok(#[stable(feature = "rust1", since = "1.0.0")] T),// `Ok(T)`: 这是 `Result` 枚举的一个变体,用于表示成功的情况
                                                        // (): 是 Rust 中的单元类型(unit type),类似于其他语言中的 void。

    /// Contains the error value
    #[lang = "Err"]
    #[stable(feature = "rust1", since = "1.0.0")]
    Err(#[stable(feature = "rust1", since = "1.0.0")] E),// `Err(E)`: 这是 `Result` 枚举的另一个变体,用于表示错误的情况。
}

?符号

  • 在 Rust 中,? 符号用于处理 ResultOption 类型的返回值。这个符号的作用是将可能的错误或 None 值快速传播到调用链的最上层,使得代码更加简洁和易读。
rust 复制代码
fn forward(&self, image: &Tensor) -> Result<Tensor> {
    let x = image.matmul(&self.first)?; // 如果matmul返回Err,则整个forward函数返回Err
    let x = x.relu()?; // 如果relu返回Err,则整个forward函数返回Err
    x.matmul(&self.second) // 如果matmul返回Err,则整个forward函数返回Err;否则返回Ok(Tensor)
}

语句和表达式:语句以分号结尾,而表达式通常不需要分号。

  • 函数体:函数体是一个块表达式,其值是最后一个表达式的值。

    rust 复制代码
    fn add(x: i32, y: i32) -> i32 {
        x + y // 表达式
    }

CG

相关推荐
cuisidong199715 分钟前
5G学习笔记三之物理层、数据链路层、RRC层协议
笔记·学习·5g
乌恩大侠17 分钟前
5G周边知识笔记
笔记·5g
筱源源41 分钟前
Elasticsearch-linux环境部署
linux·elasticsearch
‍。。。1 小时前
使用Rust实现http/https正向代理
http·https·rust
Source.Liu1 小时前
【用Rust写CAD】第二章 第四节 函数
开发语言·rust
monkey_meng1 小时前
【Rust中的迭代器】
开发语言·后端·rust
余衫马1 小时前
Rust-Trait 特征编程
开发语言·后端·rust
monkey_meng1 小时前
【Rust中多线程同步机制】
开发语言·redis·后端·rust
咔叽布吉2 小时前
【论文阅读笔记】CamoFormer: Masked Separable Attention for Camouflaged Object Detection
论文阅读·笔记·目标检测
johnny2332 小时前
《大模型应用开发极简入门》笔记
笔记·chatgpt