Remove restrictions on N

This commit is contained in:
Chenggang Zhao 2025-05-14 14:27:04 +08:00
parent c4a7116e0a
commit 279eb03190
2 changed files with 7 additions and 5 deletions

View File

@ -200,7 +200,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
if m == 0:
return
aligned_n = (n + 63) // 64 * 64
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
@ -241,7 +241,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'N': n, 'K': aligned_k,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,

View File

@ -71,13 +71,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
if k == 0:
return
aligned_n = (n + 63) // 64 * 64
# K must be aligned to 128
aligned_k = (k + 127) // 128 * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages
block_k = 128
num_tma_threads = 128
@ -114,7 +114,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
runtime, best_keys = jit_tuner.compile_and_tune(
name='wgrad_gemm_fp8_fp8_fp32_nt',
keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
keys={'M': m, 'N': n,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],