mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Grouped GEMM skip useless computation for unaligned Ms (#103)
* Grouped GEMM skip useless computation for unaligned Ms * Update readme.md * small typo * Rename variables * Restore previous indent * Format * Refactor tests * Add `SkipComputation` types * Bug fixed * Format * Fix tests * Add assertions * Minor fix --------- Co-authored-by: yukuai <yukuai@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
parent
391755ada0
commit
8dfa329827
@ -19,7 +19,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [ ] Larger block size on N (up to 256)
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Skip useless computation on M
|
||||
- [x] Skip useless computation on M
|
||||
- [x] NVRTC as a faster compiler
|
||||
- [ ] Stolen JIT cache
|
||||
- [ ] Sanitizer for testing
|
||||
|
@ -17,8 +17,8 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <int kNumFormerIters, int kGap, int kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
|
||||
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
|
||||
return;
|
||||
@ -54,7 +54,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Shared memory
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
@ -101,7 +101,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
@ -111,7 +111,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
@ -122,7 +122,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
@ -138,28 +138,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
|
||||
struct SkipComputation {};
|
||||
struct NotSkipComputation {};
|
||||
auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
// NOTES: for too-many branches (> 5), we disable this optimization
|
||||
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
|
||||
if (skip_computation) {
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
|
||||
} else if (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, num_former_iters_type);
|
||||
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
}
|
||||
}, func, kShouldOptimize ? num_former_iters : 0);
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
constexpr uint32_t kNumTMARegisters = 40;
|
||||
constexpr uint32_t kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
@ -173,10 +178,9 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
@ -194,7 +198,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
@ -216,7 +220,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
}, 0);
|
||||
}, false, 0);
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
@ -257,12 +261,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
auto empty_barrier_arrive = [&](uint32_t s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
@ -272,13 +276,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
|
||||
constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
|
||||
(kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
@ -300,18 +305,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
@ -328,7 +333,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
@ -345,7 +350,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
}, num_former_iters);
|
||||
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
@ -355,7 +360,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
|
||||
DG_STATIC_ASSERT(static_cast<uint32_t>(kSwizzleDMode > 0) + static_cast<uint32_t>(BLOCK_N_PADDING > 0) <= 1,
|
||||
"Swizzling and padding are not compatible");
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
@ -375,7 +380,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr int kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
@ -436,4 +441,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
#pragma clang diagnostic pop
|
@ -34,7 +34,7 @@ struct Scheduler {
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
@ -48,6 +48,17 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
@ -65,7 +76,7 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
|
||||
uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
@ -100,7 +111,7 @@ struct Scheduler {
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
|
@ -121,7 +121,7 @@ class Compiler:
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,161,174,177,940']
|
||||
'--diag-suppress=39,161,174,177,186,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
|
@ -49,13 +49,10 @@ def construct(m: int, k: int, n: int) -> \
|
||||
|
||||
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
m = 0
|
||||
m_aligned = get_m_alignment_for_contiguous_layout()
|
||||
group_m_list = []
|
||||
for i in range(num_groups):
|
||||
group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned)
|
||||
m += group_m
|
||||
group_m_list.append(group_m)
|
||||
alignment = get_m_alignment_for_contiguous_layout()
|
||||
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
|
||||
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
|
||||
@ -63,11 +60,14 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
||||
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
start = 0
|
||||
for i, group_m in enumerate(group_m_list):
|
||||
end = start + group_m
|
||||
m_indices[start:end] = i
|
||||
ref_out[start:end] = x[start:end] @ y[i].t()
|
||||
start = end
|
||||
for i, group_m in enumerate(group_ms):
|
||||
actual_end = start + group_m
|
||||
aligned_end = start + ceil_div(group_m, alignment) * alignment
|
||||
m_indices[start:actual_end] = i
|
||||
m_indices[actual_end:aligned_end] = -1
|
||||
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
|
||||
start = aligned_end
|
||||
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
@ -78,15 +78,15 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
||||
return m, x_fp8, y_fp8, m_indices, out, ref_out
|
||||
|
||||
|
||||
def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
||||
def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
assert max_m % 4 == 0, f'TMA alignment error: {max_m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
@ -94,7 +94,13 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
# Construct mask
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
assert masked_m.amax().item() <= max_m
|
||||
return x_fp8, y_fp8, masked_m, out, ref_out
|
||||
|
||||
|
||||
def construct_wgrad(m: int, k: int, n: int) -> \
|
||||
@ -170,15 +176,12 @@ def test_gemm() -> None:
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
@ -187,58 +190,52 @@ def test_gemm() -> None:
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||
# TODO: make a stronger test
|
||||
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168),
|
||||
(8, 4096, 7168, 4096), (8, 4096, 2048, 7168),
|
||||
(32, 256, 7168, 4096), (32, 256, 2048, 7168)):
|
||||
# NOTES: we should mask the unfilled part before calculating difference
|
||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
valid_m = (m_indices != -1).sum().item()
|
||||
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing grouped masked GEMM:')
|
||||
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)):
|
||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||
# Test correctness
|
||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
||||
for i in range(10):
|
||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||
x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
# noinspection PyUnboundLocalVariable
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
@ -312,4 +309,4 @@ if __name__ == '__main__':
|
||||
test_m_grouped_gemm_masked()
|
||||
|
||||
test_wgrad_gemm()
|
||||
test_k_grouped_wgrad_gemm()
|
||||
test_k_grouped_wgrad_gemm()
|
||||
|
Loading…
Reference in New Issue
Block a user