diff --git a/README.md b/README.md index a9d7d1f..55ca9dd 100644 --- a/README.md +++ b/README.md @@ -105,16 +105,23 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: -- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default -- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default -- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default -- `DG_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler -- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization -- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output -- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details -- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command -- `DG_JIT_DEBUG`: 0 or 1, print more debugging information -- `DG_JIT_USE_NVRTC`: 0 or 1, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, 0 by default +- General + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default +- JIT cache related +- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default +- `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default +- NVCC/NVRTC selections + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default +- Compiler options + - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default +- Post optimization + - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default +- Testing + - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. @@ -141,9 +148,9 @@ The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide - Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction - [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups -- Larger block sizes -- Less bank conflicts via 3D TMA 🐳 -- Overlapping as much as possible, e.g. overlapping TMA store and non-TMA RHS scaling factor load 🐳 +- Less bank conflicts via 3D TMA or swizzling +- Larger block sizes (up to 256x128 🐳) +- Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 #### A unified and optimized block scheduler diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index cf07889..80910b4 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -49,8 +49,8 @@ def get_deep_gemm_version() -> str: @functools.lru_cache(maxsize=None) def get_nvcc_compiler() -> Tuple[str, str]: paths = [] - if os.getenv('DG_NVCC_COMPILER'): - paths.append(os.getenv('DG_NVCC_COMPILER')) + if os.getenv('DG_JIT_NVCC_COMPILER'): + paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc' paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin)) @@ -73,8 +73,8 @@ def get_nvcc_compiler() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def get_default_user_dir(): - if 'DG_CACHE_DIR' in os.environ: - path = os.getenv('DG_CACHE_DIR') + if 'DG_JIT_CACHE_DIR' in os.environ: + path = os.getenv('DG_JIT_CACHE_DIR') os.makedirs(path, exist_ok=True) return path return os.path.join(os.path.expanduser('~'), '.deep_gemm') @@ -119,10 +119,10 @@ class Compiler: @staticmethod def flags() -> List[str]: - cpp_standard = int(os.getenv('DG_OVERRIDE_CPP_STANDARD', 20)) + cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) return [f'-std=c++{cpp_standard}', '--ptxas-options=--register-usage-level=10' + - (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), + (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=39,161,174,177,940'] @@ -136,7 +136,7 @@ class Compiler: flags = cls.flags() # Build signature - enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) + enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' name = f'kernel.{name}.{hash_to_hex(signature)}' path = os.path.join(get_cache_dir(), name) diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py index 12baa0d..7899a22 100644 --- a/deep_gemm/jit/interleave_ffma.py +++ b/deep_gemm/jit/interleave_ffma.py @@ -37,7 +37,7 @@ def extract_ffma(sass): collected.append((f'{arch_name}::{func_name}', current)) current = [] - if int(os.getenv('DG_PRINT_REG_REUSE', 0)): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f'Found {len(collected)} FFMA segments') return collected @@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines): dst_reg_set.add(dst_reg) new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) last_reused, last_dst_reg = reused, dst_reg - if int(os.getenv('DG_PRINT_REG_REUSE', 0)): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') # Find the offset @@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines): def process(path): - if int(os.getenv('DG_PRINT_REG_REUSE', 0)): + if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): print(f'Processing {path}') output = run_cuobjdump(path) segments = extract_ffma(output) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 50d8d42..79ba6b8 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -95,7 +95,7 @@ class RuntimeCache: return self.cache[path] # Already compiled - if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): + if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): runtime = runtime_cls(path) self.cache[path] = runtime return runtime diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py index d5cdd01..f99ecd4 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -80,25 +80,10 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): # Conflict with Nsight Systems - using_nsys = os.environ.get('DG_NSYS_PROFILING', False) + using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle - # this avoid thermal throttling while keeping DVFS at max clocks (slight gain vs sleep / more consistent on GH200) - sleep_between_tests = 0.0 flush_l2_size = int(8e9 // 4) - if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False): - flush_l2 = False - if os.environ.get('DG_BENCH_POWER_LIMITED', False): - # if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time - # and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2) - num_tests = 2000 - flush_l2_size = int(80e6 // 4) - sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False) - if sleep_val: - try: - sleep_between_tests = float(sleep_val) - except ValueError: - pass # Keep default # For some auto-tuning kernels with prints fn() @@ -117,8 +102,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: lhs @ rhs dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): - if sleep_between_tests > 0.0: - time.sleep(sleep_between_tests) if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() fn() diff --git a/tests/test_jit.py b/tests/test_jit.py index e6bad9f..1ba0d16 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -7,7 +7,7 @@ from deep_gemm import jit # Essential debugging staffs os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') -os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1') +os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') class VectorAddRuntime(jit.Runtime):