mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 15:34:22 +00:00
Support TMA multicast on B with m_grouped_gemm_contiguous. (#88)
This commit is contained in:
parent
83aa960b9b
commit
891f35adf5
@ -12,10 +12,11 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
|||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
- [ ] More correctness tests for grouped-contiguous layout
|
- [x] More correctness tests for grouped-contiguous layout
|
||||||
- [x] Shared memory swizzling for output
|
- [x] Shared memory swizzling for output
|
||||||
- [ ] Larger block size on N (up to 256)
|
- [ ] Larger block size on N (up to 256)
|
||||||
- [ ] MoE scheduler with TMA multicast compatibility
|
- [x] MoE scheduler with TMA multicast compatibility
|
||||||
|
- [ ] Fix TMA multicast compatibility for indivisible shapes
|
||||||
- [ ] Weight gradient kernels for dense models
|
- [ ] Weight gradient kernels for dense models
|
||||||
- [ ] Weight gradient kernels for MoE models
|
- [ ] Weight gradient kernels for MoE models
|
||||||
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
|
||||||
|
@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// Prefetch TMA descriptors at very beginning
|
// Prefetch TMA descriptors at very beginning
|
||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||||
|
|
||||||
// `tensor_map_d` is only used in swizzling mode
|
// `tensor_map_d` is only used in swizzling mode
|
||||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||||
if constexpr (kSwizzleDMode > 0)
|
if constexpr (kSwizzleDMode > 0)
|
||||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
@ -212,8 +212,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||||
|
|
||||||
// Issue TMA B
|
// Issue TMA B
|
||||||
|
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
|
||||||
|
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
|
||||||
|
// requiring additional checks before multicast operations.
|
||||||
|
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||||
|
} else {
|
||||||
|
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
|
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||||
|
}
|
||||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,18 @@ struct Scheduler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
|
||||||
|
if constexpr (kGemmType == GemmType::Normal) {
|
||||||
|
return true;
|
||||||
|
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||||
|
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||||
|
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||||
|
return group_idx == peer_group_idx;
|
||||||
|
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||||
|
|
||||||
@ -72,10 +84,10 @@ struct Scheduler {
|
|||||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||||
if constexpr (kGemmType == GemmType::Normal) {
|
if constexpr (kGemmType == GemmType::Normal) {
|
||||||
return block_idx * block_size;
|
return block_idx * block_size;
|
||||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||||
return offset * shape_dim + block_idx * block_size;
|
return offset * shape_dim + block_idx * block_size;
|
||||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,10 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
best_tma_multicast_config = (1, True)
|
best_tma_multicast_config = (1, True)
|
||||||
|
|
||||||
# Try to multicast on the larger block side first
|
# Try to multicast on the larger block side first
|
||||||
is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked)
|
|
||||||
is_multicast_legal = {
|
is_multicast_legal = {
|
||||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm,
|
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
|
||||||
}
|
}
|
||||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||||
if m >= 512 and is_multicast_legal[i]:
|
if m >= 512 and is_multicast_legal[i]:
|
||||||
|
@ -4,6 +4,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||||
|
from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
|
||||||
|
|
||||||
|
|
||||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -38,7 +39,38 @@ def construct(m: int, k: int, n: int) -> \
|
|||||||
return x_fp8, y_fp8, out, ref_out
|
return x_fp8, y_fp8, out, ref_out
|
||||||
|
|
||||||
|
|
||||||
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
|
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||||
|
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
m = 0
|
||||||
|
m_aligned = get_m_alignment_for_contiguous_layout()
|
||||||
|
group_m_list = []
|
||||||
|
for i in range(num_groups):
|
||||||
|
group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned)
|
||||||
|
m += group_m
|
||||||
|
group_m_list.append(group_m)
|
||||||
|
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||||
|
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
|
||||||
|
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||||
|
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for i, group_m in enumerate(group_m_list):
|
||||||
|
end = start + group_m
|
||||||
|
m_indices[start:end] = i
|
||||||
|
ref_out[start:end] = x[start:end] @ y[i].t()
|
||||||
|
start = end
|
||||||
|
|
||||||
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||||
|
for i in range(num_groups):
|
||||||
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
|
|
||||||
|
return m, x_fp8, y_fp8, m_indices, out, ref_out
|
||||||
|
|
||||||
|
|
||||||
|
def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
||||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||||
@ -52,11 +84,6 @@ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool)
|
|||||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
|
|
||||||
# For non-masked input, we must merge the group and M dims
|
|
||||||
if not is_masked:
|
|
||||||
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
|
|
||||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
|
||||||
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||||
return x_fp8, y_fp8, out, ref_out
|
return x_fp8, y_fp8, out, ref_out
|
||||||
@ -88,28 +115,24 @@ def test_gemm() -> None:
|
|||||||
def test_m_grouped_gemm_contiguous() -> None:
|
def test_m_grouped_gemm_contiguous() -> None:
|
||||||
print('Testing grouped contiguous GEMM:')
|
print('Testing grouped contiguous GEMM:')
|
||||||
|
|
||||||
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||||
# TODO: make a stronger test
|
# TODO: make a stronger test
|
||||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
|
||||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
diff = calc_diff(out, ref_out)
|
diff = calc_diff(out, ref_out)
|
||||||
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
|
||||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
def test_func():
|
def test_func():
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
|
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +144,7 @@ def test_m_grouped_gemm_masked() -> None:
|
|||||||
# Test correctness
|
# Test correctness
|
||||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||||
for j in range(num_groups):
|
for j in range(num_groups):
|
||||||
masked_m[j] = random.choice(masked_m_candidates)
|
masked_m[j] = random.choice(masked_m_candidates)
|
||||||
@ -132,7 +155,7 @@ def test_m_grouped_gemm_masked() -> None:
|
|||||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
||||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
|
Loading…
Reference in New Issue
Block a user