TRSV优化2

TRSV 依赖标志工作区复用优化

一句话结论

这个改动的本质是:

把"每次调用的全局内存分配 + 初始化 + kernel launch"
变成了"一次性分配 + 轻量初始化"

而在 n = 16 ~ 2048 这个区间,这些"非算术开销"常常比算本身还贵


1️⃣ 原来慢在哪?

原来的每次 trsv_blocked_launcher_2D 调用里都有:

hipMalloc / hipFree

  • 这是 全局同步级别 的操作

  • 会:

    • 打断 stream 的流水

    • 触发 device runtime 内部锁

    • 对小 n / 小 batch 极其不友好

对 n=16~256,算本身可能只要几微秒,malloc/free 却可能要 几十微秒


trsv_init 初始化 kernel

为了把 flags 设成 -1,launch 了一个 kernel:

  • kernel launch 本身就有固定开销

  • 这个 kernel 完全没有计算价值

  • 它还占用一次调度 slot,影响后面的主 kernel

小 n 时,你实际上在跑:

  • init kernel

  • 主 kernel

    而 init kernel 不产生任何算术结果


③ 每次调用都重新"搭脚手架"

flags 的生命周期其实是:

  • 只和 handle / batch_count 有关

  • uplo / trans / diag / n / A / x 都没关系

但原来每次都当成"临时对象":

复制代码
malloc → init → 用 → free

这在数值库里是典型的 anti-pattern


2️⃣ 现在改了什么?(本质变化)

你现在做了三件关键的事:


✅ A. 把 flags 变成 handle 级 workspace(复用)

复制代码
原来:每次调用新建 + 销毁
现在:handle 生命周期内复用

效果:

  • malloc/free 从 O(调用次数)O(扩容次数)

  • 对 batch_count 稳定的场景,几乎 = 0 次 malloc


✅ B. 用 hipMemsetAsync 干掉 init kernel

复制代码
原来:launch trsv_init kernel
现在:一次 async memset

为什么更快?

  • memset 是 runtime 内建 fast path

  • 不需要:

    • kernel 参数 setup

    • grid/block 调度

    • wavefront 执行

对小 n,这是纯省开销


✅ C. flags 和 stream 对齐,保持流水

现在用的是:

复制代码
hipMemsetAsync(..., stream);
hipLaunchKernelGGL(..., stream);

这保证:

  • init → compute 在 同一 stream 串行

  • 没有隐式同步

  • GPU pipeline 更干净


3️⃣ 为什么对n范围特别有效(16~2048)

小 n(16~256)

  • kernel 很快

  • launch / malloc / init 占比极高

现在:

  • 去掉 1 次 malloc

  • 去掉 1 次 free

  • 去掉 1 次 kernel launch

➡️ 常见能看到 1.5×~3× 提升


中 n(256~2048)

  • kernel 开始占主导

  • 但仍然:

    • 少一次 kernel launch

    • 少一次全局同步

➡️ 提升更温和(5%~20%),但稳定性明显变好(tail latency 降低)


4️⃣ 为什么这不会"改变数值正确性"

  • flags 只是 block 间依赖同步

  • 你仍然:

    • 每次调用前把 flags 设为 -1

    • kernel 内仍然 threadfence + 写 flag

  • 数值路径(A/x 的访问、计算顺序)完全没变

所以:

这是"调度/资源层优化",不是数值算法优化


5️⃣ 这一步为后续更狠的优化铺了路

一旦 flags 变成长期存在的 workspace,马上就可以:

  1. epoch 方案(彻底去掉 memset)

  2. 多 kernel / multi-launch 共用 flags

  3. stream-parallel TRSV(多个 stream 不互相踩内存)

  4. n 很小的时候直接绕过 wavefront kernel

这些以前都很难做,因为 flags 是"临时对象"。


最后一句总结

不是"让 kernel 更快",

而是 让 GPU 不再为无意义的准备工作买单

这在 BLAS / solver 里,往往是最容易被忽略、但收益最稳的一类优化

运行代码

cpp 复制代码
//宏定义优化+TRSV 依赖标志工作区复用优化

#include "trsv_strided_batched.hpp"
#include "gpublas-auxiliary.h"
#include "handle.hpp"
#define LDA       32
#define DIM_X     32
#define DIM_Y     8

namespace internal {

    enum class trsv_trans_t { none, trans, conj_trans };
    template <typename T>
    __device__ __forceinline__ T zero_val() { return (T)0; }

    template <>
    __device__ __forceinline__ gpublas_float_complex zero_val<gpublas_float_complex>() { return {0.0f, 0.0f}; }

    template <>
    __device__ __forceinline__ gpublas_double_complex zero_val<gpublas_double_complex>() { return {0.0, 0.0}; }

    template <typename T>
    __device__ __forceinline__ T one_val() { return (T)1; }

    template <>
    __device__ __forceinline__ gpublas_float_complex one_val<gpublas_float_complex>() { return {1.0f, 0.0f}; }

    template <>
    __device__ __forceinline__ gpublas_double_complex one_val<gpublas_double_complex>() { return {1.0, 0.0}; }


    // 如果你文件里已经有 opA_ij / conj_if_needed,就别在这里再定义了。
    // 下面代码直接调用 opA_ij / one_val / zero_val.
    
    // Solve the A21 section during A inversion (lower-block form)
    // A = [A11 0; A21 A22]
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG>
    __device__ __forceinline__
    void gpublas_invert_solve_A21(const T* __restrict__ A11,
                                 T* __restrict__       A21,
                                 const T* __restrict__ A22,
                                 T* __restrict__       sx)
    {
        const gpublas_int tid      = DIM_X * threadIdx.y + threadIdx.x;
        const gpublas_int ntid     = DIM_X * DIM_Y;
        const gpublas_int tx       = tid % N;
        const gpublas_int ty       = tid / N;
        const gpublas_int col_span = ntid / N;
    
        for(gpublas_int i = 0; i < N; i += col_span)
        {
            gpublas_int col  = i + ty;
            bool        skip = (col >= N);
    
            T val = zero_val<T>();
            if(!skip)
            {
                // val = -A21 * A11^{-1} piece (这里 A11 已在递归中被"inverted"到存储布局里)
                for(gpublas_int j = i; j < N; j++)
                {
                    if(j + ty < N)
                        val += A21[(j + ty) * LDA + tx] * A11[col * LDA + (j + ty)];
                }
                val = -val;
            }
    
            // Forward substitution with A22 (diagonal already holds inv if you preprocessed like you do)
            for(gpublas_int j = 0; j < N; j++)
            {
                if(tx == j && !skip)
                {
                    if(!UNIT_DIAG)
                        val *= A22[j * LDA + j]; // diag is inv(diag) in your preprocess style
                    sx[ty] = val;
                }
                __syncthreads();
                if(tx > j && !skip)
                {
                    val += A22[j * LDA + tx] * sx[ty];
                }
                __syncthreads();
            }
    
            if(!skip)
                A21[col * LDA + tx] = -val;
    
            __syncthreads();
        }
    }
    
    // Solve the A12 section during A inversion (upper-block form)
    // A = [A11 A12; 0 A22]
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG>
    __device__ __forceinline__
    void gpublas_invert_solve_A12(const T* __restrict__ A11,
                                 T* __restrict__       A12,
                                 const T* __restrict__ A22,
                                 T* __restrict__       sx)
    {
        const gpublas_int tid      = DIM_X * threadIdx.y + threadIdx.x;
        const gpublas_int ntid     = DIM_X * DIM_Y;
        const gpublas_int tx       = tid % N;
        const gpublas_int ty       = tid / N;
        const gpublas_int col_span = ntid / N;
    
        for(gpublas_int i = N - 1; i >= 0; i -= col_span)
        {
            gpublas_int col  = i - ty;
            bool        skip = (col < 0);
    
            T val = zero_val<T>();
            if(!skip)
            {
                for(gpublas_int j = 0; j < N; j++)
                {
                    if(j <= col)
                        val += A12[j * LDA + tx] * A22[col * LDA + j];
                }
            }
    
            // Back substitution with A11 (A11 not yet fully inverted, but diag is)
            for(gpublas_int j = N - 1; j >= 0; j--)
            {
                if(tx == j && !skip)
                {
                    if(!UNIT_DIAG)
                        val *= A11[j * LDA + j]; // diag is inv(diag) in your preprocess style
                    sx[ty] = -val;
                }
                __syncthreads();
                if(tx < j && !skip)
                {
                    val -= A11[j * LDA + tx] * sx[ty];
                }
                __syncthreads();
            }
    
            if(!skip)
                A12[col * LDA + tx] = val;
    
            __syncthreads();
        }
    }
    
    template <typename T>
    __device__ __forceinline__
    void gpublas_trsv_transpose(const gpublas_int n,
                                const T* __restrict__ A,
                                T* __restrict__       at)
    {
        if(threadIdx.y == 0 && threadIdx.x < n)
        {
            for(gpublas_int i = 0; i < n; i++)
                at[i * LDA + threadIdx.x] = A[threadIdx.x * LDA + i];
        }
    }
    
    template <gpublas_int n>
    static constexpr bool equals_two = false;
    template <>
    constexpr bool equals_two<2> = true;
    
    // Invert (lower) - base case N=2
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert(T* __restrict__ A, T* __restrict__ sx)
    {
        if(threadIdx.x == 0 && threadIdx.y == 0)
        {
            if(UNIT_DIAG)
            {
                A[0]       = one_val<T>();
                A[LDA + 1] = one_val<T>();
            }
            else
            {
                // diag already stores inv(diag), so A[0], A[LDA+1] are done.
                // offdiag: A[1] = A[1] * (inv(d0) * inv(d1)) with sign already folded by your preprocess choice
                A[1] = A[1] * (A[0] * A[LDA + 1]);
            }
    
            if(TRANS)
            {
                A[LDA] = A[1];
            }
        }
    }
    
    // Invert (lower) - recursive
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<!equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert(T* __restrict__ A, T* __restrict__ sx)
    {
        gpublas_trsv_invert<T, N / 2, UNIT_DIAG, TRANS>(A, sx);
        __syncthreads();
    
        gpublas_invert_solve_A21<T, N / 2, UNIT_DIAG>(
            A, &A[N / 2], &A[(LDA + 1) * (N / 2)], sx);
    
        if(TRANS)
        {
            __syncthreads();
            gpublas_trsv_transpose<T>(N / 2, &A[N / 2], &A[(N / 2) * LDA]);
        }
        __syncthreads();
    
        gpublas_trsv_invert<T, N / 2, UNIT_DIAG, TRANS>(
            &A[(LDA + 1) * (N / 2)], sx);
    }
    
    // Invert (upper) - base case N=2
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert_upper(T* __restrict__ A, T* __restrict__ sx)
    {
        if(threadIdx.x == 0 && threadIdx.y == 0)
        {
            if(UNIT_DIAG)
            {
                A[0]       = one_val<T>();
                A[LDA + 1] = one_val<T>();
            }
            else
            {
                // upper offdiag sits at A[LDA]
                A[LDA] = A[LDA] * (A[0] * A[LDA + 1]);
            }
    
            if(TRANS)
            {
                A[1] = A[LDA];
            }
        }
    }
    
    // Invert (upper) - recursive
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<!equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert_upper(T* __restrict__ A, T* __restrict__ sx)
    {
        gpublas_trsv_invert_upper<T, N / 2, UNIT_DIAG, TRANS>(
            &A[(LDA + 1) * (N / 2)], sx);
        __syncthreads();
    
        // A12 solve
        gpublas_invert_solve_A12<T, N / 2, UNIT_DIAG>(
            A, &A[(N / 2) * LDA], &A[(LDA + 1) * (N / 2)], sx);
    
        if(TRANS)
        {
            __syncthreads();
            gpublas_trsv_transpose<T>(N / 2, &A[(N / 2) * LDA], &A[(N / 2)]);
        }
        __syncthreads();
    
        gpublas_trsv_invert_upper<T, N / 2, UNIT_DIAG, TRANS>(A, sx);
    }
    
    // Optional: block_solve using inverse (kept as-is but renamed)
    // NOTE: if you already have sum_sh / xprev_sh usage, integrate as needed.
    template <typename T, bool UPPER>
    __device__ __forceinline__
    void gpublas_trsv_block_solve_inverse(const T* __restrict__ Ainv,
                                          T* __restrict__       sx,
                                          T&                    val,
                                          T* __restrict__       sum)
    {
        Ainv += threadIdx.y * DIM_X + threadIdx.x;
        sx   += threadIdx.y;
    
        if(threadIdx.y == 0)
            sx[threadIdx.x] = val;
    
        __syncthreads();
    
        val = zero_val<T>();
        if(!UPPER)
        {
            for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
            {
                if(threadIdx.x >= threadIdx.y + i)
                    val += Ainv[i * DIM_X] * sx[i];
            }
            sum[threadIdx.y * DIM_X + threadIdx.x] = val;
            __syncthreads();
    
            if(threadIdx.y == 0)
            {
                for(gpublas_int i = 1; i < DIM_Y; i++)
                    val += sum[i * DIM_X + threadIdx.x];
            }
        }
        else
        {
            for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
            {
                if(threadIdx.x <= i + threadIdx.y)
                    val += Ainv[i * DIM_X] * sx[i];
            }
            sum[threadIdx.y * DIM_X + threadIdx.x] = val;
            __syncthreads();
    
            if(threadIdx.y == 0)
            {
                for(gpublas_int i = 1; i < DIM_Y; i++)
                    val += sum[i * DIM_X + threadIdx.x];
            }
        }
    }

template <typename T>
__device__ __forceinline__ T conj_if_needed(const T& v, bool conj)
{
    // real types: conj is no-op
    return v;
}

template <>
__device__ __forceinline__ gpublas_float_complex
conj_if_needed(const gpublas_float_complex& v, bool conj)
{
    if(!conj) return v;
    return {v.x, -v.y};
}

template <>
__device__ __forceinline__ gpublas_double_complex
conj_if_needed(const gpublas_double_complex& v, bool conj)
{
    if(!conj) return v;
    return {v.x, -v.y};
}

// op(A)(i,j)
template <typename T>
__device__ __forceinline__ T opA_ij(
    const T* __restrict__ A, gpublas_int lda,
    gpublas_int i, gpublas_int j,
    trsv_trans_t trans)
{
    if(trans == trsv_trans_t::none)
    {
        return A[i + j * lda];
    }
    else if(trans == trsv_trans_t::trans)
    {
        return A[j + i * lda];
    }
    else // conj_trans
    {
        T v = A[j + i * lda];
        return conj_if_needed(v, true);
    }
}
template <bool UNIT_DIAG, typename T>
__device__ void gpublas_trsv_block_solve_lower(const T* __restrict__ A,
                                                       gpublas_int bs,
                                                       T&          val)
{
    __shared__ T xs;

    // Iterate forwards
    for(gpublas_int i = 0; i < bs; i++)
    {
        // Solve current element
        if(threadIdx.x == i && threadIdx.y == 0)
        {
            if(!UNIT_DIAG)
                val *= A[i * bs + i];
            xs = val;
        }

        __syncthreads();

        // Update future elements with solved one
        if(threadIdx.x > i && threadIdx.y == 0)
        {
            val += A[i * bs + threadIdx.x] * xs;
        }

        __syncthreads();
    }
}

template <bool UNIT_DIAG, typename T>
__device__ void gpublas_trsv_block_solve_upper(const T* __restrict__ A,
                                                       gpublas_int bs,
                                                       T&          val)
{
    __shared__ T xs;

    for(gpublas_int i = bs - 1; i >= 0; i--)
    {
        // Solve current element
        if(threadIdx.x == i && threadIdx.y == 0)
        {
            if(!UNIT_DIAG)
                val *= A[i * bs + i];
            xs = val;
        }

        __syncthreads();

        // Update future elements with solved one
        if(threadIdx.x < i && threadIdx.y == 0 )
        {
            val += A[i * bs + threadIdx.x] * xs;
        }

        __syncthreads();
    }
}
#define INV_AFTER 5
// ------------------------------
// 通用 blocked TRSV kernel
// NB=32, 2D threads (DIM_X=32 lanes x DIM_Y rows)
// ------------------------------
template <typename T, bool EFF_LOWER, bool UNIT_DIAG, trsv_trans_t TRANS>
__global__ void trsv_wavefront_kernel_2D(
    gpublas_int n,
    const T* __restrict__ A, gpublas_int lda, long strideA,
    T* __restrict__ x, gpublas_int incx, long stridex,
    gpublas_int num_blocks,
    gpublas_int* __restrict__ w_completed_sec,
    gpublas_int batch_count   // ✅ 新增
)
{
    constexpr gpublas_int NB = 32;
    constexpr bool DO_TRANS = (TRANS != trsv_trans_t::none);
    for(gpublas_int batch = blockIdx.z; batch < batch_count; batch += gridDim.z)
    {
    
    __shared__ T sAdiag[NB * NB];                 // diagonal block缓存对角块
    __shared__ T xprev_sh[NB];                    // 缓存前面已经求解完的panel的x用于更新
    __shared__ T sum_sh[DIM_X * DIM_Y];           // 分块规约
    T sAoff[DIM_X / DIM_Y];//紧邻对角块的非对角block,寄存器,每个线程负责 4 个元素


    gpublas_int logical_i = blockIdx.x;//当前对应的逻辑块号
    const gpublas_int num_blocks = gridDim.x;

    gpublas_int tx = threadIdx.x;
    gpublas_int ty = threadIdx.y;
    gpublas_int tid = ty * blockDim.x + tx;//全局线性ID
    gpublas_int nthreads = blockDim.x * blockDim.y;//一个CTA总线程数

    auto blk_of = [&](gpublas_int logical_blk) {
        return EFF_LOWER ? logical_blk : (num_blocks - 1 - logical_blk);
    };//逻辑块对应到真实块号

    gpublas_int my_blk = blk_of(logical_i);
    gpublas_int row_start = my_blk * NB;//开始行
    gpublas_int row_end   = min(row_start + NB, n);//结束行
    gpublas_int bs = row_end - row_start;//块的真实大小

    const T* A_batch = A + (long)batch * strideA;
    T* x_batch = x + (long)batch * stridex;
    const gpublas_int remainder = n % DIM_X;
    const bool row_is_remainder = ((n - 1) / DIM_X == my_blk && remainder != 0);
    const bool first_blk = EFF_LOWER ? my_blk == 0 : my_blk == num_blocks - 1;
    if(!first_blk)//提前把"紧邻对角块"的矩阵块加载到寄存器 sAoff 中
    {
        const gpublas_int block_col = EFF_LOWER ?my_blk - 1 :  my_blk + 1;//确定邻接block
        const gpublas_int local_col = DO_TRANS ? my_blk * DIM_X + tx : block_col * DIM_X + ty;//计算矩阵访问位置
        const gpublas_int local_row = DO_TRANS ? block_col * DIM_X + ty : my_blk * DIM_X + tx;
        const size_t      A_idx     = (local_row) + (local_col)*lda;//列主序访问

        for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
        {
            const size_t i_idx = DO_TRANS ? i : i * lda;//连续行;跨列

            __syncthreads();
            if(DO_TRANS ? (local_row + i < n && local_col < n)
                        : (local_row < n && local_col + i < n))//最后一个 block 可能不满 DIM_X
                sAoff[i / DIM_Y] = A[A_idx + i_idx];//每个线程加载一条"竖向切片"
            else
                sAoff[i / DIM_Y] = zero_val<T>();
        }
    }
#ifdef INV_AFTER
        bool cache_transpose = (DO_TRANS && EFF_LOWER && num_blocks - 1 - my_blk < INV_AFTER)
                               || (DO_TRANS && !EFF_LOWER && my_blk < INV_AFTER)
                               || (DO_TRANS && row_is_remainder);
#else
    bool cache_transpose = DO_TRANS; // works for ALL without inversion method
#endif
    //对角块预处理
    gpublas_int row = tx;//每个线程负责一行
    for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
    {
        const gpublas_int col    = ty + i;//thread y 方向协同读取列,当前线程负责矩阵列
        const gpublas_int sA_idx = cache_transpose ? col + bs * row : col * bs + row;
        const size_t      A_idx
            = (my_blk * DIM_X * lda + my_blk * DIM_X) + col * lda + row;//A(my_blk * DIM_X + row,my_blk * DIM_X + col)
        const gpublas_int total_col = my_blk * DIM_X + col;//当前元素在整矩阵中的真实坐标
        const gpublas_int total_row = my_blk * DIM_X + row;

        if(((row > col && EFF_LOWER) || (col > row && !EFF_LOWER)) && row < bs
                   && col < bs)//只允许访问真实存在的矩阵区域
        {
            sAdiag[sA_idx] = -opA_ij<T>(A_batch, lda, total_row, total_col, TRANS);
        }
        else if(!UNIT_DIAG && row == col && row < bs)
        {
            // Dividing here so we can just multiply later.
            sAdiag[sA_idx] = one_val<T>() / opA_ij<T>(A_batch, lda, total_row, total_col, TRANS);
        }
        else if(col < DIM_X
                && row < DIM_X) // In off-triangular portion or past end of remainder
        {
            sAdiag[sA_idx] = zero_val<T>();
        }
    }
    __syncthreads();

#ifdef INV_AFTER
        if(((my_blk >= INV_AFTER && EFF_LOWER)
            || (num_blocks - 1 - my_blk >= INV_AFTER && !EFF_LOWER))
            && !row_is_remainder)
        {
            if(EFF_LOWER)
                gpublas_trsv_invert<T, DIM_X, UNIT_DIAG, DO_TRANS>(sAdiag, sum_sh);
            else
                gpublas_trsv_invert_upper<T, DIM_X, UNIT_DIAG, DO_TRANS>(sAdiag, sum_sh);
        }
#endif
    __syncthreads();
    T val{};                 // 推荐:值初始化 -> (0,0)
//val = 当前 RHS 中间结果


    // -------------------------------
    // 用第一行线程拷贝到共享内存
    if(ty == 0 && tx < bs)
        val = -x_batch[(row_start + tx) * incx];//取对应真实块的b向量
    __syncthreads();


    gpublas_int col_done = -1;//已经等待过的 block index
    // -------------------------------
    // 主循环:依次处理 panel(更新)直到轮到自己(求解并发布)
    for(gpublas_int p = 0; p < logical_i; ++p)
    {
        if(tid == 0) {
            volatile gpublas_int* done = w_completed_sec;
            while(done[batch] < p) { }
        }
        __syncthreads();
        __threadfence();            // acquire x (确保后续读 x 可见)
        __syncthreads();
        


        // 计算该 panel 的真实块位置与大小
        gpublas_int p_blk = blk_of(p);
        gpublas_int p_row_start = p_blk * NB;
        gpublas_int p_row_end   = min(p_row_start + NB, n);
        gpublas_int p_bs = p_row_end - p_row_start;

        // 把已解的 x(panel) 加载进共享内存
        if(tid < p_bs)
            xprev_sh[tid] = x_batch[(p_row_start + tid) * incx];
        __syncthreads();

        // 做更新:sx_sh -= op(A)(row, col_in_panel) * x_prev
        if(tx < bs)
        {
            gpublas_int row = row_start + tx;
            const bool   cached
                    = !first_blk
                      && (EFF_LOWER ? p_blk == my_blk - 1 : p_blk == my_blk + 1);
                      //cached:这个条件表示如果当前块不是第一个块,并且我们需要访问上一个已求解的块(反向替代时是 block_col == my_blk + 1,正向替代时是 block_col == my_blk - 1),则会从共享内存 sAoff 中读取数据,而不是重新从全局内存中读取。

            for(gpublas_int j_local = ty; j_local < p_bs; j_local += DIM_Y)
            {
                gpublas_int col = p_row_start + j_local;
                auto a = cached ? sAoff[(j_local - ty) / DIM_Y] : opA_ij<T>(A_batch, lda, row, col, TRANS);
                    val += a * xprev_sh[j_local];//每个线程会将 A_ij * x_j 加到 val 中。sx[i + ty] 是 x_j 的一部分,表示已经解决的部分向量值。
            }
        }
    }
      // 最终 将解 val 存储到合适的位置。
      sum_sh[ty * DIM_X + tx]  = val;
      __syncthreads();

        // 在 ty 上归约,并更新 sx_sh
        if(ty == 0 && tx < bs)
        {
            #pragma unroll
            for(gpublas_int r = 1; r < DIM_Y; ++r)
                val += sum_sh[r * DIM_X + tx];
            val = -val;
        }
        __syncthreads();
        #ifdef INV_AFTER
                if(((my_blk >= INV_AFTER && EFF_LOWER)
                    || (num_blocks - 1 - my_blk >= INV_AFTER && !EFF_LOWER))
                   && !row_is_remainder)
                {
                    gpublas_trsv_block_solve_inverse<T, !EFF_LOWER>(sAdiag, xprev_sh, val, sum_sh);
        
                    if(ty == 0 && tx < bs)
                        x_batch[(row_start + tx) * incx] = val;
                }
                else // same as without inversion
                {
                    // Solve the diagonal block
                    if(!EFF_LOWER)
                    gpublas_trsv_block_solve_upper<UNIT_DIAG>(sAdiag, bs, val);
                    else
                    gpublas_trsv_block_solve_lower<UNIT_DIAG>(sAdiag, bs, val);
        
                    // Store solved value into x
                    if(ty == 0 && tx < bs)
                        x_batch[(row_start + tx) * incx] = val;
                }
        #else
            // Solve the diagonal block
            if(!EFF_LOWER)
            gpublas_trsv_block_solve_upper<UNIT_DIAG>(sAdiag, bs, val);
            else
            gpublas_trsv_block_solve_lower<UNIT_DIAG>(sAdiag, bs, val);

                // 把本块解写回全局 x
            if(ty == 0 && tx < bs)
                x_batch[(row_start + tx) * incx] = val;
        #endif

    __threadfence();
    __syncthreads(); // for windows instability
    if(tid == 0)
        w_completed_sec[batch]++;

    //__threadfence();
    }
}

template <typename T>
gpublas_status trsv_blocked_launcher_2D(
    gpublas_handle handle,
    gpublas_fill uplo,
    gpublas_operation transA,
    gpublas_diagonal diag,
    gpublas_int n,
    const T* A, gpublas_int lda, gpublas_stride strideA,
    T* x, gpublas_int incx, gpublas_stride stride_x,
    gpublas_int batch_count)
{
    constexpr gpublas_int NB = 32;
    gpublas_int num_blocks = (n + NB - 1) / NB;

    hipStream_t stream = handle->get_stream();

    // ================================
    // Reuse dependency flags (A方案)
    // ================================
    auto st = handle->ensure_trsv_flags((size_t)batch_count);
    if(st != gpublas_status_success)
        return st;

    gpublas_int* w_completed_sec = handle->get_trsv_flags();

    // init flags to -1
    if(hipMemsetAsync(w_completed_sec, 0xFF,
                      sizeof(gpublas_int) * (size_t)batch_count,
                      stream) != hipSuccess)
        return gpublas_status_internal_error;

    // ================================
    // Kernel launch
    // ================================
    dim3 threads(DIM_X, DIM_Y, 1);

    int batches = std::min(batch_count, 64);   // 你原来的逻辑先不动
    dim3 grid(num_blocks, 1, batches);

    trsv_trans_t tmode;
    if(transA == gpublas_operation_none) tmode = trsv_trans_t::none;
    else if(transA == gpublas_operation_transpose) tmode = trsv_trans_t::trans;
    else tmode = trsv_trans_t::conj_trans;

    bool eff_lower = (tmode == trsv_trans_t::none)
                         ? (uplo == gpublas_fill_lower)
                         : (uplo == gpublas_fill_upper);

    bool unit_diag = (diag == gpublas_diagonal_unit);

    auto launch = [&](auto effLowerTag, auto unitTag, auto transTag)
    {
        constexpr bool EFF_LOWER = decltype(effLowerTag)::value;
        constexpr bool UNIT_DIAG = decltype(unitTag)::value;
        constexpr trsv_trans_t TRANS = decltype(transTag)::value;

        hipLaunchKernelGGL(
            (trsv_wavefront_kernel_2D<T, EFF_LOWER, UNIT_DIAG, TRANS>),
            grid, threads, 0, stream,
            (gpublas_int)n,
            A, (gpublas_int)lda, (long)strideA,
            x, (gpublas_int)incx, (long)stride_x,
            (gpublas_int)num_blocks,
            w_completed_sec, (gpublas_int)batch_count
        );
    };

    // ===== dispatch =====
    if(tmode == trsv_trans_t::none)
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
        }
    }
    else if(tmode == trsv_trans_t::trans)
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
        }
    }
    else
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
        }
    }

    return gpublas_status_success;
}


} // namespace internal

// ================================================================
// Public C API
// ================================================================

#define DEFINE_TRSV_INTERFACE(name, type)                                      \
gpublas_status name(                                                            \
    gpublas_handle handle,                                                      \
    gpublas_fill uplo,                                                          \
    gpublas_operation trans,                                                    \
    gpublas_diagonal diag,                                                      \
    gpublas_int n,                                                              \
    const type* A,                                                              \
    gpublas_int lda,                                                            \
    gpublas_stride strideA,                                                     \
    type* x,                                                                    \
    gpublas_int incx,                                                           \
    gpublas_stride stridex,                                                     \
    gpublas_int batch_count)                                               \
{                                                                               \
    return internal::trsv_blocked_launcher_2D<type>(               \
        handle, uplo, trans, diag,                                             \
        n, A, lda, strideA,                                                    \
        x, incx, stridex, batch_count);                                        \
}

DEFINE_TRSV_INTERFACE(gpublas_strsv_strided_batched, float)
DEFINE_TRSV_INTERFACE(gpublas_dtrsv_strided_batched, double)
DEFINE_TRSV_INTERFACE(gpublas_ctrsv_strided_batched, gpublas_float_complex)
DEFINE_TRSV_INTERFACE(gpublas_ztrsv_strided_batched, gpublas_double_complex)

#undef DEFINE_TRSV_INTERFACE

运行结果

|----------|-------|-----------------------|-----------------------|-----------|------------|--------------|
| Type | n | gpublas_time (ms) | rocblas_time (ms) | ratio | Result | max diff |
| F32 | 16 | 0.0474253 | 0.0509221 | 1.07373 | ok | 0.000000e+00 |
| F32 | 32 | 0.0498166 | 0.0512295 | 1.02836 | ok | 0.00E+00 |
| F32 | 64 | 0.0605633 | 0.0521927 | 0.861787 | ok | 9.83E-07 |
| F32 | 128 | 0.0769935 | 0.0707294 | 0.918641 | ok | 6.26E-06 |
| F32 | 256 | 0.0938193 | 0.10111 | 1.07771 | ok | 8.34E-06 |
| F32 | 512 | 0.119299 | 0.132825 | 1.11338 | wrong | 3.98E-05 |
| F32 | 1024 | 0.170791 | 0.150549 | 0.881482 | wrong | 7.26E-04 |
| F32 | 2048 | 0.269723 | 0.198702 | 0.736688 | wrong | 6.18E-02 |
| F64 | 16 | 0.0439261 | 0.0518829 | 1.18114 | ok | 0.00E+00 |
| F64 | 32 | 0.0510056 | 0.0529284 | 1.0377 | ok | 0.00E+00 |
| F64 | 64 | 0.0597598 | 0.0549134 | 0.918902 | ok | 1.83E-15 |
| F64 | 128 | 0.0825816 | 0.0720333 | 0.872268 | ok | 6.99E-15 |
| F64 | 256 | 0.0995566 | 0.110057 | 1.10547 | ok | 2.01E-14 |
| F64 | 512 | 0.128629 | 0.143944 | 1.11906 | ok | 8.62E-14 |
| F64 | 1024 | 0.184087 | 0.168539 | 0.915536 | ok | 1.48E-12 |
| F64 | 2048 | 0.291341 | 0.224877 | 0.771868 | ok | 8.11E-11 |
| C32 | 16 | 0.047882 | 0.0563785 | 1.17745 | ok | 5.96E-08 |
| C32 | 32 | 0.0531686 | 0.0563176 | 1.05923 | ok | 2.24E-07 |
| C32 | 64 | 0.0617148 | 0.0590301 | 0.956499 | ok | 9.83E-07 |
| C32 | 128 | 0.0874002 | 0.0783516 | 0.89647 | ok | 4.53E-06 |
| C32 | 256 | 0.112152 | 0.119338 | 1.06407 | ok | 8.88E-06 |
| C32 | 512 | 0.146413 | 0.164578 | 1.12407 | wrong | 6.08E-05 |
| C32 | 1024 | 0.227012 | 0.193773 | 0.853582 | wrong | 8.69E-04 |
| C32 | 2048 | 0.403958 | 0.261012 | 0.646138 | wrong | 8.62E-02 |
| C64 | 16 | 0.0466708 | 0.0470375 | 1.00786 | ok | 1.11E-16 |
| C64 | 32 | 0.0559716 | 0.0488645 | 0.873024 | ok | 4.16E-16 |
| C64 | 64 | 0.0660798 | 0.0606659 | 0.91807 | ok | 1.89E-15 |
| C64 | 128 | 0.096767 | 0.0852761 | 0.881251 | ok | 8.10E-15 |
| C64 | 256 | 0.125374 | 0.113984 | 0.909153 | ok | 1.68E-14 |
| C64 | 512 | 0.168339 | 0.143928 | 0.85499 | ok | 1.05E-13 |
| C64 | 1024 | 0.271448 | 0.206979 | 0.762497 | ok | 1.499023e-12 |
| C64 | 2048 | 0.468072 | 0.333167 | 0.711787 | ok | 1.528287e-10 |

Epoch 同步优化(Epoch-based Synchronization Optimization)

一、优化动机

在 TRSV 的 block 级并行实现中,通常需要一个全局依赖标志(flag)来保证 跨 block 的执行顺序

传统实现中,该依赖标志只记录"完成进度",因此 不同调用之间无法区分,必须在每次调用前对标志进行初始化(如清零或设为 -1)。

这种方式带来的问题是:

  • 每次调用都需要一次 显式初始化(memset 或 init kernel)

  • 对小规模矩阵或高频调用场景,初始化开销成为性能瓶颈

  • 依赖标志的生命周期被错误地绑定到"单次调用"


二、核心思想

Epoch 同步优化 通过为每一次 TRSV 调用分配一个唯一的 epoch(调用代数),将"调用身份"编码进依赖标志,从而实现:

依赖标志在不同调用之间的逻辑隔离,而无需物理初始化

即:

  • 依赖标志不再只表示"完成到哪一步"

  • 而是同时携带"属于哪一次调用"的信息


三、技术实现要点

1️⃣ Epoch 编码的依赖标志

将原本的单一整数型依赖标志,扩展为一个复合状态:

复制代码
| epoch (高位) | completed_step (低位) |
  • epoch:标识当前 TRSV 调用

  • completed_step:标识当前调用内的完成进度

epoch 与进度在同一个原子写入中更新。


2️⃣ 等待条件的改变

原有等待逻辑:

复制代码
等待 completed_step >= p

Epoch 优化后的等待逻辑:

复制代码
等待 (epoch == 当前调用 epoch) 且 (completed_step >= p)

当 epoch 不匹配时,依赖标志被视为 "未开始" ,等价于原先的 -1 状态。


3️⃣ 调用间隔离机制

  • 每次 TRSV 调用都会分配一个新的 epoch

  • 上一次调用留下的标志不会被误识别为当前调用的完成状态

  • 因此 不再需要在调用前清零或重置依赖标志

依赖标志只在首次分配或扩容时初始化一次。


4️⃣ 内存一致性保证

  • 写端:在写入依赖标志前,确保相关数据已通过内存栅栏提交

  • 读端:通过 acquire 语义读取依赖标志,保证数据可见性

这保证了:

后继 block 在观察到 epoch 匹配的完成标志时,一定能够看到前驱 block 写回的计算结果。


四、性能收益分析

Epoch 同步优化将依赖标志的初始化成本从:

复制代码
O(每次调用)

降低为:

复制代码
O(仅首次分配或扩容)

在以下场景中效果尤为显著:

  • 小规模 TRSV(n 较小)

  • batch_count = 1

  • 高频重复调用(例如迭代求解、时间步进)

该优化能够显著降低调用的固定调度开销,而不改变任何数值计算路径。


五、优化特性总结

  • ✅ 不改变 TRSV 算法或数学结果

  • ✅ 不增加 kernel 启动次数

  • ✅ 不依赖硬件特性,通用适用

  • ✅ 彻底消除"每次调用初始化依赖标志"的需求

handle代码

cpp 复制代码
//handle.hpp
#pragma once

#include "definitions.hpp"
#include "gpublas.h"
#include <array>
#include <cstddef>
#include <hip/hip_runtime.h>
#include <memory>
#include <tuple>
#include <type_traits>
#include <unistd.h>
#include <utility>

enum class Processor : int
{
    // matching enum used in hipGcnArch
    // only including supported types
    gfx803  = 803,
    gfx900  = 900,
    gfx906  = 906,
    gfx908  = 908,
    gfx90a  = 910,
};


/*******************************************************************************
 * \brief gpublas_handle is a structure holding the gpublas library context.
 * It must be initialized using gpublas_create_handle() and the returned handle mus
 * It should be destroyed at the end using gpublas_destroy_handle().
 * Exactly like CUBLAS, GPUBLAS only uses one stream for one API routine
 ******************************************************************************/
struct _gpublas_handle
{
private:

    // gpublas by default take the system default stream 0 users cannot create
    hipStream_t stream = 0;
    gpublas_int* d_trsv_flags = nullptr;//wkj添加
    size_t trsv_flags_capacity = 0; // 以 batch_count 为单位wkj添加
    uint32_t trsv_epoch = 1;   // 0 保留表示"未初始化/无效",wkj

    // Device ID is created at handle creation time and remains in effect for the life of the handle.
    const int device;

    // Arch ID is created at handle creation time and remains in effect for the life of the handle.
    const int arch;
    int       archMajor;
    int       archMajorMinor;

public:
    uint32_t next_trsv_epoch16();//wkj加的

    gpublas_status ensure_trsv_flags(size_t need);//wkj添加

    gpublas_int* get_trsv_flags() const { return d_trsv_flags; }
    _gpublas_handle();
    ~_gpublas_handle();

    _gpublas_handle(const _gpublas_handle&) = delete;
    _gpublas_handle& operator=(const _gpublas_handle&) = delete;


    int getDevice()
    {
        return device;
    }

    int getArch()
    {
        return arch;
    }

    int getArchMajor()
    {
        return archMajor;
    }

    int getArchMajorMinor()
    {
        return archMajorMinor;
    }

    // hipEvent_t pointers (for internal use only)
    hipEvent_t startEvent = nullptr;
    hipEvent_t stopEvent  = nullptr;

    // default pointer_mode is on host
    gpublas_pointer_mode pointer_mode = gpublas_pointer_mode_host;


    friend gpublas_status(::gpublas_set_stream)(_gpublas_handle*, hipStream_t);

    // // Temporarily change pointer mode, returning object which restores old mode when destroyed
    // auto push_pointer_mode(gpublas_pointer_mode mode)
    // {
    //     return _pushed_state<gpublas_pointer_mode>(pointer_mode, mode);
    // }

    // Return the current stream
    hipStream_t get_stream() const
    {
        return stream;
    }



};
cpp 复制代码
//handle.cpp
#include "handle.hpp"
#include <cstdarg>
#include <limits>

static inline int getActiveDevice()
{
    int device;
    THROW_IF_HIP_ERROR(hipGetDevice(&device));
    return device;
}

static Processor getActiveArch(int deviceId)
{
    hipDeviceProp_t deviceProperties;
    hipGetDeviceProperties(&deviceProperties, deviceId);
    // strip out xnack/ecc from name
    std::string deviceFullString(deviceProperties.gcnArchName);
    std::string deviceString = deviceFullString.substr(0, deviceFullString.find(":"));

    if(deviceString.find("gfx906") != std::string::npos)
    {
        return Processor::gfx906;
    }
    
    return static_cast<Processor>(0);
}

// ✅ 成员函数(带类限定符)
uint32_t _gpublas_handle::next_trsv_epoch16()
{
    // 只用低16位;0 作为无效值避免
    uint32_t e = (trsv_epoch++ & 0xFFFFu);
    if(e == 0) e = (trsv_epoch++ & 0xFFFFu);
    return e;
}

// ✅ 成员函数(带类限定符)
gpublas_status _gpublas_handle::ensure_trsv_flags(size_t need)
{
    if(need == 0) return gpublas_status_success;

    if(trsv_flags_capacity >= need && d_trsv_flags)
        return gpublas_status_success;

    int cur = -1;
    hipGetDevice(&cur);
    if(cur != device) hipSetDevice(device);

    size_t new_cap = 1;
    while(new_cap < need) new_cap <<= 1;

    gpublas_int* new_ptr = nullptr;
    if(hipMalloc(&new_ptr, new_cap * sizeof(gpublas_int)) != hipSuccess)
    {
        if(cur != device) hipSetDevice(cur);
        return gpublas_status_internal_error;
    }

    // ✅ 关键:新分配/扩容的 flags 必须清零一次(epoch=0)
    hipMemset(new_ptr, 0, new_cap * sizeof(gpublas_int));

    if(d_trsv_flags) hipFree(d_trsv_flags);

    d_trsv_flags = new_ptr;
    trsv_flags_capacity = new_cap;

    if(cur != device) hipSetDevice(cur);
    return gpublas_status_success;
}
/*******************************************************************************
 * constructor
 ******************************************************************************/
_gpublas_handle::_gpublas_handle()
    : device(getActiveDevice()) // active device is handle device
    , arch(static_cast<int>(getActiveArch(device)))
{
    archMajor      = arch / 100; // this may need to switch to string handling in the future
    archMajorMinor = arch / 10;

}

/*******************************************************************************
 * destructor
 ******************************************************************************/
//_gpublas_handle::~_gpublas_handle(){}wkj修改

/*******************************************************************************
 * destructor  ✅ 只保留这一个
 ******************************************************************************/
 _gpublas_handle::~_gpublas_handle()//wkj添加
 {
     int cur = -1;
     hipGetDevice(&cur);
     if(cur != device) hipSetDevice(device);
 
     if(d_trsv_flags)
         hipFree(d_trsv_flags);
     d_trsv_flags = nullptr;
     trsv_flags_capacity = 0;
 
     if(startEvent) hipEventDestroy(startEvent);
     if(stopEvent)  hipEventDestroy(stopEvent);
     startEvent = nullptr;
     stopEvent  = nullptr;
 
     if(cur != device) hipSetDevice(cur);
 }

运行代码

cpp 复制代码
//Epoch

#include "trsv_strided_batched.hpp"
#include "gpublas-auxiliary.h"
#include "handle.hpp"
#define LDA       32
#define DIM_X     32
#define DIM_Y     8

namespace internal {

    enum class trsv_trans_t { none, trans, conj_trans };
    template <typename T>
    __device__ __forceinline__ T zero_val() { return (T)0; }

    template <>
    __device__ __forceinline__ gpublas_float_complex zero_val<gpublas_float_complex>() { return {0.0f, 0.0f}; }

    template <>
    __device__ __forceinline__ gpublas_double_complex zero_val<gpublas_double_complex>() { return {0.0, 0.0}; }

    template <typename T>
    __device__ __forceinline__ T one_val() { return (T)1; }

    template <>
    __device__ __forceinline__ gpublas_float_complex one_val<gpublas_float_complex>() { return {1.0f, 0.0f}; }

    template <>
    __device__ __forceinline__ gpublas_double_complex one_val<gpublas_double_complex>() { return {1.0, 0.0}; }

    __device__ __forceinline__ uint32_t pack_flag(uint32_t epoch16, uint32_t comp16)
    {
        return (epoch16 << 16) | (comp16 & 0xFFFFu);
    }
    __device__ __forceinline__ uint32_t flag_epoch(uint32_t v) { return v >> 16; }
    __device__ __forceinline__ uint32_t flag_comp (uint32_t v) { return v & 0xFFFFu; }
    
    // epoch mismatch 时,把 comp 当作 -1 (0xFFFF)
    __device__ __forceinline__ int get_comp_or_minus1(uint32_t v, uint32_t epoch16)
    {
        return (flag_epoch(v) == epoch16) ? (int)flag_comp(v) : -1;
    }
    
    // 如果你文件里已经有 opA_ij / conj_if_needed,就别在这里再定义了。
    // 下面代码直接调用 opA_ij / one_val / zero_val.
    
    // Solve the A21 section during A inversion (lower-block form)
    // A = [A11 0; A21 A22]
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG>
    __device__ __forceinline__
    void gpublas_invert_solve_A21(const T* __restrict__ A11,
                                 T* __restrict__       A21,
                                 const T* __restrict__ A22,
                                 T* __restrict__       sx)
    {
        const gpublas_int tid      = DIM_X * threadIdx.y + threadIdx.x;
        const gpublas_int ntid     = DIM_X * DIM_Y;
        const gpublas_int tx       = tid % N;
        const gpublas_int ty       = tid / N;
        const gpublas_int col_span = ntid / N;
    
        for(gpublas_int i = 0; i < N; i += col_span)
        {
            gpublas_int col  = i + ty;
            bool        skip = (col >= N);
    
            T val = zero_val<T>();
            if(!skip)
            {
                // val = -A21 * A11^{-1} piece (这里 A11 已在递归中被"inverted"到存储布局里)
                for(gpublas_int j = i; j < N; j++)
                {
                    if(j + ty < N)
                        val += A21[(j + ty) * LDA + tx] * A11[col * LDA + (j + ty)];
                }
                val = -val;
            }
    
            // Forward substitution with A22 (diagonal already holds inv if you preprocessed like you do)
            for(gpublas_int j = 0; j < N; j++)
            {
                if(tx == j && !skip)
                {
                    if(!UNIT_DIAG)
                        val *= A22[j * LDA + j]; // diag is inv(diag) in your preprocess style
                    sx[ty] = val;
                }
                __syncthreads();
                if(tx > j && !skip)
                {
                    val += A22[j * LDA + tx] * sx[ty];
                }
                __syncthreads();
            }
    
            if(!skip)
                A21[col * LDA + tx] = -val;
    
            __syncthreads();
        }
    }
    
    // Solve the A12 section during A inversion (upper-block form)
    // A = [A11 A12; 0 A22]
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG>
    __device__ __forceinline__
    void gpublas_invert_solve_A12(const T* __restrict__ A11,
                                 T* __restrict__       A12,
                                 const T* __restrict__ A22,
                                 T* __restrict__       sx)
    {
        const gpublas_int tid      = DIM_X * threadIdx.y + threadIdx.x;
        const gpublas_int ntid     = DIM_X * DIM_Y;
        const gpublas_int tx       = tid % N;
        const gpublas_int ty       = tid / N;
        const gpublas_int col_span = ntid / N;
    
        for(gpublas_int i = N - 1; i >= 0; i -= col_span)
        {
            gpublas_int col  = i - ty;
            bool        skip = (col < 0);
    
            T val = zero_val<T>();
            if(!skip)
            {
                for(gpublas_int j = 0; j < N; j++)
                {
                    if(j <= col)
                        val += A12[j * LDA + tx] * A22[col * LDA + j];
                }
            }
    
            // Back substitution with A11 (A11 not yet fully inverted, but diag is)
            for(gpublas_int j = N - 1; j >= 0; j--)
            {
                if(tx == j && !skip)
                {
                    if(!UNIT_DIAG)
                        val *= A11[j * LDA + j]; // diag is inv(diag) in your preprocess style
                    sx[ty] = -val;
                }
                __syncthreads();
                if(tx < j && !skip)
                {
                    val -= A11[j * LDA + tx] * sx[ty];
                }
                __syncthreads();
            }
    
            if(!skip)
                A12[col * LDA + tx] = val;
    
            __syncthreads();
        }
    }
    
    template <typename T>
    __device__ __forceinline__
    void gpublas_trsv_transpose(const gpublas_int n,
                                const T* __restrict__ A,
                                T* __restrict__       at)
    {
        if(threadIdx.y == 0 && threadIdx.x < n)
        {
            for(gpublas_int i = 0; i < n; i++)
                at[i * LDA + threadIdx.x] = A[threadIdx.x * LDA + i];
        }
    }
    
    template <gpublas_int n>
    static constexpr bool equals_two = false;
    template <>
    constexpr bool equals_two<2> = true;
    
    // Invert (lower) - base case N=2
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert(T* __restrict__ A, T* __restrict__ sx)
    {
        if(threadIdx.x == 0 && threadIdx.y == 0)
        {
            if(UNIT_DIAG)
            {
                A[0]       = one_val<T>();
                A[LDA + 1] = one_val<T>();
            }
            else
            {
                // diag already stores inv(diag), so A[0], A[LDA+1] are done.
                // offdiag: A[1] = A[1] * (inv(d0) * inv(d1)) with sign already folded by your preprocess choice
                A[1] = A[1] * (A[0] * A[LDA + 1]);
            }
    
            if(TRANS)
            {
                A[LDA] = A[1];
            }
        }
    }
    
    // Invert (lower) - recursive
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<!equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert(T* __restrict__ A, T* __restrict__ sx)
    {
        gpublas_trsv_invert<T, N / 2, UNIT_DIAG, TRANS>(A, sx);
        __syncthreads();
    
        gpublas_invert_solve_A21<T, N / 2, UNIT_DIAG>(
            A, &A[N / 2], &A[(LDA + 1) * (N / 2)], sx);
    
        if(TRANS)
        {
            __syncthreads();
            gpublas_trsv_transpose<T>(N / 2, &A[N / 2], &A[(N / 2) * LDA]);
        }
        __syncthreads();
    
        gpublas_trsv_invert<T, N / 2, UNIT_DIAG, TRANS>(
            &A[(LDA + 1) * (N / 2)], sx);
    }
    
    // Invert (upper) - base case N=2
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert_upper(T* __restrict__ A, T* __restrict__ sx)
    {
        if(threadIdx.x == 0 && threadIdx.y == 0)
        {
            if(UNIT_DIAG)
            {
                A[0]       = one_val<T>();
                A[LDA + 1] = one_val<T>();
            }
            else
            {
                // upper offdiag sits at A[LDA]
                A[LDA] = A[LDA] * (A[0] * A[LDA + 1]);
            }
    
            if(TRANS)
            {
                A[1] = A[LDA];
            }
        }
    }
    
    // Invert (upper) - recursive
    template <typename T,
              gpublas_int N,
              bool        UNIT_DIAG,
              bool        TRANS,
              std::enable_if_t<!equals_two<N>, gpublas_int> = 0>
    __device__ __forceinline__
    void gpublas_trsv_invert_upper(T* __restrict__ A, T* __restrict__ sx)
    {
        gpublas_trsv_invert_upper<T, N / 2, UNIT_DIAG, TRANS>(
            &A[(LDA + 1) * (N / 2)], sx);
        __syncthreads();
    
        // A12 solve
        gpublas_invert_solve_A12<T, N / 2, UNIT_DIAG>(
            A, &A[(N / 2) * LDA], &A[(LDA + 1) * (N / 2)], sx);
    
        if(TRANS)
        {
            __syncthreads();
            gpublas_trsv_transpose<T>(N / 2, &A[(N / 2) * LDA], &A[(N / 2)]);
        }
        __syncthreads();
    
        gpublas_trsv_invert_upper<T, N / 2, UNIT_DIAG, TRANS>(A, sx);
    }
    
    // Optional: block_solve using inverse (kept as-is but renamed)
    // NOTE: if you already have sum_sh / xprev_sh usage, integrate as needed.
    template <typename T, bool UPPER>
    __device__ __forceinline__
    void gpublas_trsv_block_solve_inverse(const T* __restrict__ Ainv,
                                          T* __restrict__       sx,
                                          T&                    val,
                                          T* __restrict__       sum)
    {
        Ainv += threadIdx.y * DIM_X + threadIdx.x;
        sx   += threadIdx.y;
    
        if(threadIdx.y == 0)
            sx[threadIdx.x] = val;
    
        __syncthreads();
    
        val = zero_val<T>();
        if(!UPPER)
        {
            for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
            {
                if(threadIdx.x >= threadIdx.y + i)
                    val += Ainv[i * DIM_X] * sx[i];
            }
            sum[threadIdx.y * DIM_X + threadIdx.x] = val;
            __syncthreads();
    
            if(threadIdx.y == 0)
            {
                for(gpublas_int i = 1; i < DIM_Y; i++)
                    val += sum[i * DIM_X + threadIdx.x];
            }
        }
        else
        {
            for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
            {
                if(threadIdx.x <= i + threadIdx.y)
                    val += Ainv[i * DIM_X] * sx[i];
            }
            sum[threadIdx.y * DIM_X + threadIdx.x] = val;
            __syncthreads();
    
            if(threadIdx.y == 0)
            {
                for(gpublas_int i = 1; i < DIM_Y; i++)
                    val += sum[i * DIM_X + threadIdx.x];
            }
        }
    }

template <typename T>
__device__ __forceinline__ T conj_if_needed(const T& v, bool conj)
{
    // real types: conj is no-op
    return v;
}

template <>
__device__ __forceinline__ gpublas_float_complex
conj_if_needed(const gpublas_float_complex& v, bool conj)
{
    if(!conj) return v;
    return {v.x, -v.y};
}

template <>
__device__ __forceinline__ gpublas_double_complex
conj_if_needed(const gpublas_double_complex& v, bool conj)
{
    if(!conj) return v;
    return {v.x, -v.y};
}

// op(A)(i,j)
template <typename T>
__device__ __forceinline__ T opA_ij(
    const T* __restrict__ A, gpublas_int lda,
    gpublas_int i, gpublas_int j,
    trsv_trans_t trans)
{
    if(trans == trsv_trans_t::none)
    {
        return A[i + j * lda];
    }
    else if(trans == trsv_trans_t::trans)
    {
        return A[j + i * lda];
    }
    else // conj_trans
    {
        T v = A[j + i * lda];
        return conj_if_needed(v, true);
    }
}
template <bool UNIT_DIAG, typename T>
__device__ void gpublas_trsv_block_solve_lower(const T* __restrict__ A,
                                                       gpublas_int bs,
                                                       T&          val)
{
    __shared__ T xs;

    // Iterate forwards
    for(gpublas_int i = 0; i < bs; i++)
    {
        // Solve current element
        if(threadIdx.x == i && threadIdx.y == 0)
        {
            if(!UNIT_DIAG)
                val *= A[i * bs + i];
            xs = val;
        }

        __syncthreads();

        // Update future elements with solved one
        if(threadIdx.x > i && threadIdx.y == 0)
        {
            val += A[i * bs + threadIdx.x] * xs;
        }

        __syncthreads();
    }
}

template <bool UNIT_DIAG, typename T>
__device__ void gpublas_trsv_block_solve_upper(const T* __restrict__ A,
                                                       gpublas_int bs,
                                                       T&          val)
{
    __shared__ T xs;

    for(gpublas_int i = bs - 1; i >= 0; i--)
    {
        // Solve current element
        if(threadIdx.x == i && threadIdx.y == 0)
        {
            if(!UNIT_DIAG)
                val *= A[i * bs + i];
            xs = val;
        }

        __syncthreads();

        // Update future elements with solved one
        if(threadIdx.x < i && threadIdx.y == 0 )
        {
            val += A[i * bs + threadIdx.x] * xs;
        }

        __syncthreads();
    }
}
#define INV_AFTER 5
// ------------------------------
// 通用 blocked TRSV kernel
// NB=32, 2D threads (DIM_X=32 lanes x DIM_Y rows)
// ------------------------------
template <typename T, bool EFF_LOWER, bool UNIT_DIAG, trsv_trans_t TRANS>
__global__ void trsv_wavefront_kernel_2D(
    gpublas_int n,
    const T* __restrict__ A, gpublas_int lda, long strideA,
    T* __restrict__ x, gpublas_int incx, long stridex,
    gpublas_int num_blocks,
    gpublas_int* __restrict__ w_completed_sec,
    gpublas_int batch_count,   // ✅ 新增
    uint32_t epoch16
)
{
    constexpr gpublas_int NB = 32;
    constexpr bool DO_TRANS = (TRANS != trsv_trans_t::none);
    for(gpublas_int batch = blockIdx.z; batch < batch_count; batch += gridDim.z)
    {
    
    __shared__ T sAdiag[NB * NB];                 // diagonal block缓存对角块
    __shared__ T xprev_sh[NB];                    // 缓存前面已经求解完的panel的x用于更新
    __shared__ T sum_sh[DIM_X * DIM_Y];           // 分块规约
    T sAoff[DIM_X / DIM_Y];//紧邻对角块的非对角block,寄存器,每个线程负责 4 个元素


    gpublas_int logical_i = blockIdx.x;//当前对应的逻辑块号
    const gpublas_int num_blocks = gridDim.x;

    gpublas_int tx = threadIdx.x;
    gpublas_int ty = threadIdx.y;
    gpublas_int tid = ty * blockDim.x + tx;//全局线性ID
    gpublas_int nthreads = blockDim.x * blockDim.y;//一个CTA总线程数

    auto blk_of = [&](gpublas_int logical_blk) {
        return EFF_LOWER ? logical_blk : (num_blocks - 1 - logical_blk);
    };//逻辑块对应到真实块号

    gpublas_int my_blk = blk_of(logical_i);
    gpublas_int row_start = my_blk * NB;//开始行
    gpublas_int row_end   = min(row_start + NB, n);//结束行
    gpublas_int bs = row_end - row_start;//块的真实大小

    const T* A_batch = A + (long)batch * strideA;
    T* x_batch = x + (long)batch * stridex;
    const gpublas_int remainder = n % DIM_X;
    const bool row_is_remainder = ((n - 1) / DIM_X == my_blk && remainder != 0);
    const bool first_blk = EFF_LOWER ? my_blk == 0 : my_blk == num_blocks - 1;
    if(!first_blk)//提前把"紧邻对角块"的矩阵块加载到寄存器 sAoff 中
    {
        const gpublas_int block_col = EFF_LOWER ?my_blk - 1 :  my_blk + 1;//确定邻接block
        const gpublas_int local_col = DO_TRANS ? my_blk * DIM_X + tx : block_col * DIM_X + ty;//计算矩阵访问位置
        const gpublas_int local_row = DO_TRANS ? block_col * DIM_X + ty : my_blk * DIM_X + tx;
        const size_t      A_idx     = (local_row) + (local_col)*lda;//列主序访问

        for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
        {
            const size_t i_idx = DO_TRANS ? i : i * lda;//连续行;跨列

            __syncthreads();
            if(DO_TRANS ? (local_row + i < n && local_col < n)
                        : (local_row < n && local_col + i < n))//最后一个 block 可能不满 DIM_X
                sAoff[i / DIM_Y] = A[A_idx + i_idx];//每个线程加载一条"竖向切片"
            else
                sAoff[i / DIM_Y] = zero_val<T>();
        }
    }
#ifdef INV_AFTER
        bool cache_transpose = (DO_TRANS && EFF_LOWER && num_blocks - 1 - my_blk < INV_AFTER)
                               || (DO_TRANS && !EFF_LOWER && my_blk < INV_AFTER)
                               || (DO_TRANS && row_is_remainder);
#else
    bool cache_transpose = DO_TRANS; // works for ALL without inversion method
#endif
    //对角块预处理
    gpublas_int row = tx;//每个线程负责一行
    for(gpublas_int i = 0; i < DIM_X; i += DIM_Y)
    {
        const gpublas_int col    = ty + i;//thread y 方向协同读取列,当前线程负责矩阵列
        const gpublas_int sA_idx = cache_transpose ? col + bs * row : col * bs + row;
        const size_t      A_idx
            = (my_blk * DIM_X * lda + my_blk * DIM_X) + col * lda + row;//A(my_blk * DIM_X + row,my_blk * DIM_X + col)
        const gpublas_int total_col = my_blk * DIM_X + col;//当前元素在整矩阵中的真实坐标
        const gpublas_int total_row = my_blk * DIM_X + row;

        if(((row > col && EFF_LOWER) || (col > row && !EFF_LOWER)) && row < bs
                   && col < bs)//只允许访问真实存在的矩阵区域
        {
            sAdiag[sA_idx] = -opA_ij<T>(A_batch, lda, total_row, total_col, TRANS);
        }
        else if(!UNIT_DIAG && row == col && row < bs)
        {
            // Dividing here so we can just multiply later.
            sAdiag[sA_idx] = one_val<T>() / opA_ij<T>(A_batch, lda, total_row, total_col, TRANS);
        }
        else if(col < DIM_X
                && row < DIM_X) // In off-triangular portion or past end of remainder
        {
            sAdiag[sA_idx] = zero_val<T>();
        }
    }
    __syncthreads();

#ifdef INV_AFTER
        if(((my_blk >= INV_AFTER && EFF_LOWER)
            || (num_blocks - 1 - my_blk >= INV_AFTER && !EFF_LOWER))
            && !row_is_remainder)
        {
            if(EFF_LOWER)
                gpublas_trsv_invert<T, DIM_X, UNIT_DIAG, DO_TRANS>(sAdiag, sum_sh);
            else
                gpublas_trsv_invert_upper<T, DIM_X, UNIT_DIAG, DO_TRANS>(sAdiag, sum_sh);
        }
#endif
    __syncthreads();
    T val{};                 // 推荐:值初始化 -> (0,0)
//val = 当前 RHS 中间结果


    // -------------------------------
    // 用第一行线程拷贝到共享内存
    if(ty == 0 && tx < bs)
        val = -x_batch[(row_start + tx) * incx];//取对应真实块的b向量
    __syncthreads();


    gpublas_int col_done = -1;//已经等待过的 block index
    // -------------------------------
    // 主循环:依次处理 panel(更新)直到轮到自己(求解并发布)
    for(gpublas_int p = 0; p < logical_i; ++p)
    {
        uint32_t* flags = reinterpret_cast<uint32_t*>(w_completed_sec);

        if(tid == 0)
        {
            // 等到 "同一 epoch 且 comp >= p"
            // epoch mismatch 时 comp 视为 -1,因此会一直等到本次 epoch 的写入出现
            while(true)
            {
                uint32_t v = __atomic_load_n(&flags[batch], __ATOMIC_ACQUIRE);
                int comp = get_comp_or_minus1(v, epoch16);
                if(comp >= p) break;

                // 可选:退避,降低 L2 压力(AMD)
                __builtin_amdgcn_s_sleep(1);
            }
        }
        __syncthreads();

        // 这里这句严格说可以删(因为上面的 Acquire 已经是"获取"语义)
        // 但为了最小改动/保守起见,你可以暂时保留:
        __threadfence();
        __syncthreads();

        


        // 计算该 panel 的真实块位置与大小
        gpublas_int p_blk = blk_of(p);
        gpublas_int p_row_start = p_blk * NB;
        gpublas_int p_row_end   = min(p_row_start + NB, n);
        gpublas_int p_bs = p_row_end - p_row_start;

        // 把已解的 x(panel) 加载进共享内存
        if(tid < p_bs)
            xprev_sh[tid] = x_batch[(p_row_start + tid) * incx];
        __syncthreads();

        // 做更新:sx_sh -= op(A)(row, col_in_panel) * x_prev
        if(tx < bs)
        {
            gpublas_int row = row_start + tx;
            const bool   cached
                    = !first_blk
                      && (EFF_LOWER ? p_blk == my_blk - 1 : p_blk == my_blk + 1);
                      //cached:这个条件表示如果当前块不是第一个块,并且我们需要访问上一个已求解的块(反向替代时是 block_col == my_blk + 1,正向替代时是 block_col == my_blk - 1),则会从共享内存 sAoff 中读取数据,而不是重新从全局内存中读取。

            for(gpublas_int j_local = ty; j_local < p_bs; j_local += DIM_Y)
            {
                gpublas_int col = p_row_start + j_local;
                auto a = cached ? sAoff[(j_local - ty) / DIM_Y] : opA_ij<T>(A_batch, lda, row, col, TRANS);
                    val += a * xprev_sh[j_local];//每个线程会将 A_ij * x_j 加到 val 中。sx[i + ty] 是 x_j 的一部分,表示已经解决的部分向量值。
            }
        }
    }
      // 最终 将解 val 存储到合适的位置。
      sum_sh[ty * DIM_X + tx]  = val;
      __syncthreads();

        // 在 ty 上归约,并更新 sx_sh
        if(ty == 0 && tx < bs)
        {
            #pragma unroll
            for(gpublas_int r = 1; r < DIM_Y; ++r)
                val += sum_sh[r * DIM_X + tx];
            val = -val;
        }
        __syncthreads();
        #ifdef INV_AFTER
                if(((my_blk >= INV_AFTER && EFF_LOWER)
                    || (num_blocks - 1 - my_blk >= INV_AFTER && !EFF_LOWER))
                   && !row_is_remainder)
                {
                    gpublas_trsv_block_solve_inverse<T, !EFF_LOWER>(sAdiag, xprev_sh, val, sum_sh);
        
                    if(ty == 0 && tx < bs)
                        x_batch[(row_start + tx) * incx] = val;
                }
                else // same as without inversion
                {
                    // Solve the diagonal block
                    if(!EFF_LOWER)
                    gpublas_trsv_block_solve_upper<UNIT_DIAG>(sAdiag, bs, val);
                    else
                    gpublas_trsv_block_solve_lower<UNIT_DIAG>(sAdiag, bs, val);
        
                    // Store solved value into x
                    if(ty == 0 && tx < bs)
                        x_batch[(row_start + tx) * incx] = val;
                }
        #else
            // Solve the diagonal block
            if(!EFF_LOWER)
            gpublas_trsv_block_solve_upper<UNIT_DIAG>(sAdiag, bs, val);
            else
            gpublas_trsv_block_solve_lower<UNIT_DIAG>(sAdiag, bs, val);

                // 把本块解写回全局 x
            if(ty == 0 && tx < bs)
                x_batch[(row_start + tx) * incx] = val;
        #endif

        uint32_t* flags = reinterpret_cast<uint32_t*>(w_completed_sec);

        __threadfence();
        __syncthreads(); // 你原来的保留
        
        if(tid == 0)
        {
            // 读出当前 comp(epoch mismatch -> -1)
            uint32_t oldv = __atomic_load_n(&flags[batch], __ATOMIC_RELAXED);
            int comp = get_comp_or_minus1(oldv, epoch16);
        
            // 你原来是 ++:从 -1 -> 0 -> 1 ...
            uint32_t new_comp = (uint32_t)(comp + 1);
        
            // Release store:发布"本次 epoch 的进度"
            __atomic_store_n(&flags[batch],
                             pack_flag(epoch16, new_comp),
                             __ATOMIC_RELEASE);
        }
    }
        
}

template <typename T>
gpublas_status trsv_blocked_launcher_2D(
    gpublas_handle handle,
    gpublas_fill uplo,
    gpublas_operation transA,
    gpublas_diagonal diag,
    gpublas_int n,
    const T* A, gpublas_int lda, gpublas_stride strideA,
    T* x, gpublas_int incx, gpublas_stride stride_x,
    gpublas_int batch_count)
{
    constexpr gpublas_int NB = 32;
    gpublas_int num_blocks = (n + NB - 1) / NB;

    hipStream_t stream = handle->get_stream();

    // ================================
    // Reuse dependency flags (A方案)
    // ================================

    auto st = handle->ensure_trsv_flags((size_t)batch_count);
    if(st != gpublas_status_success) return st;
    
    gpublas_int* w_completed_sec = handle->get_trsv_flags();
    
    // ✅ 取本次调用的 epoch(16-bit)
    uint32_t epoch16 = handle->next_trsv_epoch16();
    

    // ================================
    // Kernel launch
    // ================================
    dim3 threads(DIM_X, DIM_Y, 1);

    int batches = std::min(batch_count, 64);   // 你原来的逻辑先不动
    dim3 grid(num_blocks, 1, batches);

    trsv_trans_t tmode;
    if(transA == gpublas_operation_none) tmode = trsv_trans_t::none;
    else if(transA == gpublas_operation_transpose) tmode = trsv_trans_t::trans;
    else tmode = trsv_trans_t::conj_trans;

    bool eff_lower = (tmode == trsv_trans_t::none)
                         ? (uplo == gpublas_fill_lower)
                         : (uplo == gpublas_fill_upper);

    bool unit_diag = (diag == gpublas_diagonal_unit);

    auto launch = [&](auto effLowerTag, auto unitTag, auto transTag)
    {
        constexpr bool EFF_LOWER = decltype(effLowerTag)::value;
        constexpr bool UNIT_DIAG = decltype(unitTag)::value;
        constexpr trsv_trans_t TRANS = decltype(transTag)::value;

        hipLaunchKernelGGL(
            (trsv_wavefront_kernel_2D<T, EFF_LOWER, UNIT_DIAG, TRANS>),
            grid, threads, 0, stream,
            (gpublas_int)n,
            A, (gpublas_int)lda, (long)strideA,
            x, (gpublas_int)incx, (long)stride_x,
            (gpublas_int)num_blocks,
            w_completed_sec, (gpublas_int)batch_count,
            epoch16                       // ✅ 新增
        );
    };

    // ===== dispatch =====
    if(tmode == trsv_trans_t::none)
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::none>{});
        }
    }
    else if(tmode == trsv_trans_t::trans)
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::trans>{});
        }
    }
    else
    {
        if(eff_lower)
        {
            if(unit_diag) launch(std::true_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
            else launch(std::true_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
        }
        else
        {
            if(unit_diag) launch(std::false_type{}, std::true_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
            else launch(std::false_type{}, std::false_type{},
                std::integral_constant<trsv_trans_t, trsv_trans_t::conj_trans>{});
        }
    }

    return gpublas_status_success;
}
} // namespace internal

// ================================================================
// Public C API
// ================================================================

#define DEFINE_TRSV_INTERFACE(name, type)                                      \
gpublas_status name(                                                            \
    gpublas_handle handle,                                                      \
    gpublas_fill uplo,                                                          \
    gpublas_operation trans,                                                    \
    gpublas_diagonal diag,                                                      \
    gpublas_int n,                                                              \
    const type* A,                                                              \
    gpublas_int lda,                                                            \
    gpublas_stride strideA,                                                     \
    type* x,                                                                    \
    gpublas_int incx,                                                           \
    gpublas_stride stridex,                                                     \
    gpublas_int batch_count)                                               \
{                                                                               \
    return internal::trsv_blocked_launcher_2D<type>(               \
        handle, uplo, trans, diag,                                             \
        n, A, lda, strideA,                                                    \
        x, incx, stridex, batch_count);                                        \
}

DEFINE_TRSV_INTERFACE(gpublas_strsv_strided_batched, float)
DEFINE_TRSV_INTERFACE(gpublas_dtrsv_strided_batched, double)
DEFINE_TRSV_INTERFACE(gpublas_ctrsv_strided_batched, gpublas_float_complex)
DEFINE_TRSV_INTERFACE(gpublas_ztrsv_strided_batched, gpublas_double_complex)

#undef DEFINE_TRSV_INTERFACE

运行结果

|----------|-------|-----------------------|-----------------------|-----------|------------|--------------|
| Type | n | gpublas_time (ms) | rocblas_time (ms) | ratio | Result | max diff |
| F32 | 16 | 0.0384554 | 0.0492736 | 1.28132 | ok | 0.000000e+00 |
| F32 | 32 | 0.043688 | 0.0501442 | 1.14778 | ok | 0.000000e+00 |
| F32 | 64 | 0.0521231 | 0.0529281 | 1.01545 | ok | 9.834766e-07 |
| F32 | 128 | 0.0701885 | 0.0688272 | 0.980604 | ok | 6.258488e-06 |
| F32 | 256 | 0.0887099 | 0.100734 | 1.13555 | ok | 8.344650e-06 |
| F32 | 512 | 0.11446 | 0.135409 | 1.18303 | wrong | 3.977120e-05 |
| F32 | 1024 | 0.165897 | 0.153069 | 0.922673 | wrong | 7.258356e-04 |
| F32 | 2048 | 0.267934 | 0.201444 | 0.751843 | wrong | 6.179182e-02 |
| F64 | 16 | 0.0400395 | 0.0521118 | 1.30151 | ok | 0.000000e+00 |
| F64 | 32 | 0.0429384 | 0.0535718 | 1.24764 | ok | 0.000000e+00 |
| F64 | 64 | 0.0536112 | 0.0531339 | 0.991097 | ok | 1.831868e-15 |
| F64 | 128 | 0.0741696 | 0.071985 | 0.970546 | ok | 6.994405e-15 |
| F64 | 256 | 0.0953009 | 0.109126 | 1.14507 | ok | 2.009504e-14 |
| F64 | 512 | 0.123693 | 0.146716 | 1.18613 | ok | 8.615331e-14 |
| F64 | 1024 | 0.178045 | 0.168268 | 0.945085 | ok | 1.481038e-12 |
| F64 | 2048 | 0.290869 | 0.227173 | 0.781013 | ok | 8.114365e-11 |
| C32 | 16 | 0.0396814 | 0.0566436 | 1.42746 | ok | 5.960464e-08 |
| C32 | 32 | 0.044774 | 0.0569825 | 1.27267 | ok | 2.235174e-07 |
| C32 | 64 | 0.0556244 | 0.0577805 | 1.03876 | ok | 9.834766e-07 |
| C32 | 128 | 0.0802432 | 0.0759247 | 0.946183 | ok | 4.529953e-06 |
| C32 | 256 | 0.104033 | 0.117173 | 1.12631 | ok | 8.881092e-06 |
| C32 | 512 | 0.137956 | 0.164481 | 1.19227 | wrong | 6.079674e-05 |
| C32 | 1024 | 0.220238 | 0.19721 | 0.89544 | wrong | 8.689165e-04 |
| C32 | 2048 | 0.397301 | 0.259538 | 0.653252 | wrong | 8.624413e-02 |
| C64 | 16 | 0.0425746 | 0.0461916 | 1.08496 | ok | 1.110223e-16 |
| C64 | 32 | 0.0479257 | 0.047647 | 0.994185 | ok | 4.163336e-16 |
| C64 | 64 | 0.0607162 | 0.0612375 | 1.00859 | ok | 1.887379e-15 |
| C64 | 128 | 0.0902493 | 0.0853255 | 0.945442 | ok | 8.104628e-15 |
| C64 | 256 | 0.11969 | 0.111169 | 0.928813 | ok | 1.676437e-14 |
| C64 | 512 | 0.162828 | 0.144605 | 0.888086 | ok | 1.048051e-13 |
| C64 | 1024 | 0.267098 | 0.206569 | 0.773385 | ok | 1.499023e-12 |
| C64 | 2048 | 0.471593 | 0.332993 | 0.706102 | ok | 1.528287e-10 |

相关推荐
代码游侠9 小时前
C语言核心概念复习——网络协议与TCP/IP
linux·运维·服务器·网络·算法
2301_763472469 小时前
C++20概念(Concepts)入门指南
开发语言·c++·算法
abluckyboy10 小时前
Java 实现求 n 的 n^n 次方的最后一位数字
java·python·算法
园小异10 小时前
2026年技术面试完全指南:从算法到系统设计的实战突破
算法·面试·职场和发展
m0_7066532310 小时前
分布式系统安全通信
开发语言·c++·算法
天天爱吃肉821811 小时前
跟着创意天才周杰伦学新能源汽车研发测试!3年从工程师到领域专家的成长秘籍!
数据库·python·算法·分类·汽车
alphaTao11 小时前
LeetCode 每日一题 2026/2/2-2026/2/8
算法·leetcode
甄心爱学习11 小时前
【leetcode】判断平衡二叉树
python·算法·leetcode
颜酱12 小时前
从二叉树到衍生结构:5种高频树结构原理+解析
javascript·后端·算法