diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index f482300..4a0c79e 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -50,12 +50,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, using Barrier = cutlass::arch::ClusterTransactionBarrier; // Shared memory + static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); - static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); // Configs constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; @@ -99,9 +100,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); // Fill barriers - DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); - DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers"); - auto barrier_start_ptr = reinterpret_cast(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2)); + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); #pragma unroll for (int i = 0; i < kNumStages; ++ i) { full_barriers[i] = barrier_start_ptr + i; diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index b5c509c..1ba413a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -50,7 +50,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2) + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier return smem_size