mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-17 01:31:15 +00:00
Support multicasting on B
This commit is contained in:
parent
742fb1c8a5
commit
7ffb118e54
@ -31,7 +31,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
uint32_t kNumGroups, uint32_t kNumStages,
|
||||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||||
uint32_t kNumTMAMulticast,
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
GemmType kGemmType>
|
GemmType kGemmType>
|
||||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||||
@ -146,7 +146,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// Block scheduler
|
// Block scheduler
|
||||||
uint32_t m_block_idx, n_block_idx;
|
uint32_t m_block_idx, n_block_idx;
|
||||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
|
||||||
|
|
||||||
if (threadIdx.x >= kNumMathThreads) {
|
if (threadIdx.x >= kNumMathThreads) {
|
||||||
// TMA warp-group for loading data
|
// TMA warp-group for loading data
|
||||||
@ -161,6 +161,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||||
|
|
||||||
|
// Assign TMA multicast number into A and B
|
||||||
|
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
|
||||||
|
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
|
||||||
|
|
||||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -168,18 +172,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
// Wait consumer release
|
// Wait consumer release
|
||||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||||
|
|
||||||
// Issue TMA A with broadcasting
|
// Issue TMA A
|
||||||
auto& full_barrier = *full_barriers[s];
|
auto& full_barrier = *full_barriers[s];
|
||||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||||
|
|
||||||
// Issue TMA B without broadcasting
|
// Issue TMA B
|
||||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,7 +351,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
uint32_t kNumGroups, uint32_t kNumStages,
|
||||||
uint32_t kNumTMAMulticast,
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
GemmType kGemmType>
|
GemmType kGemmType>
|
||||||
class Gemm {
|
class Gemm {
|
||||||
private:
|
private:
|
||||||
@ -369,7 +373,7 @@ public:
|
|||||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||||
kNumTMAMulticast, kGemmType>;
|
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||||
|
|
||||||
// Cluster launch
|
// Cluster launch
|
||||||
|
@ -12,9 +12,10 @@ enum class GemmType {
|
|||||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||||
template <GemmType kGemmType,
|
template <GemmType kGemmType,
|
||||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
uint32_t kNumGroups,
|
||||||
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||||
uint32_t kNumNBlocksPerGroup = 16>
|
uint32_t kNum1DBlocksPerGroup = 16>
|
||||||
struct Scheduler {
|
struct Scheduler {
|
||||||
int current_iter = -1;
|
int current_iter = -1;
|
||||||
uint32_t num_aligned_m_blocks;
|
uint32_t num_aligned_m_blocks;
|
||||||
@ -43,16 +44,27 @@ struct Scheduler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||||
|
|
||||||
// Swizzle for better L2 usages
|
// Swizzle for better L2 usages
|
||||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
// TODO: unify these 2 branches
|
||||||
auto group_idx = block_idx / num_blocks_per_group;
|
if constexpr (kIsTMAMulticastOnA) {
|
||||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
|
||||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
auto group_idx = block_idx / num_blocks_per_group;
|
||||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||||
|
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||||
|
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||||
|
} else {
|
||||||
|
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
|
||||||
|
auto group_idx = block_idx / num_blocks_per_group;
|
||||||
|
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||||
|
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
|
||||||
|
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||||
|
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
|
||||||
|
n_block_idx = in_group_idx / num_m_blocks_in_group;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||||
|
@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
|||||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||||
new_template = copy.deepcopy(template)
|
new_template = copy.deepcopy(template)
|
||||||
for key, value in keys.items():
|
for key, value in keys.items():
|
||||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
value_str = str(value)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
value_str = value_str.lower()
|
||||||
|
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
|
||||||
return new_template
|
return new_template
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,9 +15,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
|||||||
constexpr auto BLOCK_N = {BLOCK_N};
|
constexpr auto BLOCK_N = {BLOCK_N};
|
||||||
constexpr auto kNumStages = {NUM_STAGES};
|
constexpr auto kNumStages = {NUM_STAGES};
|
||||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||||
|
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||||
|
|
||||||
// Make a templated GEMM
|
// Make a templated GEMM
|
||||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||||
|
|
||||||
// Launch kernel
|
// Launch kernel
|
||||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||||
@ -31,10 +32,10 @@ GemmType::run(out, rhs_scales, nullptr,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||||
if num_tma_multicast == 1:
|
if num_tma_multicast == 1:
|
||||||
return True
|
return True
|
||||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||||
|
|
||||||
|
|
||||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||||
@ -56,7 +57,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
|||||||
|
|
||||||
|
|
||||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]:
|
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||||
if not is_grouped_contiguous:
|
if not is_grouped_contiguous:
|
||||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||||
block_ms = (64 if m <= 64 else 128, )
|
block_ms = (64 if m <= 64 else 128, )
|
||||||
@ -96,20 +97,27 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
break
|
break
|
||||||
assert best_num_stages is not None
|
assert best_num_stages is not None
|
||||||
|
|
||||||
# Decide the number of TMA multicast
|
# Decide the number of TMA multicast and whether broadcast on A
|
||||||
best_num_tma_multicast = 1
|
best_tma_multicast_config = (1, True)
|
||||||
# When using large block tiling, broadcasting B is required to achieve maximum performance gains.
|
|
||||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
# Try to multicast on the larger block side first
|
||||||
best_num_tma_multicast = 2
|
is_multicast_legal = {
|
||||||
|
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||||
|
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms),
|
||||||
|
}
|
||||||
|
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||||
|
if m >= 1024 and is_multicast_legal[i] and num_groups == 1:
|
||||||
|
best_tma_multicast_config = (2, i == 'A')
|
||||||
|
break
|
||||||
|
|
||||||
# Recompute the minimal number of SMs required
|
# Recompute the minimal number of SMs required
|
||||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||||
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast
|
num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||||
assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms)
|
assert num_min_sms <= num_sms
|
||||||
|
|
||||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
|
||||||
|
|
||||||
|
|
||||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@ -159,12 +167,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
name='gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||||
space=(),
|
space=(),
|
||||||
includes=includes,
|
includes=includes,
|
||||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||||
|
Loading…
Reference in New Issue
Block a user