From 7ffb118e5453898ef8daf6a9acf37be575db5156 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 14:56:42 +0800 Subject: [PATCH] Support multicasting on B --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 30 ++++++++++-------- deep_gemm/include/deep_gemm/scheduler.cuh | 32 +++++++++++++------ deep_gemm/jit/template.py | 5 ++- deep_gemm/jit_kernels/gemm.py | 38 ++++++++++++++--------- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index a041d40..065bd88 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -31,7 +31,7 @@ template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, @@ -146,7 +146,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); + auto scheduler = Scheduler(shape_m, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -161,6 +161,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // Assign TMA multicast number into A and B + constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; + constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll @@ -168,18 +172,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait consumer release empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // Issue TMA A with broadcasting + // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } @@ -347,7 +351,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template class Gemm { private: @@ -369,7 +373,7 @@ public: constexpr uint32_t kNumMathThreadsPerGroup = 128; auto kernel = fp8_gemm_kernel; + kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); // Cluster launch diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index c339b53..6e3cb52 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -12,9 +12,10 @@ enum class GemmType { #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template + uint32_t kNum1DBlocksPerGroup = 16> struct Scheduler { int current_iter = -1; uint32_t num_aligned_m_blocks; @@ -43,16 +44,27 @@ struct Scheduler { } __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages - auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; - auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + // TODO: unify these 2 branches + if constexpr (kIsTMAMulticastOnA) { + auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + } else { + auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + } } template diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index cdca4c4..ead37f5 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str: # We don't use `str.format` because it's not safe for C++ {} braces new_template = copy.deepcopy(template) for key, value in keys.items(): - new_template = new_template.replace(f'{{{key}}}', f'{value}') + value_str = str(value) + if isinstance(value, bool): + value_str = value_str.lower() + new_template = new_template.replace(f'{{{key}}}', f'{value_str}') return new_template diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 6852d5e..d0fd8f5 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -15,9 +15,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -31,10 +32,10 @@ GemmType::run(out, rhs_scales, nullptr, """ -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: if num_tma_multicast == 1: return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: @@ -56,7 +57,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]: + is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space block_ms = (64 if m <= 64 else 128, ) @@ -96,20 +97,27 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, break assert best_num_stages is not None - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - # When using large block tiling, broadcasting B is required to achieve maximum performance gains. - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # Decide the number of TMA multicast and whether broadcast on A + best_tma_multicast_config = (1, True) + + # Try to multicast on the larger block side first + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms), + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 1024 and is_multicast_legal[i] and num_groups == 1: + best_tma_multicast_config = (2, i == 'A') + break # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -159,12 +167,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms) args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),