DeepGEMM/deep_gemm/__init__.py
Zhean Xu 04278f6dee
Weight gradient kernels for dense and MoE models (#95)
* Init weight gradient kernels.

* Support unaligned n,k and gmem stride

* Update docs

* Several cleanups

* Remove restrictions on N

* Add stride(0) assertions

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2025-05-14 14:47:58 +08:00

16 lines
425 B
Python

import torch
from . import jit
from .jit_kernels import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
ceil_div,
set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
)
from .utils import bench, bench_kineto, calc_diff