From 742fb1c8a589e8f5ceffdcb3e71080dba380f516 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 13:41:28 +0800 Subject: [PATCH] Compilation-time GCD --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 2 +- deep_gemm/include/deep_gemm/utils.cuh | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c2572a6..a041d40 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - // TODO: check `ceil_div(BLOCK_N, BLOCK_K) == 1` or `gcd(BLOCK_K, BLOCK_N) == BLOCK_N - BLOCK_K` + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 0005907..fe2c016 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -46,3 +46,8 @@ template __device__ __host__ constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } + +template +__device__ __host__ constexpr T gcd(T a, T b) { + return b == 0 ? a : gcd(b, a % b); +}