Compilation-time GCD

This commit is contained in:
Chenggang Zhao 2025-03-25 13:41:28 +08:00
parent b922e64cb2
commit 742fb1c8a5
2 changed files with 6 additions and 1 deletions

View File

@ -43,7 +43,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
// TODO: check `ceil_div(BLOCK_N, BLOCK_K) == 1` or `gcd(BLOCK_K, BLOCK_N) == BLOCK_N - BLOCK_K`
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;

View File

@ -46,3 +46,8 @@ template <typename T>
__device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ constexpr T gcd(T a, T b) {
return b == 0 ? a : gcd(b, a % b);
}