mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Remove restrictions on N
This commit is contained in:
parent
c4a7116e0a
commit
279eb03190
@ -200,7 +200,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
if m == 0:
|
if m == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
aligned_n = (n + 63) // 64 * 64
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = (k + 127) // 128 * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
@ -241,7 +241,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
name='gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': aligned_k,
|
||||||
|
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'SWIZZLE_D_MODE': smem_config[1],
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
'BLOCK_N_PADDING': smem_config[2],
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
|
|||||||
@ -71,13 +71,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
if k == 0:
|
if k == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
aligned_n = (n + 63) // 64 * 64
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = (k + 127) // 128 * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||||
m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||||
last_stages = (k + 127) // 128 % num_stages
|
last_stages = (k + 127) // 128 % num_stages
|
||||||
block_k = 128
|
block_k = 128
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
@ -114,7 +114,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||||
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
||||||
keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'M': m, 'N': n,
|
||||||
|
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
'LAST_STAGES': last_stages,
|
'LAST_STAGES': last_stages,
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user