C = alpha*A*B + beta*C
每个线程负责一个C(i , j) 元素的计算。
1, gemm nn
__global__
void dgemm_NN( lint M,
lint N,
lint K,
double* Ad,
lint lda,
double* Bd,
lint ldb,
double* Cd,
lint ldc,
double alpha,
double beta )
{
lint i = blockIdx.x * blockDim.x + threadIdx.x;
lint j = blockIdx.y * blockDim.y + threadIdx.y;
double sigma = 0.0;
for(lint k=0; k<K; k++)
sigma += Ad[i + k*lda]*Bd[k + j*ldb];
Cd[i + j*ldc] = alpha*sigma + beta*Cd[i + j*ldc];
}
2, gemm nt
__global__
void dgemm_NT( int opA,
int opB,
lint M,
lint N,
lint K,
double* Ad,
lint lda,
double* Bd,
lint ldb,
double* Cd,
lint ldc,
double alpha,
double beta )
{
lint i = blockIdx.x * blockDim.x + threadIdx.x;
lint j = blockIdx.y * blockDim.y + threadIdx.y;
double sigma = 0.0;
if(i<M && j<N){
for(lint k=0; k<K; k++)
sigma += Ad[i + k*lda]*Bd[k*ldb + j];
Cd[i + j*ldc] = alpha*sigma + beta*Cd[i + j*ldc];
}
}