support group_gemm_offset, group_gemm_offset_swapAB

This commit is contained in:
Wangzheee
2025-06-19 14:51:38 +00:00
parent 0c88cd0139
commit d29b20cd16
10 changed files with 1649 additions and 252 deletions

View File

@@ -1,7 +1,8 @@
from .gemm import gemm_fp8_fp8_bf16_nt
from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
m_grouped_gemm_fp8_fp8_bf16_nt_offset
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,

View File

@@ -34,42 +34,71 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> Tuple[int, int, int]:
assert block_k == 128
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = num_stages * 8 * 2
if not is_swap_ab:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = num_stages * 8 * 2
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
return smem_size, swizzle_mode, block_n_padding
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
return smem_size, swizzle_mode, block_n_padding
else:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales
smem_b_per_stage = block_n * block_k
smem_scales_b_per_stage = 0 # swap_ab not support wgrad
smem_scales_b = ceil_div(block_n * 4, 128) * 128 # swap_ab not support wgrad
smem_barrier = num_stages * 8 * 2
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_b
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
return smem_size, swizzle_mode, block_n_padding
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
@@ -119,7 +148,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad, is_swap_ab = is_swap_ab)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
@@ -131,21 +160,39 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Try to multicast on the larger block side first
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
if not is_swap_ab:
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
else:
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_m, 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_n, 2, num_sms),
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if n >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'B')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_n, best_block_m)
num_min_sms = ceil_div(ceil_div(n, best_block_m) * ceil_div(m, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config

View File

@@ -4,10 +4,12 @@ from typing import Tuple
from ..jit import build
from .gemm import get_best_configs
from .runtime import (
FP8GemmRuntime, GemmType,
FP8GemmRuntime, FP8GemmOffsetRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
make_2d_tma_d_desc, make_2d_tma_scales_desc,
make_2d_tma_scales_a_offset_desc,
make_2d_tma_a_offset_desc_swapAB, make_2d_tma_b_offset_desc_swapAB, make_2d_tma_d_offset_desc_swapAB, make_2d_tma_scales_b_offset_desc_swapAB)
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms, compute_padded_offset
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -203,3 +205,163 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
offsets: torch.Tensor,
out: torch.Tensor, expected_m: int) -> None:
"""
GroupedWithOffset from TensorRT-LLM
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
num_groups, n, k_ = rhs.shape
m_, n_ = out.shape
print("expected_m: ",expected_m)
print("A shape: ",lhs.shape)
print("A scale shape: ",lhs_scales.shape)
print("B shape: ",rhs.shape)
print("B scale shape: ",rhs_scales.shape)
print("out shape: ",out.shape)
# Type and shape checks
assert m == m_ and n == n_ and k == k_
max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4
max_shape_m_32_align_padded = compute_padded_offset(m, num_groups)
assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0
# if compute_padded_offset ?
#assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
num_sms = get_num_sms()
if num_sms==78:
m_per_expert_threshold = 64 # H20
else:
m_per_expert_threshold = 32 # H100
if expected_m>= m_per_expert_threshold:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle
tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle
kwargs = {
# Templated arguments
'KERNEL_NAME': 'fp8_gemm_offset_kernel',
'SCHEDULER_TYPE': 'SchedulerSelector',
'INPUT_TYPE': 'GroupedWithOffsetSchedulerInput',
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedWithOffset,
# Runtime arguments
'SCALES': rhs_scales,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index,
'OUT': out
}
else:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True)
# Extra checks for TMA store
if num_groups > 1 and n > block_m:
assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
print("is_swap_ab=True =========")
print("num_sms: ",num_sms)
print("block_m: ",block_m)
print("block_n: ",block_n)
print("num_stages: ",num_stages)
print("tma_multicast_config: ",tma_multicast_config)
print("smem_config: ",smem_config)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle
tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle
kwargs = {
# Templated arguments
'KERNEL_NAME': 'fp8_gemm_offset_kernel_swapAB',
'SCHEDULER_TYPE': 'SchedulerSelectorSwapAB',
'INPUT_TYPE': 'GroupedWithOffsetSchedulerInputSwapAB',
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedWithOffset,
# Runtime arguments
'SCALES': rhs_scales,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index,
'OUT': out
}
# Generate, build and run the kernel
code = FP8GemmOffsetRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt_offset', code, FP8GemmOffsetRuntime, kwargs)
runtime(**kwargs)

View File

@@ -5,7 +5,7 @@ import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict, Tuple
from .utils import get_tma_aligned_size
from .utils import get_tma_aligned_size, ceil_div
from ..jit.runtime import Runtime
@@ -13,12 +13,15 @@ class GemmType(enum.Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
GroupedWithOffset = 3
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
3: 'GroupedWithOffset',
}[self.value]
@@ -133,6 +136,58 @@ def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_a_offset_desc(gemm_type: GemmType, t: torch.Tensor,
max_m_padded_total: int, shape_k: int,
block_m: int, block_k: int,
global_stride_in_bytes: int = 0) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
max_m_padded_total, ceil_div(shape_k, block_k), max_m_padded_total,
block_m, 1,
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_a_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_k: int, m_stride: int,
block_m: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride,
block_k, block_m)
def make_2d_tma_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
shape_n: int, shape_k: int, n_stride: int,
block_n: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_n * (num_groups if gemm_type == GemmType.GroupedMasked else 1), n_stride,
block_k, block_n)
def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_n: int, m_stride: int,
block_m: int, block_n: int,
num_groups: int,
swizzle_mode: int) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(t,
shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride,
min(block_n, shape_n), min(block_m, shape_m),
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
max_n_padded_total: int, shape_k: int,
block_n: int, block_k: int,
global_stride_in_bytes: int = 0) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
max_n_padded_total, ceil_div(shape_k, block_k), max_n_padded_total,
block_n, 1,
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path)
@@ -316,3 +371,101 @@ static void __instantiate_kernel() {{
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
class FP8GemmOffsetRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path)
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
using SchedulerType =
typename {kwargs['SCHEDULER_TYPE']} <GemmType::GroupedWithOffset, {kwargs['N']},
{kwargs['K']}, {kwargs['BLOCK_M']}, {kwargs['BLOCK_N']},
{kwargs['BLOCK_K']}, {kwargs['NUM_GROUPS']}, {kwargs['NUM_TMA_MULTICAST']}>::type;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&{kwargs['KERNEL_NAME']}<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
SchedulerType,
{kwargs['INPUT_TYPE']}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = kwargs['STREAM']
arg_values = (
kwargs['OUT'].data_ptr(),
kwargs['SCALES'].data_ptr(),
kwargs['PROBLEM_OFFSETS'].data_ptr(),
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES'],
kwargs['TENSOR_MAP_D'],
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)

View File

@@ -107,3 +107,6 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def compute_padded_offset(offset, idx_problem, alignment=32):
return (offset + idx_problem * (alignment - 1)) // alignment * alignment