import os import sys import numpy as np import torch import torch.distributed as dist from typing import Optional def init_dist(local_rank: int, num_local_ranks: int): # NOTES: you may rewrite this function with your own cluster settings ip = os.getenv('MASTER_ADDR', '127.0.0.1') port = int(os.getenv('MASTER_PORT', '8361')) num_nodes = int(os.getenv('WORLD_SIZE', 1)) node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 dist.init_process_group( backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=num_nodes * num_local_ranks, rank=node_rank * num_local_ranks + local_rank ) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double() + 1, y.double() + 1 denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return (1 - sim).item() def per_token_cast_to_fp8(x: torch.Tensor): assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) x_scales = x_scales.view(x_fp8.size(0), -1, 1) return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) def inplace_unique(x: torch.Tensor, num_slots: int): assert x.dim() == 2 mask = x < 0 x_padded = x.masked_fill(mask, num_slots) bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) bin_count = bin_count[:, :num_slots] sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values x[:, :].fill_(-1) valid_len = min(num_slots, x.size(1)) x[:, :valid_len] = sorted_bin_idx[:, :valid_len] def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int): num_tokens, num_experts = scores.shape scores = scores.view(num_tokens, num_groups, -1) mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) return (scores * mask).view(num_tokens, num_experts) def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') # Warmup for _ in range(num_warmups): fn() # Flush L2 cache.zero_() # Testing start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] for i in range(num_tests): # Record start_events[i].record() fn() end_events[i].record() if post_fn is not None: post_fn() torch.cuda.synchronize() times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:] return np.average(times), np.min(times), np.max(times) class empty_suppress: def __enter__(self): return self def __exit__(self, *_): pass class suppress_stdout_stderr: def __enter__(self): self.outnull_file = open(os.devnull, 'w') self.errnull_file = open(os.devnull, 'w') self.old_stdout_fileno_undup = sys.stdout.fileno() self.old_stderr_fileno_undup = sys.stderr.fileno() self.old_stdout_fileno = os.dup(sys.stdout.fileno()) self.old_stderr_fileno = os.dup(sys.stderr.fileno()) self.old_stdout = sys.stdout self.old_stderr = sys.stderr os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) sys.stdout = self.outnull_file sys.stderr = self.errnull_file return self def __exit__(self, *_): sys.stdout = self.old_stdout sys.stderr = self.old_stderr os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) os.close(self.old_stdout_fileno) os.close(self.old_stderr_fileno) self.outnull_file.close() self.errnull_file.close() def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: Optional[str] = None, barrier_comm_profiling: bool = False): # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: for i in range(2): # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead if barrier_comm_profiling: lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') lhs @ rhs dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): fn() prof.step() # Parse the profiling table assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tupled = isinstance(kernel_names, tuple) prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) for name in kernel_names: assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' # Save chrome traces if trace_path is not None: prof.export_chrome_trace(trace_path) # Return average kernel times units = {'ms': 1e3, 'us': 1e6} kernel_times = [] for name in kernel_names: for line in prof_lines: if name in line: time_str = line.split()[-2] for unit, scale in units.items(): if unit in time_str: kernel_times.append(float(time_str.replace(unit, '')) / scale) break break return tuple(kernel_times) if is_tupled else kernel_times[0] def hash_tensor(t: torch.Tensor): return t.view(torch.int64).sum().item()