pytorch 是如何调用 cusolver API 的调用

0,环境

ubuntu 22.04

pytorch 2.3.1

x86

RTX 3080

cuda 12.2

1, 示例代码

以potrs为例;

hello_cholesk.py

复制代码
""" 
hello_cholesky.py
step1, Cholesky decompose;
step2, inverse A;
step3, Cholesky again;
python3 hello_cholesky.py --size 256  --cuda_device_id  0
"""
import torch
import time
import argparse


def cholesky_measure(A, cuda_dev=0):
    dev = torch.device(f"cuda:{cuda_dev}")
    A = A.to(dev)

    print(f'Which device to compute : {dev}')
  
    SY = 100* torch.mm(A, A.t()) +  200*torch.eye(N, device=dev)

    to_start = time.time() 
    SY = torch.linalg.cholesky(SY)
    SY = torch.cholesky_inverse(SY)
    SY = torch.linalg.cholesky(SY, upper=True)
    run_time = time.time() - to_start   
     
    print(f'The device: {dev}, run: {run_time:.3f} second')
    print(f'SY : {SY}')
    print(f'****'*20)

    return run_time

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='dim of A.')
    parser.add_argument('--N', type=int, default=512, required=True, help='dim of A')
    args = parser.parse_args()
    N = args.N

    print(f'A N : {N}')    
    A = torch.randn(N, N)
       
    cuda_dev = 0
    time_dev0 = cholesky_measure(A, cuda_dev)    
    time_dev1 = cholesky_measure(A, cuda_dev+1)    
    print(f'time_dev0 /time_dev1 = {time_dev0/time_dev1:.2f} ')

运行效果:

2,调用栈跟踪

跟踪如下调用关系:

复制代码
Tensor cholesky_inverse(const Tensor &input, bool upper)    aten/src/ATen/native/BatchLinearAlgebra.cpp
	static Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper)
	DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
	REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
	Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
	Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper)
	void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper)
	template<typename scalar_t>
	inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos)
	at::cuda::solver::potrs<scalar_t>(
      handle, uplo, n_32, nrhs_32,
      A_ptr + i * A_matrix_stride,
      lda_32,
      self_working_copy_ptr + i * self_matrix_stride,
      ldb_32,
      infos_ptr
    );

一些细节:

相关推荐
橙汁味的风1 分钟前
1基础搜索算法
人工智能·图搜索
Dfreedom.5 分钟前
第一阶段:U-net++的概况和核心价值
人工智能·深度学习·神经网络·计算机视觉·图像分割·u-net·u-net++
weixin_462446236 分钟前
使用 Docker Compose 部署 Next-AI-Draw-IO(精简 OpenAI 兼容配置)
人工智能·docker·容器
Dfreedom.6 分钟前
循阶而上,庖丁解牛:系统学习开源 AI 模型的行动指南
人工智能·深度学习·学习·开源·图像算法
亚马逊云开发者8 分钟前
使用 Kiro AI IDE 开发 基于Amazon EMR 的Flink 智能监控系统实践
人工智能
数据光子10 分钟前
【YOLO数据集】自动驾驶
人工智能·yolo·自动驾驶
Elastic 中国社区官方博客12 分钟前
使用 Elasticsearch 中的结构化输出创建可靠的 agents
大数据·人工智能·elk·elasticsearch·搜索引擎·ai·全文检索
北京耐用通信16 分钟前
告别AGV“迷路”“断联”!耐达讯自动化PROFIBUS三路中继器,用少投入解决大麻烦
人工智能·科技·网络协议·自动化·信息与通信
xinyaozixun17 分钟前
闪极loomos系列AI眼镜重磅发布,重构日常佩戴体验,再次引领智能穿戴赛道
人工智能·重构
Hcoco_me19 分钟前
大模型面试题22:从通俗理解交叉熵公式到通用工程实现
人工智能·rnn·自然语言处理·lstm·word2vec