Some lints and refactor

This commit is contained in:
Chenggang Zhao
2025-05-06 17:23:35 +08:00
parent 8aff6309d4
commit 981cc58932
18 changed files with 421 additions and 449 deletions

View File

@@ -1,3 +1,2 @@
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
from .template import generate
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
from .runtime import Runtime

View File

@@ -1,4 +1,3 @@
import abc
import functools
import hashlib
import os
@@ -14,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
from .runtime import Runtime, RuntimeCache
runtime_cache = RuntimeCache()
@@ -32,11 +31,11 @@ def get_jit_include_dir() -> str:
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
md5 = hashlib.md5()
# Update include directories
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read())
@@ -98,24 +97,20 @@ def make_tmp_dir():
def put(path, data):
is_binary = isinstance(data, bytes)
# Write and do POSIX atomic replace
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
class Compiler(abc.ABC):
class Compiler:
@staticmethod
@abc.abstractmethod
def __version__() -> Tuple[int, int]:
pass
@classmethod
@abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str) -> None:
pass
@staticmethod
@@ -132,13 +127,12 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
# Compiler flags
flags = cls.flags()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
@@ -147,7 +141,7 @@ class Compiler(abc.ABC):
global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build')
return cached_runtime
@@ -160,9 +154,8 @@ class Compiler(abc.ABC):
cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
@@ -177,12 +170,12 @@ class Compiler(abc.ABC):
return runtime
class NvccCompiler(Compiler):
class NVCCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
return major, minor
@classmethod
def flags(cls) -> List[str]:
@@ -197,7 +190,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str):
def compile(cls, name: str, code: str, target_path: str) -> None:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
@@ -205,26 +198,23 @@ class NvccCompiler(Compiler):
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with command {command}')
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
if os.getenv('DG_JIT_DEBUG', None):
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f'Failed to compile {src_path}'
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
assert False, f'Failed to compile {src_path}'
class NvrtcCompiler(Compiler):
class NVRTCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get actual NVRTC version, use bindings version instead
# Failed to get the actual NVRTC version, use cuda-bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor)
return major, minor
@staticmethod
def include_dirs() -> List[str]:
@@ -238,54 +228,51 @@ class NvrtcCompiler(Compiler):
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch']
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str) -> None:
# Create program
code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram(
result, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to create program: {res}')
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
# Compile
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with options: {options}')
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
# Print compiler log
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get program log size: {res}')
log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get program log: {res}')
log_str = log_bytes.decode('utf-8')
print(log_str)
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
print(f'Compiler log: {log_bytes.decode("utf-8")}')
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to compile program: {compile_res}')
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get CUBIN size: {res}')
# Exit if failed
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
# Create CUBIN
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to get CUBIN: {res}')
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
# Write into the file system
put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f'Failed to destroy program: {res}')
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
else:
return NvccCompiler.build(name, code, runtime_cls=runtime_cls)
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls=runtime_cls)

View File

@@ -37,7 +37,7 @@ def extract_ffma(sass):
collected.append((f'{arch_name}::{func_name}', current))
current = []
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f'Found {len(collected)} FFMA segments')
return collected
@@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines):
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset
@@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines):
def process(path):
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)

View File

@@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from .utils import run_gemm
class Runtime:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
def __init__(self, path: str, kernel_name: str = None,
caller: Callable[..., cuda.CUresult] = None,
args: List[str] = None) -> None:
self.path = path
self.lib = None
self.kernel = None
@@ -27,7 +27,7 @@ class Runtime:
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
def __call__(self, **kwargs) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
@@ -59,9 +59,8 @@ class Runtime:
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller(
self.kernel,
@@ -75,25 +74,6 @@ class Runtime:
raise Exception(f'Failed to unload library {self.path}: {res}')
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
@@ -101,14 +81,14 @@ class RuntimeCache:
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None
return None

View File

@@ -1,51 +0,0 @@
import os
from typing import Any, Dict
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <string>
#include <cuda.h>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
# Debug print
if os.getenv('DG_JIT_DEBUG', None):
print(f'Generated code:\n{code}')
return code

View File

@@ -1,164 +0,0 @@
import ctypes
from enum import Enum
from typing import Any, Dict, Tuple
import cuda.bindings.driver as cuda
import torch
class Layout(Enum):
RowMajor = 0
ColMajor = 1
class GemmType(Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
typename_map: Dict[Any, str] = {
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_map = {
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
tensor_dtype = typename_map[global_address.dtype]
res, tensor_map = cuda.cuTensorMapEncodeTiled(
tensor_dtype,
2, # tensor rank
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes,), # global strides
smem_dim,
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to encode tensor map: {res}')
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Make TMA aligned to 16 bytes
kAlignment = 16 / global_address.element_size()
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cuda.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
kernelValues = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)