From f4014953ada2d2e1db61a935258ec842a103e6d6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 22 Apr 2025 17:24:02 +0800 Subject: [PATCH] Several code lints x2 --- README.md | 2 +- deep_gemm/include/deep_gemm/scheduler.cuh | 3 +++ deep_gemm/jit_kernels/gemm.py | 14 +++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index ccc46e0..5d925da 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] Shared memory swizzling for output - [ ] Larger block size on N (up to 256) - [x] MoE scheduler with TMA multicast compatibility -- [ ] Fix TMA multicast compatibility for indivisible shapes +- [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models - [ ] Utility kernels for MoE models (as a pre-built CUDA library) diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 457eaff..6a0b9dc 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -83,6 +83,8 @@ struct Scheduler { auto first_block_idx = group_idx * kNum1DBlocksPerGroup; auto in_group_idx = block_idx % num_blocks_per_group; num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { num_blocks_in_group = num_blocks_in_group ^ 1; @@ -93,6 +95,7 @@ struct Scheduler { } } + // Convert to final M/N block indices if constexpr (kIsTMAMulticastOnA) { m_block_idx = in_group_idx / num_blocks_in_group; n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index cb75432..cb438b7 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr, """ -def is_tma_multicast_legal(shape_dim: int, multicast_block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: - if num_tma_multicast == 1: - return True - return shape_dim % multicast_block_dim == 0 and num_sms % num_tma_multicast == 0 +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, + require_divisible: bool = False) -> bool: + divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible + return divisible and num_sms % num_tma_multicast == 0 def get_swizzle_mode(block_n: int) -> int: @@ -146,10 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first - # NOTES: Currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the n-direction to be even. + # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n * (2 if is_grouped_masked else 1), 2, num_sms), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked), + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, } for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): if m >= 512 and is_multicast_legal[i]: