mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 17:44:21 +00:00
Compatible with padding
This commit is contained in:
parent
4c111418a2
commit
23d4289365
@ -355,12 +355,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// TMA checks
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
||||
"Swizzling and padding are not compatible");
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
||||
"Swizzling and padding are not compatible");
|
||||
}
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
|
@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
'--diag-suppress=39,174,177,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
Loading…
Reference in New Issue
Block a user