mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-06 13:34:22 +00:00
Code polishing
This commit is contained in:
parent
2752a67aad
commit
a51629ddf9
@ -12,10 +12,10 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
|||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
- [ ] More correctness tests for grouped-contiguous layout
|
- [x] More correctness tests for grouped-contiguous layout
|
||||||
- [x] Shared memory swizzling for output
|
- [x] Shared memory swizzling for output
|
||||||
- [ ] Larger block size on N (up to 256)
|
- [ ] Larger block size on N (up to 256)
|
||||||
- [ ] MoE scheduler with TMA multicast compatibility
|
- [x] MoE scheduler with TMA multicast compatibility
|
||||||
- [ ] Weight gradient kernels for dense models
|
- [ ] Weight gradient kernels for dense models
|
||||||
- [ ] Weight gradient kernels for MoE models
|
- [ ] Weight gradient kernels for MoE models
|
||||||
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
||||||
|
@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// Prefetch TMA descriptors at very beginning
|
// Prefetch TMA descriptors at very beginning
|
||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||||
|
|
||||||
// `tensor_map_d` is only used in swizzling mode
|
// `tensor_map_d` is only used in swizzling mode
|
||||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||||
if constexpr (kSwizzleDMode > 0)
|
if constexpr (kSwizzleDMode > 0)
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
@ -212,19 +212,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||||
|
|
||||||
// Issue TMA B
|
// Issue TMA B
|
||||||
if constexpr (kNumTMAMulticastOnB == kNumTMAMulticast) {
|
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||||
// NOTES: In grouped contiguous GEMM, different m_block_idx values may correspond to blocks of different groups (b),
|
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
|
||||||
// requiring additional checks before multicast operations.
|
// requiring additional checks before multicast operations.
|
||||||
if (scheduler.is_tma_multicast_valid(m_block_idx))
|
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
|
||||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
|
||||||
else
|
|
||||||
tma_copy<1>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
|
||||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||||
|
} else {
|
||||||
|
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
|
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||||
}
|
}
|
||||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||||
}
|
}
|
||||||
|
@ -47,9 +47,9 @@ struct Scheduler {
|
|||||||
if constexpr (kGemmType == GemmType::Normal) {
|
if constexpr (kGemmType == GemmType::Normal) {
|
||||||
return true;
|
return true;
|
||||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||||
auto expert = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||||
auto cluster_partner_expert = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||||
return (expert == cluster_partner_expert);
|
return group_idx == peer_group_idx;
|
||||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user