diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index b917dec..c983817 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -1,11 +1,10 @@ import copy import ctypes import os +from typing import Any, Dict, Iterable, Tuple + import torch -from typing import Any, Iterable, Dict, Tuple - - # Name map for Python `eval` typename_map: Dict[Any, str] = { **{t: t.__name__ for t in (bool, int, float)}, @@ -37,12 +36,31 @@ genc_map = { def map_ctype(value: Any) -> Any: - ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)] - if isinstance(value, torch.Tensor): - return ctype(value.data_ptr()) - if isinstance(value, torch.cuda.Stream): - return ctype(value.cuda_stream) - return ctype(value) + if hasattr(value, 'data_ptr'): + if value.dtype == torch.int: + return ctypes.c_void_p(value.data_ptr()) + elif value.dtype == torch.float: + return ctypes.c_void_p(value.data_ptr()) + elif value.dtype == torch.bfloat16: + return ctypes.c_void_p(value.data_ptr()) + elif value.dtype == torch.float16: + return ctypes.c_void_p(value.data_ptr()) + elif value.dtype == torch.float8_e4m3fn: + return ctypes.c_void_p(value.data_ptr()) + else: + return ctypes.c_void_p(value.data_ptr()) + + if hasattr(value, 'cuda_stream'): + return ctypes.c_void_p(value.cuda_stream) + + if isinstance(value, bool): + return ctypes.c_bool(value) + elif isinstance(value, int): + return ctypes.c_int(value) + elif isinstance(value, float): + return ctypes.c_float(value) + + return ctype_map[type(value)](value) def cpp_format(template: str, keys: Dict[str, Any]) -> str: