mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Grouped GEMM skip useless computation for unaligned Ms
This commit is contained in:
parent
391755ada0
commit
ccca476ac4
@ -271,6 +271,16 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (!scheduler.is_valid_m(math_wg_idx * WGMMA::M, m_block_idx)) {
|
||||||
|
// Skip useless computation for unaligned Ms
|
||||||
|
launch_k_iterations([&](int k_iter, auto type, auto _) {
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
||||||
|
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||||
|
empty_barrier_arrive(s);
|
||||||
|
}
|
||||||
|
}, num_former_iters);
|
||||||
|
} else {
|
||||||
// Launch MMAs
|
// Launch MMAs
|
||||||
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
@ -288,7 +298,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
// Wait TMA arrivals
|
// Wait TMA arrivals
|
||||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||||
|
|
||||||
// TODO: remove some useless computation for unaligned Ms
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||||
@ -346,6 +355,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
|||||||
empty_barrier_arrive(s);
|
empty_barrier_arrive(s);
|
||||||
}
|
}
|
||||||
}, num_former_iters);
|
}, num_former_iters);
|
||||||
|
}
|
||||||
|
|
||||||
// TMA checks
|
// TMA checks
|
||||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||||
|
|||||||
@ -48,6 +48,16 @@ struct Scheduler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ bool is_valid_m(const uint32_t m_offset, const uint32_t& m_block_idx) const {
|
||||||
|
if constexpr (kGemmType == GemmType::Normal) {
|
||||||
|
return true;
|
||||||
|
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||||
|
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) != -1;
|
||||||
|
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||||
|
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||||
if (num_blocks_in_group == 1)
|
if (num_blocks_in_group == 1)
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -53,8 +53,8 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
m_aligned = get_m_alignment_for_contiguous_layout()
|
m_aligned = get_m_alignment_for_contiguous_layout()
|
||||||
group_m_list = []
|
group_m_list = []
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned)
|
group_m = random.randint(int(expected_m_per_group * 0.7), int(expected_m_per_group * 1.3))
|
||||||
m += group_m
|
m += ceil_div(group_m, m_aligned) * m_aligned
|
||||||
group_m_list.append(group_m)
|
group_m_list.append(group_m)
|
||||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||||
@ -64,10 +64,12 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
for i, group_m in enumerate(group_m_list):
|
for i, group_m in enumerate(group_m_list):
|
||||||
end = start + group_m
|
actual_end = start + group_m
|
||||||
m_indices[start:end] = i
|
aligned_end = start + ceil_div(group_m, m_aligned) * m_aligned
|
||||||
ref_out[start:end] = x[start:end] @ y[i].t()
|
m_indices[start:actual_end] = i
|
||||||
start = end
|
m_indices[actual_end:aligned_end] = -1
|
||||||
|
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
|
||||||
|
start = aligned_end
|
||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = per_token_cast_to_fp8(x)
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
@ -191,6 +193,8 @@ def test_m_grouped_gemm_contiguous() -> None:
|
|||||||
# TODO: make a stronger test
|
# TODO: make a stronger test
|
||||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
|
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
|
||||||
|
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
|
||||||
diff = calc_diff(out, ref_out)
|
diff = calc_diff(out, ref_out)
|
||||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||||
|
|
||||||
@ -202,24 +206,25 @@ def test_m_grouped_gemm_contiguous() -> None:
|
|||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
|
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
|
sum_m = (m_indices != -1).sum().item()
|
||||||
print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_m_grouped_gemm_masked() -> None:
|
def test_m_grouped_gemm_masked() -> None:
|
||||||
print('Testing grouped masked GEMM:')
|
print('Testing grouped masked GEMM:')
|
||||||
|
|
||||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
m = 4096
|
||||||
|
for num_groups, excepted_m in ((1, 1024), (2, 512), (4, 256)):
|
||||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||||
# Test correctness
|
# Test correctness
|
||||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||||
for j in range(num_groups):
|
for j in range(num_groups):
|
||||||
masked_m[j] = random.choice(masked_m_candidates)
|
masked_m[j] = random.randint(int(excepted_m * 0.7), int(excepted_m * 1.3))
|
||||||
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||||
for j in range(num_groups):
|
for j in range(num_groups):
|
||||||
@ -228,17 +233,20 @@ def test_m_grouped_gemm_masked() -> None:
|
|||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
for j in range(num_groups):
|
||||||
|
masked_m[j] = random.randint(int(excepted_m * 0.7), int(excepted_m * 1.3))
|
||||||
|
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
||||||
|
sum_m = masked_m.sum().item()
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
def test_func():
|
def test_func():
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||||
|
|
||||||
# Test performance with fixed shapes
|
# Test performance with fixed shapes
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user