mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Fix JIT tests
This commit is contained in:
parent
78d8362e7a
commit
391755ada0
@ -166,7 +166,7 @@ class Compiler:
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True)
|
||||
assert runtime is not None
|
||||
return runtime
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class Runtime:
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
def __call__(self, **kwargs) -> cbd.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
@ -81,17 +81,19 @@ class RuntimeCache:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime],
|
||||
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
|
||||
name: str = '', kwargs: Dict[str, Any] = None,
|
||||
force_enable_cache: bool = False) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0))
|
||||
if use_cache and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
# Print heuristic for the first time
|
||||
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
|
||||
simplified_kwargs = dict()
|
||||
for key, value in kwargs.items():
|
||||
for key, value in kwargs.items() if kwargs is not None else dict().items():
|
||||
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
|
||||
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
|
||||
simplified_kwargs[key] = value
|
||||
|
||||
@ -239,4 +239,4 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -103,7 +103,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -202,4 +202,4 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -110,7 +110,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
# Generate, build and run the kernel
|
||||
code = FP8WGradGemmRuntime.generate(kwargs)
|
||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
||||
runtime(kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
@ -74,7 +74,8 @@ static void __instantiate_kernel() {{
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Generated code:')
|
||||
code = VectorAddRuntime.generate(T='float')
|
||||
kwargs = {'T': 'float'}
|
||||
code = VectorAddRuntime.generate(kwargs)
|
||||
print(code)
|
||||
print()
|
||||
|
||||
@ -85,7 +86,7 @@ if __name__ == '__main__':
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = compiler_cls.build('test_func', code, VectorAddRuntime)
|
||||
func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs)
|
||||
|
||||
# Run and check
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user