Fix JIT tests

This commit is contained in:
Chenggang Zhao
2025-05-16 14:39:58 +08:00
parent 78d8362e7a
commit 391755ada0
6 changed files with 14 additions and 11 deletions

View File

@@ -239,4 +239,4 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(kwargs)
runtime(**kwargs)

View File

@@ -103,7 +103,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(kwargs)
runtime(**kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -202,4 +202,4 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(kwargs)
runtime(**kwargs)

View File

@@ -110,7 +110,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Generate, build and run the kernel
code = FP8WGradGemmRuntime.generate(kwargs)
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
runtime(kwargs)
runtime(**kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],