__global__ void matMulNaive(float *A, float *B, float *C, int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
 
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; ++k) {
            // A 读取良好:连续
            // B 读取糟糕:stride 为 N
            sum += A[row * K + k] * B[k * N + col]; 
        }
        C[row * N + col] = sum;
    }
}
// 假设传入的 B_T 已经是 B 的转置矩阵,维度为 [N, K]
__global__ void matMulTransposed(float *A, float *B_T, float *C, int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
 
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; ++k) {
            // A 读取良好:连续
            // B_T 读取良好:也是连续的!(col 变成了 B_T 的行索引)
            sum += A[row * K + k] * B_T[col * K + k];
        }
        C[row * N + col] = sum;
    }
}

总结: 以前是MxK * KxN 现在是MxK * (NxK)T,但其实本质是MxK*NxK,是B矩阵也是按照行索引 以前是col那一列:k * N + col 现在是col那一行:col*K + k 原始AB: 新AB’