mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-03 19:51:52 +00:00
193 lines
7.1 KiB
Python
193 lines
7.1 KiB
Python
import os
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from typing import Optional
|
|
|
|
|
|
def init_dist(local_rank: int, num_local_ranks: int):
|
|
# NOTES: you may rewrite this function with your own cluster settings
|
|
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
|
port = int(os.getenv('MASTER_PORT', '8361'))
|
|
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
|
node_rank = int(os.getenv('RANK', 0))
|
|
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
|
|
|
dist.init_process_group(
|
|
backend='nccl',
|
|
init_method=f'tcp://{ip}:{port}',
|
|
world_size=num_nodes * num_local_ranks,
|
|
rank=node_rank * num_local_ranks + local_rank
|
|
)
|
|
torch.set_default_dtype(torch.bfloat16)
|
|
torch.set_default_device('cuda')
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
|
|
|
|
|
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
x, y = x.double() + 1, y.double() + 1
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return (1 - sim).item()
|
|
|
|
|
|
def per_token_cast_to_fp8(x: torch.Tensor):
|
|
assert x.dim() == 2 and x.size(1) % 128 == 0
|
|
m, n = x.shape
|
|
x_view = x.view(m, -1, 128)
|
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
|
|
|
|
|
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
|
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
|
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
|
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
|
|
|
|
|
def inplace_unique(x: torch.Tensor, num_slots: int):
|
|
assert x.dim() == 2
|
|
mask = x < 0
|
|
x_padded = x.masked_fill(mask, num_slots)
|
|
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
|
|
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
|
|
bin_count = bin_count[:, :num_slots]
|
|
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
|
|
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
|
|
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
|
|
x[:, :].fill_(-1)
|
|
valid_len = min(num_slots, x.size(1))
|
|
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
|
|
|
|
|
|
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
|
|
num_tokens, num_experts = scores.shape
|
|
scores = scores.view(num_tokens, num_groups, -1)
|
|
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
|
|
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
|
|
return (scores * mask).view(num_tokens, num_experts)
|
|
|
|
|
|
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
|
# Flush L2 cache with 256 MB data
|
|
torch.cuda.synchronize()
|
|
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
|
|
|
# Warmup
|
|
for _ in range(num_warmups):
|
|
fn()
|
|
|
|
# Flush L2
|
|
cache.zero_()
|
|
|
|
# Testing
|
|
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
|
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
|
for i in range(num_tests):
|
|
# Record
|
|
start_events[i].record()
|
|
fn()
|
|
end_events[i].record()
|
|
if post_fn is not None:
|
|
post_fn()
|
|
torch.cuda.synchronize()
|
|
|
|
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
|
|
return np.average(times), np.min(times), np.max(times)
|
|
|
|
|
|
class empty_suppress:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *_):
|
|
pass
|
|
|
|
|
|
class suppress_stdout_stderr:
|
|
def __enter__(self):
|
|
self.outnull_file = open(os.devnull, 'w')
|
|
self.errnull_file = open(os.devnull, 'w')
|
|
|
|
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
|
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
|
|
|
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
|
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
|
|
|
self.old_stdout = sys.stdout
|
|
self.old_stderr = sys.stderr
|
|
|
|
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
|
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
|
|
|
sys.stdout = self.outnull_file
|
|
sys.stderr = self.errnull_file
|
|
return self
|
|
|
|
def __exit__(self, *_):
|
|
sys.stdout = self.old_stdout
|
|
sys.stderr = self.old_stderr
|
|
|
|
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
|
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
|
|
|
os.close(self.old_stdout_fileno)
|
|
os.close(self.old_stderr_fileno)
|
|
|
|
self.outnull_file.close()
|
|
self.errnull_file.close()
|
|
|
|
|
|
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
|
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
|
|
# Profile
|
|
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
|
with suppress():
|
|
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
|
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
|
|
for i in range(2):
|
|
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
|
if barrier_comm_profiling:
|
|
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
lhs @ rhs
|
|
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
|
for _ in range(num_tests):
|
|
fn()
|
|
prof.step()
|
|
|
|
# Parse the profiling table
|
|
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
|
is_tupled = isinstance(kernel_names, tuple)
|
|
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
|
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
|
assert all([isinstance(name, str) for name in kernel_names])
|
|
for name in kernel_names:
|
|
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
|
|
|
# Save chrome traces
|
|
if trace_path is not None:
|
|
prof.export_chrome_trace(trace_path)
|
|
|
|
# Return average kernel times
|
|
units = {'ms': 1e3, 'us': 1e6}
|
|
kernel_times = []
|
|
for name in kernel_names:
|
|
for line in prof_lines:
|
|
if name in line:
|
|
time_str = line.split()[-2]
|
|
for unit, scale in units.items():
|
|
if unit in time_str:
|
|
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
|
break
|
|
break
|
|
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
|
|
|
|
|
def hash_tensor(t: torch.Tensor):
|
|
return t.view(torch.int64).sum().item()
|