pytorch 是如何调用 cusolver API 的调用

0,环境

ubuntu 22.04

pytorch 2.3.1

x86

RTX 3080

cuda 12.2

1, 示例代码

以potrs为例;

hello_cholesk.py

复制代码
""" 
hello_cholesky.py
step1, Cholesky decompose;
step2, inverse A;
step3, Cholesky again;
python3 hello_cholesky.py --size 256  --cuda_device_id  0
"""
import torch
import time
import argparse


def cholesky_measure(A, cuda_dev=0):
    dev = torch.device(f"cuda:{cuda_dev}")
    A = A.to(dev)

    print(f'Which device to compute : {dev}')
  
    SY = 100* torch.mm(A, A.t()) +  200*torch.eye(N, device=dev)

    to_start = time.time() 
    SY = torch.linalg.cholesky(SY)
    SY = torch.cholesky_inverse(SY)
    SY = torch.linalg.cholesky(SY, upper=True)
    run_time = time.time() - to_start   
     
    print(f'The device: {dev}, run: {run_time:.3f} second')
    print(f'SY : {SY}')
    print(f'****'*20)

    return run_time

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='dim of A.')
    parser.add_argument('--N', type=int, default=512, required=True, help='dim of A')
    args = parser.parse_args()
    N = args.N

    print(f'A N : {N}')    
    A = torch.randn(N, N)
       
    cuda_dev = 0
    time_dev0 = cholesky_measure(A, cuda_dev)    
    time_dev1 = cholesky_measure(A, cuda_dev+1)    
    print(f'time_dev0 /time_dev1 = {time_dev0/time_dev1:.2f} ')

运行效果:

2,调用栈跟踪

跟踪如下调用关系:

复制代码
Tensor cholesky_inverse(const Tensor &input, bool upper)    aten/src/ATen/native/BatchLinearAlgebra.cpp
	static Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper)
	DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
	REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
	Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
	Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper)
	void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper)
	template<typename scalar_t>
	inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos)
	at::cuda::solver::potrs<scalar_t>(
      handle, uplo, n_32, nrhs_32,
      A_ptr + i * A_matrix_stride,
      lda_32,
      self_working_copy_ptr + i * self_matrix_stride,
      ldb_32,
      infos_ptr
    );

一些细节:

相关推荐
一 铭36 分钟前
AI领域新趋势:从提示(Prompt)工程到上下文(Context)工程
人工智能·语言模型·大模型·llm·prompt
麻雀无能为力4 小时前
CAU数据挖掘实验 表分析数据插件
人工智能·数据挖掘·中国农业大学
时序之心4 小时前
时空数据挖掘五大革新方向详解篇!
人工智能·数据挖掘·论文·时间序列
.30-06Springfield5 小时前
人工智能概念之七:集成学习思想(Bagging、Boosting、Stacking)
人工智能·算法·机器学习·集成学习
说私域6 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的超级文化符号构建路径研究
人工智能·小程序·开源
永洪科技6 小时前
永洪科技荣获商业智能品牌影响力奖,全力打造”AI+决策”引擎
大数据·人工智能·科技·数据分析·数据可视化·bi
shangyingying_16 小时前
关于小波降噪、小波增强、小波去雾的原理区分
人工智能·深度学习·计算机视觉
书玮嘎7 小时前
【WIP】【VLA&VLM——InternVL系列】
人工智能·深度学习
猫头虎7 小时前
猫头虎 AI工具分享:一个网页抓取、结构化数据提取、网页爬取、浏览器自动化操作工具:Hyperbrowser MCP
运维·人工智能·gpt·开源·自动化·文心一言·ai编程
要努力啊啊啊7 小时前
YOLOv2 正负样本分配机制详解
人工智能·深度学习·yolo·计算机视觉·目标跟踪