CUDA GEMM:从基础写法到工程优化
SGEMM 要做的事很简单:
C = alpha * A * B + beta * CA是M x KB是K x NC是M x N
默认约定:
- 矩阵按
row-major存储 - 后面的高性能版本主要面向大矩阵 benchmark,很多 kernel 默认
M/N按 block tile 对齐、K按BK对齐、N满足float4对齐 - 真正工程里如果要通吃任意尺寸,还要补边界判断和尾块处理
参数记号:
BM / BN / BK:thread block 级别 tileTM / TN:每个线程在寄存器里负责的输出 tileWM / WN:每个 warp 负责的输出 tile
优化主线:线程映射 → shared memory → register tiling → vectorize → 解决 bank conflict → autotune → warp tiling → double buffering → async copy
1. 最基础版本:一个线程算一个输出元素
__global__ void sgemm_naive(int M, int N, int K, float alpha,
const float *A, const float *B,
float beta, float *C) {
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;
if (tidx < M && tidy < N) {
float tmp = 0.0f;
for (int i = 0; i < K; ++i) {
tmp += A[tidx * K + i] * B[i * N + tidy];
}
C[tidx * N + tidy] = alpha * tmp + beta * C[tidx * N + tidy];
}
}这版把 tidx 当行、tidy 当列,不符合 GPU 惯例(通常 x 对应列),这正是导致访存问题的原因
问题也很直接:
- 每次乘加都直接读 global memory,没有数据复用
A[tidx * K + i]:warp 内tidx连续,访问A的不同行(步长为K),不连续,无法合并访问(non-coalesced)B[i * N + tidy]:tidy由threadIdx.y决定,同一 warp 内tidy相同,多个线程读取同一地址,虽然会广播但浪费带宽C[tidx * N + tidy]:同理,同一 warp 的线程访问不同行,不连续- 算力很低,大部分时间都耗在访存上
为什么同一 warp 内
tidy相同? 2D block 内线程的线性编号是threadIdx.y * blockDim.x + threadIdx.x,warp 按线性编号连续取 32 个线程。当blockDim.x >= 32时,一个 warp 内threadIdx.x从 0 跑到 31,但threadIdx.y是同一个值,所以tidy也相同假设 blockDim.x = 32: warp 0: threadIdx.y=0, threadIdx.x=0~31 warp 1: threadIdx.y=1, threadIdx.x=0~31 warp 2: threadIdx.y=2, threadIdx.x=0~31 ...
2. 调整线程映射:先把 global memory 访问做顺
template <const uint BLOCKSIZE>
__global__ void sgemm_global_mem_coalesce(...) {
const int cRow = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
const int cCol = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
float tmp = 0.0f;
for (int i = 0; i < K; ++i) {
tmp += A[cRow * K + i] * B[i * N + cCol];
}
C[cRow * N + cCol] = alpha * tmp + beta * C[cRow * N + cCol];
}核心变化不是算法,而是线程到输出元素的映射
启动时用的是 1D 线程块,每个 block 有 BLOCKSIZE * BLOCKSIZE 个线程:
dim3 block(BLOCKSIZE * BLOCKSIZE); // 比如 32*32 = 1024 个线程,只有 threadIdx.x
dim3 grid(CEIL_DIV(M, BLOCKSIZE), CEIL_DIV(N, BLOCKSIZE));每个线程只有 threadIdx.x(范围 0 ~ BLOCKSIZE*BLOCKSIZE-1),通过 / % 手动拆出行列:
threadIdx.x / BLOCKSIZE→ 行号(同一 warp 内不变)threadIdx.x % BLOCKSIZE→ 列号(同一 warp 内连续递增)
以 BLOCKSIZE=32(刚好等于 warp 大小)为例:
warp 0 — threadIdx.x: 0 1 2 3 ... 31
cRow = /32: 0 0 0 0 ... 0 ← 全相同
cCol = %32: 0 1 2 3 ... 31 ← 连续递增
这样 warp 内线程变成”同一行、连续列”:
- 读
B[i * N + cCol]:cCol 连续 → 地址连续 → coalesced access - 读
A[cRow * K + i]:cRow 相同、i 也相同 → 32 个线程算出完全一样的地址 → 硬件只做一次读取,广播给整个 warp - 写
C[cRow * N + cCol]:cCol 连续 → coalesced write
如果
BLOCKSIZE不等于 32(比如 16),一个 warp 里会包含多个 cRow,读 A 时不再是单一地址,但 cCol 仍然连续,对 B 和 C 的 coalesced access 依然成立 收益:
-
global memory access 更接近 coalesced- 不改变计算逻辑,代价很小 问题:
-
还是没有 shared memory,
A/B仍然被重复从 global memory 读取
3. Shared Memory Blocking:先把 block tile 缓到片上
template <const int BLOCKSIZE>
__global__ void sgemm_shared_mem_block(...) {
__shared__ float As[BLOCKSIZE * BLOCKSIZE];
__shared__ float Bs[BLOCKSIZE * BLOCKSIZE];
const uint threadCol = threadIdx.x % BLOCKSIZE;
const uint threadRow = threadIdx.x / BLOCKSIZE;
A += cRow * BLOCKSIZE * K;
B += cCol * BLOCKSIZE;
C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE;
float tmp = 0.0f;
for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) {
As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol];
Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol];
__syncthreads();
for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) {
tmp += As[threadRow * BLOCKSIZE + dotIdx] *
Bs[dotIdx * BLOCKSIZE + threadCol];
}
__syncthreads();
A += BLOCKSIZE;
B += BLOCKSIZE * N;
}
}这一步开始进入 GEMM 的标准写法
-
一个 block 负责一个
BLOCKSIZE x BLOCKSIZE的Ctile- 每轮从 global memory 读一块A_tile(BLOCKSIZE x BLOCKSIZE)和B_tile(BLOCKSIZE x BLOCKSIZE)到 shared memory- tile 进 shared memory 后,block 内线程反复复用,不再每次乘加都回 global memory 取数 收益: -
数据复用第一次真正建立起来- 访存带宽压力明显下降 问题:
-
一个线程还是只算一个输出元素,寄存器复用不够-
BLOCKSIZE=32时一个 block 就是1024线程,线程组织仍然偏粗
4. 1D Block Tiling:一个线程开始算多个结果
template <const int BM, const int BN, const int BK, const int TM>
__global__ void sgemm1DBlocktiling(...) {
float threadResults[TM] = {0.0f};
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB];
__syncthreads();
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
float tmpB = Bs[dotIdx * BN + threadCol];
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
threadResults[resIdx] +=
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB;
}
}
__syncthreads();
}
}这里开始做 register tiling
-
block tile 变成
BM x BN,源码里的默认参数是BM=64, BN=64, BK=8, TM=8- 一个线程不再只算一个C元素,而是负责同一列上的TM个输出-B的一个值先读进寄存器tmpB,再和TM个A值做乘加 收益: -
一个线程做更多计算,访存和计算的比例更平衡- 寄存器开始承担“线程私有缓存”的角色 注意:
-
源码里这里把
blockIdx.x / y的含义翻过来了,目的是让相邻 block 访问B时更连续,L2 locality 更好
5. 2D Block Tiling:一个线程算 TM x TN 个结果
template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void sgemm2DBlocktiling(...) {
float threadResults[TM * TN] = {0.0f};
float regM[TM] = {0.0f};
float regN[TN] = {0.0f};
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// GMEM -> SMEM
...
__syncthreads();
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint mi = 0; mi < TM; ++mi) {
for (uint ni = 0; ni < TN; ++ni) {
threadResults[mi * TN + ni] += regM[mi] * regN[ni];
}
}
}
__syncthreads();
}
}这是一个很关键的版本
-
每个线程负责一个
TM x TN的小块,而不是一列或一个点-A的TM个值进regM,B的TN个值进regN- 真正的乘加发生在寄存器里,本质上已经是一个小型 micro-kernel 收益: -
shared memory 的数据能被寄存器层再次复用- 单线程计算密度明显上升 源码默认参数:
-
大矩阵:
BM=128, BN=128, BK=8, TM=8, TN=8 -
小矩阵:退回
64 x 64tile,主要是因为这版没有认真做边界处理
6. 向量化:用 float4 提高搬运效率
float4 tmp =
reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0];
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w;
reinterpret_cast<float4 *>(&Bs[innerRowB * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];核心思路:
-
global memory 到 shared memory 的加载改成
128bit宽度,也就是一次搬4个float- 写C回 global memory 时也按float4向量化- 加载A时顺手做了一次转置,把As布局改成BK x BM,后面按dotIdx读取时更顺 收益: -
降低了指令数- 更容易吃满内存带宽 前提:
-
地址要满足
float4对齐- 尺寸最好是4的倍数
7. 处理 Shared Memory Bank Conflict:手动重排 B tile
tmp = reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 0) * 16 + innerColB / 2] = tmp.x;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 1) * 16 + innerColB / 2] = tmp.y;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 2) * 16 + innerColB / 2] = tmp.z;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 3) * 16 + innerColB / 2] = tmp.w;
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[(dotIdx * 8 + i) * 16 + threadCol];
}做到这一步以后,瓶颈已经不只是 global memory 了,shared memory 自己也会出问题
-
线程按固定 stride 读
Bs时,可能让一个 warp 里的很多线程撞到同一个 bank- 这版不是简单地按二维数组存Bs,而是把它重排成更适合 warp 读取的线性布局- 读取regN时,新的索引公式专门对应这个重排后的布局 收益: -
减少 shared memory bank conflict- 让寄存器 micro-kernel 吃到更稳定的供数速度 代价:
-
代码可读性明显下降
8. 处理 Shared Memory Bank Conflict:给 B tile 补 padding
const int extraCols = 5;
__shared__ float Bs[BK * (BN + extraCols)];
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 0] = tmp.x;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 1] = tmp.y;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 2] = tmp.z;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 3] = tmp.w;
regN[i] = Bs[dotIdx * (BN + extraCols) + threadCol * TN + i];这版解决的是同一个问题,但写法更常见
- 不再手动做复杂 swizzle- 直接给
Bs的 leading dimension 补几列 padding,打破容易冲突的 stride- 访问公式基本不变,可维护性更高 实际工程里,这种 padding 写法通常更容易保留
9. 参数调优:把 tile 形状和线程数压到更合适的位置
const int K9_NUM_THREADS = 256;
template <const int BM, const int BN, const int BK,
const int TM, const int TN>
__global__ void __launch_bounds__(K9_NUM_THREADS) sgemmAutotuned(...) {
constexpr int WM = TM * 16;
constexpr int WN = TN * 16;
constexpr int WMITER = CEIL_DIV(BM, WM);
constexpr int WNITER = CEIL_DIV(BN, WN);
float threadResults[WMITER * WNITER * TM * TN] = {0.0f};
...
}这版本质上不是“新算法”,而是“给前面的 2D register tiling 选更合适的参数” 源码里默认参数是:
NUM_THREADS = 256BM = 128BN = 128BK = 16TM = 8TN = 8
同时配了很多 static_assert:
- 保证
float4向量化加载不会出现量化残块- 保证 thread tile、block tile、thread 数之间正好整除- 保证编译器能按预期展开 这一类 kernel 的重点不是公式变了,而是“形状选对以后,寄存器、occupancy、shared memory 容量、向量化宽度”能同时落在一个比较舒服的点上
10. Warp Tiling:以 warp 为调度单位组织计算
const uint warp_id = threadIdx.x / WARPSIZE;
const uint warpCol = warp_id % (BN / WN);
const uint warpRow = warp_id / (BN / WN);
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
constexpr uint WSUBM = WM / WMITER;
constexpr uint WSUBN = WN / WNITER;
const uint lane_id = threadIdx.x % WARPSIZE;
const uint threadColInWarp = lane_id % (WSUBN / TN);
const uint threadRowInWarp = lane_id / (WSUBN / TN);前面的版本虽然已经在做 register tiling,但还没有显式把 warp 当成一级结构来设计。这里开始真正进入高性能 GEMM 的主流组织方式
-
一个 thread block 先切成多个 warp- 每个 warp 负责一个
WM x WN的 warp tile- warp tile 还会继续切成更小的WSUBM x WSUBN子块,让 32 个线程的职责分配更均匀 配套的两个核心函数也很清楚: -
loadFromGmem:负责把A/Btile 从 global memory 搬到 shared memory-processFromSmem:负责从 shared memory 取数到寄存器,然后做 warp 级 micro-kernel 收益: -
计算组织更贴近 GPU 的真实调度粒度- warp 内寄存器复用和 shared memory 读取模式都更容易做细
11. Double Buffering:一边算当前 tile,一边准备下一个 tile
__shared__ float As[2 * BM * BK];
__shared__ float Bs[2 * BK * BN];
bool doubleBufferIdx = threadIdx.x >= (NUM_THREADS / 2);
if (doubleBufferIdx == 0) {
loadFromGmem(..., As, Bs, ...);
}
__syncthreads();
for (uint bkIdx = 0; bkIdx < K; bkIdx += 2 * BK) {
if (doubleBufferIdx == 0) {
processFromSmem(..., As, Bs, ...);
...
loadFromGmem(..., As, Bs, ...);
} else {
loadFromGmem(..., As + BM * BK, Bs + BK * BN, ...);
...
processFromSmem(..., As + BM * BK, Bs + BK * BN, ...);
}
}这版开始做 ping-pong buffer
-
As/Bs不再只开一份,而是开两份- 一部分 warp 在算当前 tile,另一部分 warp 去搬下一块数据- 下一轮直接切换到另一块 shared memory buffer 重点不是“少一次同步”,而是尽量把“取数”和“计算”重叠起来,减少流水线空转 这也是现代 GEMM 的一个核心思想: -
没有 double buffering 时,常见模式是“load 完再算,算完再 load”- 有了 double buffering,就能把这两个阶段做成流水
12. Async Copy Double Buffering:把搬运进一步做成异步流水
__shared__ cuda::barrier<cuda::thread_scope_block> frontBarrier;
__shared__ cuda::barrier<cuda::thread_scope_block> backBarrier;
loadFromGmem(..., As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, ..., *frontBarrierPtr);
for (uint bkIdx = 0; bkIdx < K - BK; bkIdx += BK) {
loadFromGmem(..., As + (1 - As_offset) * BM * BK,
Bs + (1 - Bs_offset) * BK * BN, ..., *backBarrierPtr);
(*frontBarrierPtr).arrive_and_wait();
processFromSmem(..., As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, ...);
As_offset = 1 - As_offset;
Bs_offset = 1 - Bs_offset;
swap(frontBarrierPtr, backBarrierPtr);
}这版是在上一版的基础上继续前进
-
global memory 到 shared memory 的搬运改成
cuda::memcpy_async- 用cuda::barrier管理前后两个 buffer 的就绪状态- 当前 tile 在计算时,下一块数据已经在后台往 shared memory 里搬loadFromGmem里的两个细节值得记一下: -
A因为要转置存到 shared memory,异步拷贝是按标量拆开的-B保持顺序布局,可以直接按float4异步拷贝 这一版更接近现代 CUDA 在 Ampere 及之后架构上的 GEMM pipeline 写法
小结
naive -> coalesced:先把 global memory 访问顺序理顺-shared memory blocking:GEMM 的第一道真正分水岭-1D/2D register tiling:性能开始明显上台阶-vectorize + bank conflict:把供数路径细化-warp tiling + double buffering:接近高性能 GEMM 的标准模板