Init weight gradient kernels.

This commit is contained in:
Zhean Xu
2025-05-06 17:16:27 +08:00
parent d374456787
commit d5470d3b4e
9 changed files with 841 additions and 35 deletions

View File

@@ -3,6 +3,10 @@ from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
)
from .utils import (
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,