flash attention 干了什么

template <typename Ty, int kBc = 4, int kBr = 4, int kDim = 128>
__global__ void flash_attention_v2_kernel(Ty* Q, Ty* K, Ty* V, Ty* O,
                                          int seqlen,       // M
                                          int stride_head,  // M*N
                                          Ty smScale) {
  int groupSeq = (seqlen + kBc - 1) / kBc;
  int groupTx = (kDim + kBc - 1) / kBc;
  int groupTy = (kDim + kBr - 1) / kBr;
  __shared__ Ty sQ[kBr][kDim];
  __shared__ Ty sK[kBc][kDim];
  __shared__ Ty sV[kBc][kDim];
  __shared__ Ty sO[kBr][kDim];
  __shared__ Ty sQK[kBr][kBc];
  __shared__ Ty sSafeE[kBr][kBc];
  __shared__ Ty sDenom[kBr];
  __shared__ Ty sMax[kBr];
 
  int tx = threadIdx.x;
  int ty = threadIdx.y;
  int base_offset = blockIdx.x * stride_head;
  int row = ty + blockIdx.y * blockDim.y;
  if (row >= seqlen) {
    return;
  }
 
  //   显存地址 (内存是线性的)
  // 0         1000        2000        3000        ...
  // |-----------|-----------|-----------|-----------| ...
  // | Batch0    | Batch0    | Batch0    | Batch1    |
  // | Head0     | Head1     | Head2     | Head0     |
  // | 数据      | 数据      | 数据      | 数据      |
  // |-----------|-----------|-----------|-----------|
  // ^           ^           ^
  // 初始指针Q    偏移后的Q    再偏移...
 
  Q += base_offset;  // 见上图 所以需要偏移
  K += base_offset;
  V += base_offset;
  O += base_offset;
 
  // 每个block都搬移一整行块进来
  for (int i = 0; i < groupTx; i++) {
    sQ[ty][i * kBc + tx] = Q[row * kDim + i * kBc + tx];  // Q0
    sO[ty][i * kBc + tx] = 0;
  }
 
  sMax[ty] = -INFINITY;  // 行最大值 实时修正
  sDenom[ty] = 0;        // 每一行的分母 实时修正
 
  // Outer loop over KV blocks
  // 即M方向最大值 所以在一个block中循环即可
  for (int j = 0; j < groupSeq; j++) {  // KV的M方向循环(要反过来
    // Load K and V
    if ((j * kBc + tx) < seqlen) {
      // 这个搬运与上面P矩阵搬运类似
      for (int i = 0; i < groupTy; i++) {  // KV的N方向循环
        // j = 0:K0, j = 1: K1 ...
        sK[tx][i * kBr + ty] = K[j * kBc * kDim + tx * kDim + i * kBr + ty];
        // j = 0:V0, j = 1: V1 ...
        sV[tx][i * kBr + ty] = V[j * kBc * kDim + tx * kDim + i * kBr + ty];
      }
    }
 
    __syncthreads();
 
    // Compute Q * K^T
    Ty sum = 0.f;
    for (int i = 0; i < kDim; i++) {
      sum += sQ[ty][i] * sK[tx][i];
    }
 
    // 每一个thread都算一下
    // 最终汇聚起来就是Q0*K0^T
    sQK[ty][tx] = sum * smScale;
 
    __syncthreads();
 
    // Find max for numerical stability
    // 寻找QK的行最大值进行softmax
    // 每个thread,每个小块都自己找对应该行的最大值
    Ty localMax = -INFINITY;
    for (int i = 0; i < kBc; i++) {
      localMax = max(localMax, sQK[ty][i]);
    }
    __syncthreads();
    // 历史最大值 Q0K0 Q0K1 Q0K2中
    Ty newMax = max(sMax[ty], localMax);
 
    // Compute Exponentials
    // 计算当前小方块的分子
    sSafeE[ty][tx] = exp(sQK[ty][tx] - newMax);
    __syncthreads();
 
    // Compute Denominator
    Ty localDenom = 0.f;
    for (int i = 0; i < kBc; i++) {
      // 计算当前小方块的分母(求和)
      localDenom += sSafeE[ty][i];
    }
    __syncthreads();
 
    // Update Output (Online Softmax)
    // 分母修正用 乘在分母前
    Ty rescaleOld = exp(sMax[ty] - newMax);
    // 分母:旧数据贡献*修正 + 新数据贡献
    Ty newDenom = sDenom[ty] * rescaleOld + localDenom;
    // i控制v的列循环 也控制O的列循环
    for (int i = 0; i < groupTx; i++) {
      // 修正老的分子
      sO[ty][i * kBc + tx] = (sO[ty][i * kBc + tx] * rescaleOld);
      // 小方块内矩阵计算
      for (int k = 0; k < kBc; k++) {
        sO[ty][i * kBc + tx] += sSafeE[ty][k] * sV[k][i * kBc + tx];
      }
    }
    // 更新QK矩阵的行最大值
    sMax[ty] = newMax;
    // 更新老分母
    sDenom[ty] = newDenom;
    __syncthreads();
  }
 
  // Write output to global memory
  for (int i = 0; i < groupTx; i++) {
    // 全局更新分母
    O[row * kDim + i * kBc + tx] = sO[ty][i * kBc + tx] / sDenom[ty];
  }
}
 
torch::Tensor flash_attention_v2_cuda(torch::Tensor q, torch::Tensor k,
                                      torch::Tensor v) {
  CHECK_INPUT(q);
  CHECK_INPUT(k);
  CHECK_INPUT(v);
 
  // 1, 1, M, N
  int bs = q.size(0);
  int head = q.size(1);
  int seqlen = q.size(2);  // M
  int dim = q.size(3);     // N
  float sm_scale = 1.f / sqrtf(static_cast<float>(dim));
  int stride_head = seqlen * dim;
 
  auto out = torch::zeros_like(q);
 
  const int Br = 4;
  const int Bc = 4;
  int Gc = bs * head;
  int Gr = (seqlen + Br - 1) / Br;
 
  assert(dim % Bc == 0 && seqlen % Br == 0);
 
  dim3 grid = dim3(Gc, Gr);
  dim3 block = dim3(Bc, Br);
 
  using scalar_t = float;
  flash_attention_v2_kernel<scalar_t, Bc, Br, 128><<<grid, block>>>(
      q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(),
      out.data_ptr<scalar_t>(), seqlen, stride_head, sm_scale);
 
  return out;
}

简单的线代基础

如果 A 矩阵再拆分为 2 个 如果 B 矩阵再拆分为 2 个 如果 AB 都拆

self-attention 的公式

  1. 输入定义

输入矩阵为

  • 序列长度
  • 嵌入维度
  1. Attention 公式

线性变换:

  1. 参数说明
  • : 输入矩阵的列数(即
  • : 模型设计时的超参数“头数”(num_heads)
  1. Softmax 计算 (数值稳定性优化)
  • : 表示每一行(输入向量)
  • 分母: 表示对该行所有元素求和
  • 分子: 表示当前元素经过数值稳定处理后的指数值

原始 self-attention 步骤

总结:简单的矩阵分块,一行一起 softmax 三次写入三次写出:1. QKt 写入写出,2. softmax 写入写出,3. PV 乘法写入写出

flash 版本的 self-attention 步骤

总结:矩阵分块计算,softmax 也基本是分块计算