Weight gradient kernels for dense and MoE models (#95)

* Init weight gradient kernels.

* Support unaligned n,k and gmem stride

* Update docs

* Several cleanups

* Remove restrictions on N

* Add stride(0) assertions

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Zhean Xu
2025-05-14 14:47:58 +08:00
committed by GitHub
parent d75b218b7b
commit 04278f6dee
12 changed files with 911 additions and 72 deletions

View File

@@ -78,7 +78,8 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
with_multiple_kernels: bool = False):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
@@ -119,8 +120,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
@@ -130,14 +132,19 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tupled else kernel_times[0]