mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Code polishing x3
This commit is contained in:
parent
594020e6af
commit
c99756778d
@ -16,6 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] Shared memory swizzling for output
|
||||
- [ ] Larger block size on N (up to 256)
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [ ] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Weight gradient kernels for dense models
|
||||
- [ ] Weight gradient kernels for MoE models
|
||||
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
||||
|
||||
@ -212,7 +212,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
|
||||
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
|
||||
// requiring additional checks before multicast operations.
|
||||
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
@ -43,7 +43,7 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(uint32_t& m_block_idx) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user