mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-29 16:50:33 +00:00
fix cuda_graph rng check error
This commit is contained in:
parent
6c5da03ba9
commit
723a00338e
@ -1,11 +1,10 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Any, Iterable, Dict, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@ -37,12 +36,31 @@ genc_map = {
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
||||
if isinstance(value, torch.Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, torch.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
if hasattr(value, 'data_ptr'):
|
||||
if value.dtype == torch.int:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.bfloat16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float8_e4m3fn:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
else:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
|
||||
if hasattr(value, 'cuda_stream'):
|
||||
return ctypes.c_void_p(value.cuda_stream)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return ctypes.c_bool(value)
|
||||
elif isinstance(value, int):
|
||||
return ctypes.c_int(value)
|
||||
elif isinstance(value, float):
|
||||
return ctypes.c_float(value)
|
||||
|
||||
return ctype_map[type(value)](value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user