Unify kwargs usages

This commit is contained in:
Chenggang Zhao
2025-05-15 16:53:52 +08:00
parent 350989eef3
commit 3b412f458a
6 changed files with 14 additions and 15 deletions

View File

@@ -101,9 +101,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
}
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(**kwargs)
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)
runtime(kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -200,6 +200,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
}
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(**kwargs)
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)
runtime(kwargs)