From a51629ddf968ada2cb4e8e7365ffb14294b629d2 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 18 Apr 2025 16:43:00 +0800 Subject: [PATCH] Code polishing --- README.md | 4 ++-- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 24 ++++++++++------------- deep_gemm/include/deep_gemm/scheduler.cuh | 6 +++--- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index b7924e1..7265a04 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## Roadmap -- [ ] More correctness tests for grouped-contiguous layout +- [x] More correctness tests for grouped-contiguous layout - [x] Shared memory swizzling for output - [ ] Larger block size on N (up to 256) -- [ ] MoE scheduler with TMA multicast compatibility +- [x] MoE scheduler with TMA multicast compatibility - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models - [ ] Utility kernels for MoE models (as a pre-built CUDA library) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 6e9c9e6..d7a1f69 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Prefetch TMA descriptors at very beginning if (threadIdx.x == kNumMathThreads) { - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_scales_a); // `tensor_map_d` is only used in swizzling mode // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + cute::prefetch_tma_descriptor(&tensor_map_d); } __syncwarp(); @@ -212,19 +212,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); // Issue TMA B - if constexpr (kNumTMAMulticastOnB == kNumTMAMulticast) { - // NOTES: In grouped contiguous GEMM, different m_block_idx values may correspond to blocks of different groups (b), + if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) { + // NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B), // requiring additional checks before multicast operations. - if (scheduler.is_tma_multicast_valid(m_block_idx)) - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - else - tma_copy<1>(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - } - else { + DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast"); tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + } else { + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); } full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 2b0d343..78d9d88 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -47,9 +47,9 @@ struct Scheduler { if constexpr (kGemmType == GemmType::Normal) { return true; } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - auto expert = __ldg(grouped_layout + m_block_idx * BLOCK_M); - auto cluster_partner_expert = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); - return (expert == cluster_partner_expert); + auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; } else if constexpr (kGemmType == GemmType::GroupedMasked) { return false; }