mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-30 04:43:21 +00:00
fix cuda_graph rng check error
This commit is contained in:
parent
6c5da03ba9
commit
723a00338e
@ -1,11 +1,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict, Iterable, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Any, Iterable, Dict, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
# Name map for Python `eval`
|
# Name map for Python `eval`
|
||||||
typename_map: Dict[Any, str] = {
|
typename_map: Dict[Any, str] = {
|
||||||
**{t: t.__name__ for t in (bool, int, float)},
|
**{t: t.__name__ for t in (bool, int, float)},
|
||||||
@ -37,12 +36,31 @@ genc_map = {
|
|||||||
|
|
||||||
|
|
||||||
def map_ctype(value: Any) -> Any:
|
def map_ctype(value: Any) -> Any:
|
||||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
if hasattr(value, 'data_ptr'):
|
||||||
if isinstance(value, torch.Tensor):
|
if value.dtype == torch.int:
|
||||||
return ctype(value.data_ptr())
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
if isinstance(value, torch.cuda.Stream):
|
elif value.dtype == torch.float:
|
||||||
return ctype(value.cuda_stream)
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
return ctype(value)
|
elif value.dtype == torch.bfloat16:
|
||||||
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
|
elif value.dtype == torch.float16:
|
||||||
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
|
elif value.dtype == torch.float8_e4m3fn:
|
||||||
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
|
else:
|
||||||
|
return ctypes.c_void_p(value.data_ptr())
|
||||||
|
|
||||||
|
if hasattr(value, 'cuda_stream'):
|
||||||
|
return ctypes.c_void_p(value.cuda_stream)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return ctypes.c_bool(value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return ctypes.c_int(value)
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return ctypes.c_float(value)
|
||||||
|
|
||||||
|
return ctype_map[type(value)](value)
|
||||||
|
|
||||||
|
|
||||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user