mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Support TMA multicast on B with m_grouped_gemm_contiguous.
This commit is contained in:
@@ -212,8 +212,20 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
if constexpr (kNumTMAMulticastOnB == kNumTMAMulticast) {
|
||||
// NOTES: In grouped contiguous GEMM, different m_block_idx values may correspond to blocks of different groups (b),
|
||||
// requiring additional checks before multicast operations.
|
||||
if (scheduler.is_tma_multicast_valid(m_block_idx))
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
else
|
||||
tma_copy<1>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
}
|
||||
else {
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
}
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,18 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(uint32_t& m_block_idx) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto expert = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto cluster_partner_expert = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return (expert == cluster_partner_expert);
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
@@ -72,10 +84,10 @@ struct Scheduler {
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,10 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked)
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm,
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i]:
|
||||
|
||||
Reference in New Issue
Block a user