diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 61a80ca..038ff3e 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -192,8 +192,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); // Assign TMA multicast number into A and B - constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; - constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all @@ -203,35 +205,21 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait consumer release empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // NOTES: There may be additional odd rows/columns or cases where multicast is not possible. - // In grouped contiguous GEMM, different m_block_idx values can also lead to the inability to multicast. - // We use is_tma_multicast_valid to determine whether multicast is possible. - // Issue TMA A + // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - if (kNumTMAMulticastOnA > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) { - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - } - else { - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - } + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), + num_tma_multicast_a); // Issue TMA B - if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) { - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - } else { - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - } + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index b24243c..457eaff 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -23,12 +23,13 @@ struct Scheduler { // For normal GEMM // Maybe not used in the masked grouped GEMM uint32_t num_blocks; + uint32_t num_blocks_in_group; // For grouped GEMM int* grouped_layout; + // Only used for masked layout uint32_t curr_group_idx, curr_cumsum; - int num_blocks_in_group; __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, int* grouped_layout = nullptr) { @@ -85,18 +86,17 @@ struct Scheduler { if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { num_blocks_in_group = num_blocks_in_group ^ 1; - } - else { + } else { in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; first_block_idx += num_blocks_in_group ^ 1; num_blocks_in_group = 1; } } + if constexpr (kIsTMAMulticastOnA) { m_block_idx = in_group_idx / num_blocks_in_group; n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; - } - else { + } else { m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; n_block_idx = in_group_idx / num_blocks_in_group; } diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 22731a6..18cdb58 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -80,15 +80,14 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], return tensor_map; } -template __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - int32_t const& crd_0, int32_t const& crd_1) { + int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if constexpr (kNumTMAMulticast == 1) { + if (num_tma_multicast == 1) { cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1); + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); } }