CUDA GEMM:从基础写法到工程优化

SGEMM 要做的事很简单:

  • C = alpha * A * B + beta * C
  • AM x K
  • BK x N
  • CM x N

默认约定:

  • 矩阵按 row-major 存储
  • 后面的高性能版本主要面向大矩阵 benchmark,很多 kernel 默认 M/N 按 block tile 对齐、KBK 对齐、N 满足 float4 对齐
  • 真正工程里如果要通吃任意尺寸,还要补边界判断和尾块处理

参数记号:

  • BM / BN / BK:thread block 级别 tile
  • TM / TN:每个线程在寄存器里负责的输出 tile
  • WM / WN:每个 warp 负责的输出 tile

优化主线:线程映射 shared memory register tiling vectorize 解决 bank conflict autotune warp tiling double buffering async copy

1. 最基础版本:一个线程算一个输出元素

__global__ void sgemm_naive(int M, int N, int K, float alpha,
                            const float *A, const float *B,
                            float beta, float *C) {
    int tidx = blockIdx.x * blockDim.x + threadIdx.x;
    int tidy = blockIdx.y * blockDim.y + threadIdx.y;
 
    if (tidx < M && tidy < N) {
        float tmp = 0.0f;
        for (int i = 0; i < K; ++i) {
            tmp += A[tidx * K + i] * B[i * N + tidy];
        }
        C[tidx * N + tidy] = alpha * tmp + beta * C[tidx * N + tidy];
    }
}

这版把 tidx 当行、tidy 当列,不符合 GPU 惯例(通常 x 对应列),这正是导致访存问题的原因 问题也很直接:

  • 每次乘加都直接读 global memory,没有数据复用
  • A[tidx * K + i]:warp 内 tidx 连续,访问 A 的不同行(步长为 K),不连续,无法合并访问(non-coalesced)
  • B[i * N + tidy]tidythreadIdx.y 决定,同一 warp 内 tidy 相同,多个线程读取同一地址,虽然会广播但浪费带宽
  • C[tidx * N + tidy]:同理,同一 warp 的线程访问不同行,不连续
  • 算力很低,大部分时间都耗在访存上

为什么同一 warp 内 tidy 相同? 2D block 内线程的线性编号是 threadIdx.y * blockDim.x + threadIdx.x,warp 按线性编号连续取 32 个线程。当 blockDim.x >= 32 时,一个 warp 内 threadIdx.x 从 0 跑到 31,但 threadIdx.y 是同一个值,所以 tidy 也相同

假设 blockDim.x = 32:
warp 0:  threadIdx.y=0, threadIdx.x=0~31
warp 1:  threadIdx.y=1, threadIdx.x=0~31
warp 2:  threadIdx.y=2, threadIdx.x=0~31
...

2. 调整线程映射:先把 global memory 访问做顺

template <const uint BLOCKSIZE>
__global__ void sgemm_global_mem_coalesce(...) {
    const int cRow = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
    const int cCol = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
 
    float tmp = 0.0f;
    for (int i = 0; i < K; ++i) {
        tmp += A[cRow * K + i] * B[i * N + cCol];
    }
    C[cRow * N + cCol] = alpha * tmp + beta * C[cRow * N + cCol];
}

核心变化不是算法,而是线程到输出元素的映射

启动时用的是 1D 线程块,每个 block 有 BLOCKSIZE * BLOCKSIZE 个线程:

dim3 block(BLOCKSIZE * BLOCKSIZE);  // 比如 32*32 = 1024 个线程,只有 threadIdx.x
dim3 grid(CEIL_DIV(M, BLOCKSIZE), CEIL_DIV(N, BLOCKSIZE));

每个线程只有 threadIdx.x(范围 0 ~ BLOCKSIZE*BLOCKSIZE-1),通过 / % 手动拆出行列:

  • threadIdx.x / BLOCKSIZE → 行号(同一 warp 内不变)
  • threadIdx.x % BLOCKSIZE → 列号(同一 warp 内连续递增)

BLOCKSIZE=32(刚好等于 warp 大小)为例:

warp 0 — threadIdx.x: 0   1   2   3  ... 31
          cRow = /32:  0   0   0   0  ...  0   ← 全相同
          cCol = %32:  0   1   2   3  ... 31   ← 连续递增

这样 warp 内线程变成”同一行、连续列”:

  • B[i * N + cCol]:cCol 连续 → 地址连续 → coalesced access
  • A[cRow * K + i]:cRow 相同、i 也相同 → 32 个线程算出完全一样的地址 → 硬件只做一次读取,广播给整个 warp
  • C[cRow * N + cCol]:cCol 连续 → coalesced write

如果 BLOCKSIZE 不等于 32(比如 16),一个 warp 里会包含多个 cRow,读 A 时不再是单一地址,但 cCol 仍然连续,对 B 和 C 的 coalesced access 依然成立 收益:

  • global memory access 更接近 coalesced- 不改变计算逻辑,代价很小 问题:

  • 还是没有 shared memory,A/B 仍然被重复从 global memory 读取

3. Shared Memory Blocking:先把 block tile 缓到片上

template <const int BLOCKSIZE>
__global__ void sgemm_shared_mem_block(...) {
    __shared__ float As[BLOCKSIZE * BLOCKSIZE];
    __shared__ float Bs[BLOCKSIZE * BLOCKSIZE];
 
    const uint threadCol = threadIdx.x % BLOCKSIZE;
    const uint threadRow = threadIdx.x / BLOCKSIZE;
 
    A += cRow * BLOCKSIZE * K;
    B += cCol * BLOCKSIZE;
    C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE;
 
    float tmp = 0.0f;
    for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) {
        As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol];
        Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol];
        __syncthreads();
 
        for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) {
            tmp += As[threadRow * BLOCKSIZE + dotIdx] *
                   Bs[dotIdx * BLOCKSIZE + threadCol];
        }
        __syncthreads();
 
        A += BLOCKSIZE;
        B += BLOCKSIZE * N;
    }
}

这一步开始进入 GEMM 的标准写法

  • 一个 block 负责一个 BLOCKSIZE x BLOCKSIZEC tile- 每轮从 global memory 读一块 A_tile(BLOCKSIZE x BLOCKSIZE)B_tile(BLOCKSIZE x BLOCKSIZE) 到 shared memory- tile 进 shared memory 后,block 内线程反复复用,不再每次乘加都回 global memory 取数 收益:

  • 数据复用第一次真正建立起来- 访存带宽压力明显下降 问题:

  • 一个线程还是只算一个输出元素,寄存器复用不够- BLOCKSIZE=32 时一个 block 就是 1024 线程,线程组织仍然偏粗

4. 1D Block Tiling:一个线程开始算多个结果

template <const int BM, const int BN, const int BK, const int TM>
__global__ void sgemm1DBlocktiling(...) {
    float threadResults[TM] = {0.0f};
 
    for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
        As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
        Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB];
        __syncthreads();
 
        for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
            float tmpB = Bs[dotIdx * BN + threadCol];
            for (uint resIdx = 0; resIdx < TM; ++resIdx) {
                threadResults[resIdx] +=
                    As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB;
            }
        }
        __syncthreads();
    }
}

这里开始做 register tiling

  • block tile 变成 BM x BN,源码里的默认参数是 BM=64, BN=64, BK=8, TM=8- 一个线程不再只算一个 C 元素,而是负责同一列上的 TM 个输出- B 的一个值先读进寄存器 tmpB,再和 TMA 值做乘加 收益:

  • 一个线程做更多计算,访存和计算的比例更平衡- 寄存器开始承担“线程私有缓存”的角色 注意:

  • 源码里这里把 blockIdx.x / y 的含义翻过来了,目的是让相邻 block 访问 B 时更连续,L2 locality 更好

5. 2D Block Tiling:一个线程算 TM x TN 个结果

template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void sgemm2DBlocktiling(...) {
    float threadResults[TM * TN] = {0.0f};
    float regM[TM] = {0.0f};
    float regN[TN] = {0.0f};
 
    for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
        // GMEM -> SMEM
        ...
        __syncthreads();
 
        for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
            for (uint i = 0; i < TM; ++i) {
                regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
            }
            for (uint i = 0; i < TN; ++i) {
                regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
            }
            for (uint mi = 0; mi < TM; ++mi) {
                for (uint ni = 0; ni < TN; ++ni) {
                    threadResults[mi * TN + ni] += regM[mi] * regN[ni];
                }
            }
        }
        __syncthreads();
    }
}

这是一个很关键的版本

  • 每个线程负责一个 TM x TN 的小块,而不是一列或一个点- ATM 个值进 regMBTN 个值进 regN- 真正的乘加发生在寄存器里,本质上已经是一个小型 micro-kernel 收益:

  • shared memory 的数据能被寄存器层再次复用- 单线程计算密度明显上升 源码默认参数:

  • 大矩阵:BM=128, BN=128, BK=8, TM=8, TN=8

  • 小矩阵:退回 64 x 64 tile,主要是因为这版没有认真做边界处理

6. 向量化:用 float4 提高搬运效率

float4 tmp =
    reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0];
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w;
 
reinterpret_cast<float4 *>(&Bs[innerRowB * BN + innerColB * 4])[0] =
    reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];

核心思路:

  • global memory 到 shared memory 的加载改成 128bit 宽度,也就是一次搬 4float- 写 C 回 global memory 时也按 float4 向量化- 加载 A 时顺手做了一次转置,把 As 布局改成 BK x BM,后面按 dotIdx 读取时更顺 收益:

  • 降低了指令数- 更容易吃满内存带宽 前提:

  • 地址要满足 float4 对齐- 尺寸最好是 4 的倍数

7. 处理 Shared Memory Bank Conflict:手动重排 B tile

tmp = reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 0) * 16 + innerColB / 2] = tmp.x;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 1) * 16 + innerColB / 2] = tmp.y;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 2) * 16 + innerColB / 2] = tmp.z;
Bs[((innerColB % 2) * 4 + innerRowB * 8 + 3) * 16 + innerColB / 2] = tmp.w;
 
for (uint i = 0; i < TN; ++i) {
    regN[i] = Bs[(dotIdx * 8 + i) * 16 + threadCol];
}

做到这一步以后,瓶颈已经不只是 global memory 了,shared memory 自己也会出问题

  • 线程按固定 stride 读 Bs 时,可能让一个 warp 里的很多线程撞到同一个 bank- 这版不是简单地按二维数组存 Bs,而是把它重排成更适合 warp 读取的线性布局- 读取 regN 时,新的索引公式专门对应这个重排后的布局 收益:

  • 减少 shared memory bank conflict- 让寄存器 micro-kernel 吃到更稳定的供数速度 代价:

  • 代码可读性明显下降

8. 处理 Shared Memory Bank Conflict:给 B tile 补 padding

const int extraCols = 5;
__shared__ float Bs[BK * (BN + extraCols)];
 
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 0] = tmp.x;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 1] = tmp.y;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 2] = tmp.z;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 3] = tmp.w;
 
regN[i] = Bs[dotIdx * (BN + extraCols) + threadCol * TN + i];

这版解决的是同一个问题,但写法更常见

  • 不再手动做复杂 swizzle- 直接给 Bs 的 leading dimension 补几列 padding,打破容易冲突的 stride- 访问公式基本不变,可维护性更高 实际工程里,这种 padding 写法通常更容易保留

9. 参数调优:把 tile 形状和线程数压到更合适的位置

const int K9_NUM_THREADS = 256;
 
template <const int BM, const int BN, const int BK,
          const int TM, const int TN>
__global__ void __launch_bounds__(K9_NUM_THREADS) sgemmAutotuned(...) {
    constexpr int WM = TM * 16;
    constexpr int WN = TN * 16;
    constexpr int WMITER = CEIL_DIV(BM, WM);
    constexpr int WNITER = CEIL_DIV(BN, WN);
 
    float threadResults[WMITER * WNITER * TM * TN] = {0.0f};
    ...
}

这版本质上不是“新算法”,而是“给前面的 2D register tiling 选更合适的参数” 源码里默认参数是:

  • NUM_THREADS = 256
  • BM = 128
  • BN = 128
  • BK = 16
  • TM = 8
  • TN = 8

同时配了很多 static_assert

  • 保证 float4 向量化加载不会出现量化残块- 保证 thread tile、block tile、thread 数之间正好整除- 保证编译器能按预期展开 这一类 kernel 的重点不是公式变了,而是“形状选对以后,寄存器、occupancy、shared memory 容量、向量化宽度”能同时落在一个比较舒服的点上

10. Warp Tiling:以 warp 为调度单位组织计算

const uint warp_id = threadIdx.x / WARPSIZE;
const uint warpCol = warp_id % (BN / WN);
const uint warpRow = warp_id / (BN / WN);
 
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
constexpr uint WSUBM = WM / WMITER;
constexpr uint WSUBN = WN / WNITER;
 
const uint lane_id = threadIdx.x % WARPSIZE;
const uint threadColInWarp = lane_id % (WSUBN / TN);
const uint threadRowInWarp = lane_id / (WSUBN / TN);

前面的版本虽然已经在做 register tiling,但还没有显式把 warp 当成一级结构来设计。这里开始真正进入高性能 GEMM 的主流组织方式

  • 一个 thread block 先切成多个 warp- 每个 warp 负责一个 WM x WN 的 warp tile- warp tile 还会继续切成更小的 WSUBM x WSUBN 子块,让 32 个线程的职责分配更均匀 配套的两个核心函数也很清楚:

  • loadFromGmem:负责把 A/B tile 从 global memory 搬到 shared memory- processFromSmem:负责从 shared memory 取数到寄存器,然后做 warp 级 micro-kernel 收益:

  • 计算组织更贴近 GPU 的真实调度粒度- warp 内寄存器复用和 shared memory 读取模式都更容易做细

11. Double Buffering:一边算当前 tile,一边准备下一个 tile

__shared__ float As[2 * BM * BK];
__shared__ float Bs[2 * BK * BN];
 
bool doubleBufferIdx = threadIdx.x >= (NUM_THREADS / 2);
 
if (doubleBufferIdx == 0) {
    loadFromGmem(..., As, Bs, ...);
}
__syncthreads();
 
for (uint bkIdx = 0; bkIdx < K; bkIdx += 2 * BK) {
    if (doubleBufferIdx == 0) {
        processFromSmem(..., As, Bs, ...);
        ...
        loadFromGmem(..., As, Bs, ...);
    } else {
        loadFromGmem(..., As + BM * BK, Bs + BK * BN, ...);
        ...
        processFromSmem(..., As + BM * BK, Bs + BK * BN, ...);
    }
}

这版开始做 ping-pong buffer

  • As/Bs 不再只开一份,而是开两份- 一部分 warp 在算当前 tile,另一部分 warp 去搬下一块数据- 下一轮直接切换到另一块 shared memory buffer 重点不是“少一次同步”,而是尽量把“取数”和“计算”重叠起来,减少流水线空转 这也是现代 GEMM 的一个核心思想:

  • 没有 double buffering 时,常见模式是“load 完再算,算完再 load”- 有了 double buffering,就能把这两个阶段做成流水

12. Async Copy Double Buffering:把搬运进一步做成异步流水

__shared__ cuda::barrier<cuda::thread_scope_block> frontBarrier;
__shared__ cuda::barrier<cuda::thread_scope_block> backBarrier;
 
loadFromGmem(..., As + As_offset * BM * BK,
             Bs + Bs_offset * BK * BN, ..., *frontBarrierPtr);
 
for (uint bkIdx = 0; bkIdx < K - BK; bkIdx += BK) {
    loadFromGmem(..., As + (1 - As_offset) * BM * BK,
                 Bs + (1 - Bs_offset) * BK * BN, ..., *backBarrierPtr);
 
    (*frontBarrierPtr).arrive_and_wait();
    processFromSmem(..., As + As_offset * BM * BK,
                         Bs + Bs_offset * BK * BN, ...);
 
    As_offset = 1 - As_offset;
    Bs_offset = 1 - Bs_offset;
    swap(frontBarrierPtr, backBarrierPtr);
}

这版是在上一版的基础上继续前进

  • global memory 到 shared memory 的搬运改成 cuda::memcpy_async- 用 cuda::barrier 管理前后两个 buffer 的就绪状态- 当前 tile 在计算时,下一块数据已经在后台往 shared memory 里搬 loadFromGmem 里的两个细节值得记一下:

  • A 因为要转置存到 shared memory,异步拷贝是按标量拆开的- B 保持顺序布局,可以直接按 float4 异步拷贝 这一版更接近现代 CUDA 在 Ampere 及之后架构上的 GEMM pipeline 写法

小结

  • naive -> coalesced:先把 global memory 访问顺序理顺- shared memory blocking:GEMM 的第一道真正分水岭- 1D/2D register tiling:性能开始明显上台阶- vectorize + bank conflict:把供数路径细化- warp tiling + double buffering:接近高性能 GEMM 的标准模板

参考链接