From 094d0421eca0fb7378029f395e7535fb13802d36 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 13 Mar 2025 07:04:56 +0000 Subject: [PATCH] refine --- deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh | 4 +--- deep_gemm/jit_kernels/gemm_bw.py | 2 +- tests/test_core.py | 5 ++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh b/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh index e6fa3d0..1f6219e 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm_backward_w.cuh @@ -179,8 +179,6 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N, scheduler.get_global_idx(0, 1, k_idx / BLOCK_K)); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); - // full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); - } // Wait unaligned cases @@ -268,7 +266,7 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barrier_arrive(s); #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::kNumAccum = 64, loop 16 steps + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { int src_lane_id = threadIdx.x % 4 + (i % 8) * 4; float scale_b_0 = __shfl_sync(0xffffffff, scale_b[i/8].x, src_lane_id); float scale_b_1 = __shfl_sync(0xffffffff, scale_b[i/8].y, src_lane_id); diff --git a/deep_gemm/jit_kernels/gemm_bw.py b/deep_gemm/jit_kernels/gemm_bw.py index ee65f59..8b82804 100644 --- a/deep_gemm/jit_kernels/gemm_bw.py +++ b/deep_gemm/jit_kernels/gemm_bw.py @@ -75,7 +75,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(32, 129, 32)) + block_ns = (128, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) diff --git a/tests/test_core.py b/tests/test_core.py index 52df4a7..48dcfa5 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -109,7 +109,6 @@ def test_gemm_backward_w() -> None: diff = calc_diff(out, ref_out) assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' torch.cuda.synchronize() - print(diff) # noinspection PyShadowingNames def test_func(): @@ -195,5 +194,5 @@ if __name__ == '__main__': test_gemm_backward_w() test_gemm() - # test_m_grouped_gemm_contiguous() - # test_m_grouped_gemm_masked() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked()