From 406c63070996d012fc38203f23545b8fa65b1b2f Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 14 Apr 2025 11:55:37 +0800 Subject: [PATCH] Stricter assertions --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index bf5e0ff..cfcb569 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -65,6 +65,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Types using WGMMA = typename FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); // Shared memory static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); @@ -356,14 +357,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; - if constexpr (kSwizzleDMode > 0) { - DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom"); - DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, - "Unaligned TMA store or too many TMA store instructions"); - DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); - DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, - "Swizzling and padding are not compatible"); - } + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, + "Swizzling and padding are not compatible"); // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");