Support more shapes

This commit is contained in:
Chenggang Zhao 2025-02-28 10:04:59 +08:00
parent b69f630b91
commit 6c5da03ba9
2 changed files with 4 additions and 5 deletions

View File

@ -50,12 +50,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Shared memory
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
@ -99,9 +100,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
// Fill barriers
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;

View File

@ -50,7 +50,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
return smem_size