mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Remove restrictions on N
This commit is contained in:
parent
c4a7116e0a
commit
279eb03190
@ -200,7 +200,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
aligned_n = (n + 63) // 64 * 64
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
@ -241,7 +241,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
keys={'N': n, 'K': aligned_k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
|
||||
@ -71,13 +71,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
if k == 0:
|
||||
return
|
||||
|
||||
aligned_n = (n + 63) // 64 * 64
|
||||
# K must be aligned to 128
|
||||
aligned_k = (k + 127) // 128 * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
last_stages = (k + 127) // 128 % num_stages
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
@ -114,7 +114,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
||||
keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
keys={'M': m, 'N': n,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages,
|
||||
'LAST_STAGES': last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user