0,环境
ubuntu 22.04
pytorch 2.3.1
x86
RTX 3080
cuda 12.2
1, 示例代码
以potrs为例;
hello_cholesk.py
"""
hello_cholesky.py
step1, Cholesky decompose;
step2, inverse A;
step3, Cholesky again;
python3 hello_cholesky.py --size 256 --cuda_device_id 0
"""
import torch
import time
import argparse
def cholesky_measure(A, cuda_dev=0):
dev = torch.device(f"cuda:{cuda_dev}")
A = A.to(dev)
print(f'Which device to compute : {dev}')
SY = 100* torch.mm(A, A.t()) + 200*torch.eye(N, device=dev)
to_start = time.time()
SY = torch.linalg.cholesky(SY)
SY = torch.cholesky_inverse(SY)
SY = torch.linalg.cholesky(SY, upper=True)
run_time = time.time() - to_start
print(f'The device: {dev}, run: {run_time:.3f} second')
print(f'SY : {SY}')
print(f'****'*20)
return run_time
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='dim of A.')
parser.add_argument('--N', type=int, default=512, required=True, help='dim of A')
args = parser.parse_args()
N = args.N
print(f'A N : {N}')
A = torch.randn(N, N)
cuda_dev = 0
time_dev0 = cholesky_measure(A, cuda_dev)
time_dev1 = cholesky_measure(A, cuda_dev+1)
print(f'time_dev0 /time_dev1 = {time_dev0/time_dev1:.2f} ')
运行效果:
2,调用栈跟踪
跟踪如下调用关系:
Tensor cholesky_inverse(const Tensor &input, bool upper) aten/src/ATen/native/BatchLinearAlgebra.cpp
static Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper)
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper)
void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper)
template<typename scalar_t>
inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos)
at::cuda::solver::potrs<scalar_t>(
handle, uplo, n_32, nrhs_32,
A_ptr + i * A_matrix_stride,
lda_32,
self_working_copy_ptr + i * self_matrix_stride,
ldb_32,
infos_ptr
);
一些细节: