mirror of
				https://github.com/deepseek-ai/DeepGEMM
				synced 2025-06-26 23:15:49 +00:00 
			
		
		
		
	Support TMA multicast on B with m_grouped_gemm_contiguous. (#88)
This commit is contained in:
		
							parent
							
								
									83aa960b9b
								
							
						
					
					
						commit
						891f35adf5
					
				| @ -12,10 +12,11 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert | |||||||
| 
 | 
 | ||||||
| ## Roadmap | ## Roadmap | ||||||
| 
 | 
 | ||||||
| - [ ] More correctness tests for grouped-contiguous layout | - [x] More correctness tests for grouped-contiguous layout | ||||||
| - [x] Shared memory swizzling for output | - [x] Shared memory swizzling for output | ||||||
| - [ ] Larger block size on N (up to 256) | - [ ] Larger block size on N (up to 256) | ||||||
| - [ ] MoE scheduler with TMA multicast compatibility | - [x] MoE scheduler with TMA multicast compatibility | ||||||
|  | - [ ] Fix TMA multicast compatibility for indivisible shapes | ||||||
| - [ ] Weight gradient kernels for dense models | - [ ] Weight gradient kernels for dense models | ||||||
| - [ ] Weight gradient kernels for MoE models | - [ ] Weight gradient kernels for MoE models | ||||||
| - [ ] Utility kernels for MoE models (as a pre-built CUDA library) | - [ ] Utility kernels for MoE models (as a pre-built CUDA library) | ||||||
|  | |||||||
| @ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, | |||||||
| 
 | 
 | ||||||
|     // Prefetch TMA descriptors at very beginning |     // Prefetch TMA descriptors at very beginning | ||||||
|     if (threadIdx.x == kNumMathThreads) { |     if (threadIdx.x == kNumMathThreads) { | ||||||
|         cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a)); |         cute::prefetch_tma_descriptor(&tensor_map_a); | ||||||
|         cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b)); |         cute::prefetch_tma_descriptor(&tensor_map_b); | ||||||
|         cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a)); |         cute::prefetch_tma_descriptor(&tensor_map_scales_a); | ||||||
| 
 | 
 | ||||||
|         // `tensor_map_d` is only used in swizzling mode |         // `tensor_map_d` is only used in swizzling mode | ||||||
|         // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode |         // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode | ||||||
|         if constexpr (kSwizzleDMode > 0) |         if constexpr (kSwizzleDMode > 0) | ||||||
|             cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d)); |             cute::prefetch_tma_descriptor(&tensor_map_d); | ||||||
|     } |     } | ||||||
|     __syncwarp(); |     __syncwarp(); | ||||||
| 
 | 
 | ||||||
| @ -212,8 +212,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, | |||||||
|                                                       scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); |                                                       scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); | ||||||
| 
 | 
 | ||||||
|                         // Issue TMA B |                         // Issue TMA B | ||||||
|  |                         if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) { | ||||||
|  |                             // NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B), | ||||||
|  |                             // requiring additional checks before multicast operations. | ||||||
|  |                             DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast"); | ||||||
|                             tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), |                             tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), | ||||||
|                                                           smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); |                                                           smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); | ||||||
|  |                         } else { | ||||||
|  |                             tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), | ||||||
|  |                                      smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); | ||||||
|  |                         } | ||||||
|                         full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); |                         full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -43,6 +43,18 @@ struct Scheduler { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     __device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) { | ||||||
|  |         if constexpr (kGemmType == GemmType::Normal) { | ||||||
|  |             return true; | ||||||
|  |         } else if constexpr (kGemmType == GemmType::GroupedContiguous) { | ||||||
|  |             auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); | ||||||
|  |             auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); | ||||||
|  |             return group_idx == peer_group_idx; | ||||||
|  |         } else if constexpr (kGemmType == GemmType::GroupedMasked) { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { |     __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { | ||||||
|         DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); |         DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); | ||||||
| 
 | 
 | ||||||
| @ -72,10 +84,10 @@ struct Scheduler { | |||||||
|                                                        const uint32_t& block_idx, const uint32_t& m_block_idx=0) { |                                                        const uint32_t& block_idx, const uint32_t& m_block_idx=0) { | ||||||
|         if constexpr (kGemmType == GemmType::Normal) { |         if constexpr (kGemmType == GemmType::Normal) { | ||||||
|             return block_idx * block_size; |             return block_idx * block_size; | ||||||
|         } else if (kGemmType == GemmType::GroupedContiguous) { |         } else if constexpr (kGemmType == GemmType::GroupedContiguous) { | ||||||
|             auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); |             auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); | ||||||
|             return offset * shape_dim + block_idx * block_size; |             return offset * shape_dim + block_idx * block_size; | ||||||
|         } else if (kGemmType == GemmType::GroupedMasked) { |         } else if constexpr (kGemmType == GemmType::GroupedMasked) { | ||||||
|             return curr_group_idx * shape_dim + block_idx * block_size; |             return curr_group_idx * shape_dim + block_idx * block_size; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -146,10 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, | |||||||
|     best_tma_multicast_config = (1, True) |     best_tma_multicast_config = (1, True) | ||||||
| 
 | 
 | ||||||
|     # Try to multicast on the larger block side first |     # Try to multicast on the larger block side first | ||||||
|     is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked) |  | ||||||
|     is_multicast_legal = { |     is_multicast_legal = { | ||||||
|         'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), |         'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), | ||||||
|         'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm, |         'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked), | ||||||
|     } |     } | ||||||
|     for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): |     for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): | ||||||
|         if m >= 512 and is_multicast_legal[i]: |         if m >= 512 and is_multicast_legal[i]: | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ from typing import Tuple | |||||||
| 
 | 
 | ||||||
| import deep_gemm | import deep_gemm | ||||||
| from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor | from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor | ||||||
|  | from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
| @ -38,7 +39,38 @@ def construct(m: int, k: int, n: int) -> \ | |||||||
|     return x_fp8, y_fp8, out, ref_out |     return x_fp8, y_fp8, out, ref_out | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \ | def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ | ||||||
|  |         Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
|  |     m = 0 | ||||||
|  |     m_aligned = get_m_alignment_for_contiguous_layout() | ||||||
|  |     group_m_list = [] | ||||||
|  |     for i in range(num_groups): | ||||||
|  |         group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned) | ||||||
|  |         m += group_m | ||||||
|  |         group_m_list.append(group_m) | ||||||
|  |     x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) | ||||||
|  |     y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)     | ||||||
|  |     m_indices = torch.empty(m, device='cuda', dtype=torch.int32) | ||||||
|  |     out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) | ||||||
|  |     ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) | ||||||
|  | 
 | ||||||
|  |     start = 0 | ||||||
|  |     for i, group_m in enumerate(group_m_list): | ||||||
|  |         end = start + group_m | ||||||
|  |         m_indices[start:end] = i | ||||||
|  |         ref_out[start:end] = x[start:end] @ y[i].t() | ||||||
|  |         start = end | ||||||
|  | 
 | ||||||
|  |     assert m % 4 == 0, f'TMA alignment error: {m}' | ||||||
|  |     x_fp8 = per_token_cast_to_fp8(x) | ||||||
|  |     y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) | ||||||
|  |     for i in range(num_groups): | ||||||
|  |         y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) | ||||||
|  | 
 | ||||||
|  |     return m, x_fp8, y_fp8, m_indices, out, ref_out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ | ||||||
|         Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: |         Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: | ||||||
|     x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) |     x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) | ||||||
|     y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) |     y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) | ||||||
| @ -52,11 +84,6 @@ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) | |||||||
|         x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) |         x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) | ||||||
|         y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) |         y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) | ||||||
| 
 | 
 | ||||||
|     # For non-masked input, we must merge the group and M dims |  | ||||||
|     if not is_masked: |  | ||||||
|         x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) |  | ||||||
|         out, ref_out = out.view(-1, n), ref_out.view(-1, n) |  | ||||||
| 
 |  | ||||||
|     # Transpose earlier so that the testing will not trigger transposing kernels |     # Transpose earlier so that the testing will not trigger transposing kernels | ||||||
|     x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) |     x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) | ||||||
|     return x_fp8, y_fp8, out, ref_out |     return x_fp8, y_fp8, out, ref_out | ||||||
| @ -88,28 +115,24 @@ def test_gemm() -> None: | |||||||
| def test_m_grouped_gemm_contiguous() -> None: | def test_m_grouped_gemm_contiguous() -> None: | ||||||
|     print('Testing grouped contiguous GEMM:') |     print('Testing grouped contiguous GEMM:') | ||||||
| 
 | 
 | ||||||
|     for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): |     for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): | ||||||
|         # TODO: make a stronger test |         # TODO: make a stronger test | ||||||
|         x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) |         m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) | ||||||
|         m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) |  | ||||||
|         m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) |  | ||||||
|         deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) |         deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) | ||||||
|         diff = calc_diff(out, ref_out) |         diff = calc_diff(out, ref_out) | ||||||
|         assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}' |         assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' | ||||||
| 
 | 
 | ||||||
|         # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) |         # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) | ||||||
|         x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) |         m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) | ||||||
|         m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) |  | ||||||
|         m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) |  | ||||||
| 
 | 
 | ||||||
|         # noinspection PyShadowingNames |         # noinspection PyShadowingNames | ||||||
|         def test_func(): |         def test_func(): | ||||||
|             deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) |             deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) | ||||||
| 
 | 
 | ||||||
|         t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) |         t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) | ||||||
|         print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' |         print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' | ||||||
|               f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' |               f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' | ||||||
|               f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') |               f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') | ||||||
|     print() |     print() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -121,7 +144,7 @@ def test_m_grouped_gemm_masked() -> None: | |||||||
|             # Test correctness |             # Test correctness | ||||||
|             masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) |             masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) | ||||||
|             for i in range(10): |             for i in range(10): | ||||||
|                 x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) |                 x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) | ||||||
|                 masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) |                 masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) | ||||||
|                 for j in range(num_groups): |                 for j in range(num_groups): | ||||||
|                     masked_m[j] = random.choice(masked_m_candidates) |                     masked_m[j] = random.choice(masked_m_candidates) | ||||||
| @ -132,7 +155,7 @@ def test_m_grouped_gemm_masked() -> None: | |||||||
|                     assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' |                     assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' | ||||||
| 
 | 
 | ||||||
|             # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) |             # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) | ||||||
|             x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) |             x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) | ||||||
|             masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m |             masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m | ||||||
| 
 | 
 | ||||||
|             # noinspection PyShadowingNames |             # noinspection PyShadowingNames | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user