fix cuda_graph rng check error

This commit is contained in:
sleepcoo 2025-03-11 12:40:42 +08:00
parent 6c5da03ba9
commit 723a00338e

View File

@ -1,11 +1,10 @@
import copy
import ctypes
import os
from typing import Any, Dict, Iterable, Tuple
import torch
from typing import Any, Iterable, Dict, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
@ -37,12 +36,31 @@ genc_map = {
def map_ctype(value: Any) -> Any:
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
if isinstance(value, torch.Tensor):
return ctype(value.data_ptr())
if isinstance(value, torch.cuda.Stream):
return ctype(value.cuda_stream)
return ctype(value)
if hasattr(value, 'data_ptr'):
if value.dtype == torch.int:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
if hasattr(value, 'cuda_stream'):
return ctypes.c_void_p(value.cuda_stream)
if isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, int):
return ctypes.c_int(value)
elif isinstance(value, float):
return ctypes.c_float(value)
return ctype_map[type(value)](value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str: