mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 15:44:22 +00:00
feat: fix win compat
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
a3210ac850
commit
40c09fb883
@ -26,22 +26,22 @@ def hash_to_hex(s: str) -> str:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
|
||||
assert os.path.exists(
|
||||
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
with open(os.path.join(include_dir, filename), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
@ -51,7 +51,7 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
@ -73,17 +73,17 @@ def get_default_user_dir():
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
return os.path.join(get_default_user_dir(), 'tmp')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
return os.path.join(get_default_user_dir(), 'cache')
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
@ -96,7 +96,7 @@ def put(path, data):
|
||||
is_binary = isinstance(data, bytes)
|
||||
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
@ -137,7 +137,7 @@ class Compiler(abc.ABC):
|
||||
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
@ -148,14 +148,14 @@ class Compiler(abc.ABC):
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
cubin_path = os.path.join(path, 'kernel.cubin')
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None) or True:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
@ -194,8 +194,8 @@ class NvccCompiler(Compiler):
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
# Write the code
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
put(src_path, code)
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
|
@ -10,7 +10,9 @@ from .utils import run_gemm
|
||||
|
||||
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path]
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
command = [os.path.join(CUDA_HOME, 'bin', 'cuobjdump'), '-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
|
Loading…
Reference in New Issue
Block a user