diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 012f4da..50d8d42 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,19 +1,18 @@ import os import time -from typing import Any, Callable, Dict, List, Optional, Type +import cuda.bindings.driver as cbd +from typing import List, Optional, Type import cuda.bindings.driver as cuda class Runtime: def __init__(self, path: str, kernel_name: str = None, - caller: Callable[..., cuda.CUresult] = None, args: List[str] = None) -> None: self.path = path self.lib = None self.kernel = None self.kernel_name = kernel_name - self.caller = caller self.args = args assert self.is_path_valid(self.path) @@ -27,6 +26,14 @@ class Runtime: files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) + @staticmethod + def generate(**kwargs) -> str: + raise NotImplemented + + @staticmethod + def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: + raise NotImplemented + def __call__(self, **kwargs) -> cuda.CUresult: # Load CUBIN if self.kernel is None: @@ -62,7 +69,8 @@ class Runtime: if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') - return self.caller( + # noinspection PyArgumentList + return self.launch( self.kernel, *[kwargs[arg] for arg in self.args] ) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 825f313..fa15418 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -3,8 +3,10 @@ import torch from functools import lru_cache from typing import Tuple -from .runtime import FP8GemmRuntime, generate -from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc +from .runtime import ( + FP8GemmRuntime, GemmType, + make_2d_tma_a_desc, make_2d_tma_b_desc, + make_2d_tma_d_desc, make_2d_tma_scales_a_desc) from .tuner import jit_tuner from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout @@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), kwargs=kwargs, - generator=generate, runtime_cls=FP8GemmRuntime, ) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 3d12ff6..24a2183 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -2,8 +2,10 @@ import torch from typing import Tuple from .gemm import get_best_configs -from .runtime import FP8GemmRuntime, generate -from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc +from .runtime import ( + FP8GemmRuntime, GemmType, + make_2d_tma_a_desc, make_2d_tma_b_desc, + make_2d_tma_d_desc, make_2d_tma_scales_a_desc) from .tuner import jit_tuner from .utils import get_col_major_tma_aligned_tensor, get_num_sms @@ -103,7 +105,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten 'GEMM_TYPE': GemmType.GroupedContiguous}, space=(), kwargs=kwargs, - generator=generate, runtime_cls=FP8GemmRuntime, ) @@ -209,7 +210,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] 'GEMM_TYPE': GemmType.GroupedMasked}, space=(), kwargs=kwargs, - generator=generate, runtime_cls=FP8GemmRuntime, ) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 225b27a..8860847 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -8,66 +8,6 @@ from typing import Any, Dict, Tuple from ..jit.runtime import Runtime -def generate(**kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - -#include -#include - -#include - -using namespace deep_gemm; - -__global__ void dummy_kernel() {{ - void *ptr = (void *)&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >; -}} -''' - - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated code:\n{code}') - return code - - -class FP8GemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path, 'fp8_gemm', launch, [ - 'NUM_TMA_MULTICAST', - 'M', - 'BLOCK_M', - 'GMEM_D', - 'SCALES_B', - 'GROUPED_LAYOUT', - 'NUM_SMS', - 'SMEM_SIZE', - 'TENSOR_MAP_A', - 'TENSOR_MAP_B', - 'TENSOR_MAP_SCALES_A', - 'TENSOR_MAP_D', - 'STREAM', - ]) - - class Layout(enum.Enum): RowMajor = 0 ColMajor = 1 @@ -200,57 +140,117 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) -def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, - block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, - grouped_layout: torch.Tensor, num_sms: int, smem_size: int, - tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, - tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, - stream: cbd.CUstream) -> cbd.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 +class FP8GemmRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, 'fp8_gemm', [ + 'NUM_TMA_MULTICAST', + 'M', + 'BLOCK_M', + 'GMEM_D', + 'SCALES_B', + 'GROUPED_LAYOUT', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_D', + 'STREAM', + ]) - res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to set max dynamic shared memory size: {res}') + @staticmethod + def generate(**kwargs) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif - attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = num_tma_multicast - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cbd.CUlaunchAttribute() - attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val +#include +#include - config = cbd.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = num_sms - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = smem_size - config.hStream = stream +#include - arg_values = ( - gmem_d.data_ptr(), - scales_b.data_ptr(), - grouped_layout.data_ptr(), - shape_m, - tensor_map_a, - tensor_map_b, - tensor_map_scales_a, - tensor_map_d, - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - None, - None, - None, - None, - ) - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) +using namespace deep_gemm; + +__global__ void dummy_kernel() {{ + void *ptr = (void *)&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >; +}} +''' + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated FP8 GEMM code:\n{code}') + return code + + # noinspection PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, + block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, + grouped_layout: torch.Tensor, num_sms: int, smem_size: int, + tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap, + tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap, + stream: cbd.CUstream) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] + if res != cbd.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to set max dynamic shared memory size: {res}') + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = num_tma_multicast + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = num_sms + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = smem_size + config.hStream = stream + + arg_values = ( + gmem_d.data_ptr(), + scales_b.data_ptr(), + grouped_layout.data_ptr(), + shape_m, + tensor_map_a, + tensor_map_b, + tensor_map_scales_a, + tensor_map_d, + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py index 18245f8..9a8b6f2 100644 --- a/deep_gemm/jit_kernels/tuner.py +++ b/deep_gemm/jit_kernels/tuner.py @@ -12,7 +12,7 @@ class JITTuner: self.tuned = {} def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, - kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: + kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: # NOTES: we always assume the space, template and GPU devices will not change # NOTES: the function must have no accumulated side effects keys = {k: keys[k] for k in sorted(keys.keys())} @@ -34,7 +34,7 @@ class JITTuner: assert isinstance(tuned_keys, dict) full_keys = copy.deepcopy(keys) full_keys.update(tuned_keys) - code = generator(**kwargs, **full_keys) + code = runtime_cls.generate(**kwargs, **full_keys) kernels.append((build(name, code, runtime_cls), full_keys)) # TODO: fix tuning with space > 1 diff --git a/tests/test_jit.py b/tests/test_jit.py index 66c4fcf..e6bad9f 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,7 +1,7 @@ import ctypes import os import torch -import cuda.bindings.driver as cuda +import cuda.bindings.driver as cbd from deep_gemm import jit @@ -10,43 +10,18 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1') -# noinspection PyShadowingNames -def launch_vector_add(kernel: cuda.CUkernel, - a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, - stream: cuda.CUstream) -> cuda.CUresult: - assert a.shape == b.shape == c.shape - assert a.device == b.device == c.device - assert a.dim() == 1 +class VectorAddRuntime(jit.Runtime): + def __init__(self, path: str) -> None: + super().__init__(path, 'vector_add', [ + 'A', + 'B', + 'C', + 'STREAM', + ]) - n = a.numel() - - config = cuda.CUlaunchConfig() - config.gridDimX = (n + 127) // 128 - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = 128 - config.blockDimY = 1 - config.blockDimZ = 1 - config.hStream = stream - - arg_values = ( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - n, - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - ) - - return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] - - -def generate_vector_add(**kwargs) -> str: - return f""" + @staticmethod + def generate(**kwargs) -> str: + return f""" #ifdef __CUDACC_RTC__ #include #else @@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{ }} """ + # noinspection PyShadowingNames,PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, + stream: cbd.CUstream) -> cbd.CUresult: + assert a.shape == b.shape == c.shape + assert a.device == b.device == c.device + assert a.dim() == 1 -class VectorAddRuntime(jit.Runtime): - def __init__(self, path: str) -> None: - super().__init__(path, 'vector_add', launch_vector_add, [ - 'A', - 'B', - 'C', - 'STREAM', - ]) + config = cbd.CUlaunchConfig() + config.gridDimX = (a.numel() + 127) // 128 + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = 128 + config.blockDimY = 1 + config.blockDimZ = 1 + config.hStream = stream + + arg_values = ( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + a.numel(), + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ) + + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] if __name__ == '__main__': print('Generated code:') - code = generate_vector_add(T='float') + code = VectorAddRuntime.generate(T='float') print(code) print() @@ -100,6 +98,6 @@ if __name__ == '__main__': b = torch.randn((1024, ), dtype=torch.float32, device='cuda') c = torch.empty_like(a) ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) - assert ret == cuda.CUresult.CUDA_SUCCESS, ret + assert ret == cbd.CUresult.CUDA_SUCCESS, ret torch.testing.assert_close(c, a + b) print(f'JIT test for {compiler_name} passed\n')