This commit is contained in:
Chenggang Zhao 2025-05-27 12:00:10 +08:00
parent 81f906ef76
commit 81de208430

View File

@ -34,7 +34,7 @@ struct Scheduler {
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
int* grouped_layout = nullptr) {
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) {
@ -76,7 +76,7 @@ struct Scheduler {
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, uint32_t block_idx,
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
@ -111,7 +111,7 @@ struct Scheduler {
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;