mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Weight gradient kernels for dense and MoE models (#95)
* Init weight gradient kernels. * Support unaligned n,k and gmem stride * Update docs * Several cleanups * Remove restrictions on N * Add stride(0) assertions --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -78,7 +78,8 @@ class suppress_stdout_stderr:
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
|
||||
|
||||
@@ -119,8 +120,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
@@ -130,14 +132,19 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user