mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 12:04:22 +00:00
Indivisible TMA (#90)
Fix indivisible shapes for TMA multicast --------- Co-authored-by: yukuai <yukuai@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
parent
891f35adf5
commit
95e81b3dd6
@ -16,7 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] Shared memory swizzling for output
|
||||
- [ ] Larger block size on N (up to 256)
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [ ] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Weight gradient kernels for dense models
|
||||
- [ ] Weight gradient kernels for MoE models
|
||||
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
||||
|
@ -192,8 +192,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
|
||||
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
@ -205,23 +208,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
|
||||
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
|
||||
// requiring additional checks before multicast operations.
|
||||
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
} else {
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
}
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
@ -281,7 +279,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -23,9 +23,12 @@ struct Scheduler {
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_blocks_in_group;
|
||||
bool is_peer_cta_alive = true;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
@ -43,15 +46,20 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
return true;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
return true;
|
||||
} else {
|
||||
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,23 +67,32 @@ struct Scheduler {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
// TODO: unify these 2 branches
|
||||
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
|
||||
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
|
||||
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
|
||||
|
||||
// Fix unaligned TMA multicast
|
||||
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
|
||||
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
|
||||
num_blocks_in_group = num_blocks_in_group ^ 1;
|
||||
} else {
|
||||
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
|
||||
first_block_idx += num_blocks_in_group ^ 1;
|
||||
num_blocks_in_group = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to final M/N block indices
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
m_block_idx = in_group_idx / num_blocks_in_group;
|
||||
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
} else {
|
||||
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_m_blocks_in_group;
|
||||
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_blocks_in_group;
|
||||
}
|
||||
}
|
||||
|
||||
@ -102,7 +119,7 @@ struct Scheduler {
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
// Within the current group
|
||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
@ -117,6 +134,10 @@ struct Scheduler {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
|
||||
is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
|
||||
num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
|
||||
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
|
@ -80,15 +80,14 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
if (num_tma_multicast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr,
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
|
||||
require_divisible: bool = False) -> bool:
|
||||
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
|
||||
return divisible and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_swizzle_mode(block_n: int) -> int:
|
||||
@ -146,9 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i]:
|
||||
|
Loading…
Reference in New Issue
Block a user