diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index cad4fb1..ce14171 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -200,7 +200,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], if m == 0: return - aligned_n = (n + 63) // 64 * 64 + # K must be aligned to 128 aligned_k = (k + 127) // 128 * 128 # Auto-tuning with compilation @@ -241,7 +241,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], runtime, best_keys = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'N': aligned_n, 'K': aligned_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'N': n, 'K': aligned_k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'SWIZZLE_D_MODE': smem_config[1], 'BLOCK_N_PADDING': smem_config[2], 'NUM_STAGES': num_stages, diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index fbefb7b..9eb0373 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -71,13 +71,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], if k == 0: return - aligned_n = (n + 63) // 64 * 64 + # K must be aligned to 128 aligned_k = (k + 127) // 128 * 128 # Auto-tuning with compilation num_sms = get_num_sms() num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, aligned_n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) + m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) last_stages = (k + 127) // 128 % num_stages block_k = 128 num_tma_threads = 128 @@ -114,7 +114,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], runtime, best_keys = jit_tuner.compile_and_tune( name='wgrad_gemm_fp8_fp8_fp32_nt', - keys={'M': m, 'N': aligned_n, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'M': m, 'N': n, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_STAGES': num_stages, 'LAST_STAGES': last_stages, 'NUM_TMA_MULTICAST': tma_multicast_config[0],