Code polishing x3

This commit is contained in:
Chenggang Zhao 2025-04-21 09:41:13 +08:00
parent 594020e6af
commit c99756778d
3 changed files with 3 additions and 2 deletions

View File

@ -16,6 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] Shared memory swizzling for output
- [ ] Larger block size on N (up to 256)
- [x] MoE scheduler with TMA multicast compatibility
- [ ] Fix TMA multicast compatibility for indivisible shapes
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)

View File

@ -212,7 +212,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
// requiring additional checks before multicast operations.
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");

View File

@ -43,7 +43,7 @@ struct Scheduler {
}
}
__device__ __forceinline__ bool is_tma_multicast_valid(uint32_t& m_block_idx) {
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {