diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 8419f7d..fbe05c7 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -851,18 +851,6 @@ __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M) - 128) / 32; - write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, - scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, problem_input.ld_d); - } else { cute::tma_store_fence(); diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 622be39..dacf5f1 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -391,21 +391,9 @@ struct SchedulerSelector { static constexpr auto select_type() { - if constexpr (GT == GemmType::Normal) - return NormalScheduler(); - if constexpr (GT == GemmType::GroupedContiguous) - return GroupedContiguousScheduler(); - if constexpr (GT == GemmType::GroupedMasked) - return GroupedMaskedScheduler(); if constexpr (GT == GemmType::GroupedWithOffset) return GroupedWithOffsetScheduler(); - if constexpr (GT == GemmType::StridedBatched) - return StridedBatchedScheduler(); } using type = decltype(select_type()); diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 2a2cc31..459d1c7 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - + #block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) + block_ns = tuple(range(16, 129, 8)) # Avoid bank conflicts for FP32 output if is_fp32_out: block_ns = [x for x in block_ns if x % 16 == 8] diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 9238468..2a60767 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] else: m_per_expert_threshold = 32 # H100 - if expected_m>= m_per_expert_threshold: + if expected_m> m_per_expert_threshold: num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False) diff --git a/tests/test_core.py b/tests/test_core.py index e152a9c..eb3e51e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -119,10 +119,9 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, def test_m_grouped_gemm_offset() -> None: print('Testing grouped contiguous GEMM:') - for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),): + for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),): # NOTES: we should mask the unfilled part before calculating difference - ''' x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) @@ -160,6 +159,8 @@ def test_m_grouped_gemm_offset() -> None: print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + + ''' print()