refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 10:17:52 +00:00
parent 27cd276e19
commit c14cad0c06
5 changed files with 237 additions and 284 deletions

View File

@@ -3,8 +3,6 @@ import time
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch
from .utils import run_gemm
@@ -58,8 +56,9 @@ class Runtime:
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,