Merge pull request #100 from deepseek-ai/remove-tuner

Refactor some launch-related structures
This commit is contained in:
Chenggang Zhao
2025-05-15 17:05:42 +08:00
committed by GitHub
11 changed files with 254 additions and 442 deletions

View File

@@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful:
- Post optimization - Post optimization
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
- Heuristic selection - Heuristic selection
- `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- Testing - Testing
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default

View File

@@ -18,7 +18,7 @@ namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N, template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kLastStages, uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup, uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA> uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1) __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
@@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
struct DivisibleK {}; struct DivisibleK {};
struct NotDivisibleK {}; struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) { auto launch_k_iterations = [&](const auto& func) {
if constexpr (kLastStages == 0) { if constexpr (kNumLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}); func(k_iter, DivisibleK{});
} else { } else {
@@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
while (scheduler.get_next_block(m_block_idx, n_block_idx)) { while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) { launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B // Assign TMA multicast number into A and B
@@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
// Launch MMAs // Launch MMAs
launch_k_iterations([&](int k_iter, auto type) { launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>; constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll #pragma unroll

View File

@@ -5,7 +5,7 @@ import re
import subprocess import subprocess
import time import time
import uuid import uuid
from typing import List, Tuple, Type from typing import Any, Dict, List, Tuple, Type
import cuda.bindings import cuda.bindings
import cuda.bindings.nvrtc as nvrtc import cuda.bindings.nvrtc as nvrtc
@@ -128,7 +128,7 @@ class Compiler:
return [get_jit_include_dir()] return [get_jit_include_dir()]
@classmethod @classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
# Compiler flags # Compiler flags
flags = cls.flags() flags = cls.flags()
@@ -140,7 +140,7 @@ class Compiler:
# Check runtime cache or file system hit # Check runtime cache or file system hit
global runtime_cache global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls) cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
if cached_runtime is not None: if cached_runtime is not None:
if int(os.getenv('DG_JIT_DEBUG', 0)): if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build') print(f'Using cached JIT runtime {name} during build')
@@ -166,8 +166,8 @@ class Compiler:
os.replace(tmp_cubin_path, cubin_path) os.replace(tmp_cubin_path, cubin_path)
# Put cache and return # Put cache and return
runtime = runtime_cls(path) runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
runtime_cache[path] = runtime assert runtime is not None
return runtime return runtime
@@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler):
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls=runtime_cls) return compiler_cls.build(name, code, runtime_cls, kwargs)

View File

@@ -1,18 +1,18 @@
import os import os
import subprocess import subprocess
import time import time
import torch
import cuda.bindings.driver as cbd import cuda.bindings.driver as cbd
from typing import List, Optional, Type from typing import Any, Dict, Optional, Type
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
class Runtime: class Runtime:
def __init__(self, path: str, args: List[str] = None) -> None: def __init__(self, path: str) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.kernel = None self.kernel = None
self.args = args
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@staticmethod @staticmethod
@@ -26,14 +26,14 @@ class Runtime:
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
@staticmethod @staticmethod
def generate(**kwargs) -> str: def generate(kwargs: Dict[str, Any]) -> str:
raise NotImplemented raise NotImplemented
@staticmethod @staticmethod
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
raise NotImplemented raise NotImplemented
def __call__(self, **kwargs) -> cbd.CUresult: def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult:
# Load CUBIN # Load CUBIN
if self.kernel is None: if self.kernel is None:
start_time = time.time_ns() start_time = time.time_ns()
@@ -48,8 +48,10 @@ class Runtime:
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0 assert result.returncode == 0
illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail']
check_illegal = lambda line: any([name in line for name in illegal_names])
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line] if line.startswith('STT_FUNC') and not check_illegal(line)]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library # Load kernel from the library
@@ -62,7 +64,7 @@ class Runtime:
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList # noinspection PyArgumentList
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args]) return self.launch(self.kernel, kwargs)
def __del__(self) -> None: def __del__(self) -> None:
if self.lib is not None: if self.lib is not None:
@@ -78,13 +80,23 @@ class RuntimeCache:
def __setitem__(self, path: str, runtime: Runtime) -> None: def __setitem__(self, path: str, runtime: Runtime) -> None:
self.cache[path] = runtime self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: def get(self, path: str, runtime_cls: Type[Runtime],
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
# In Python runtime # In Python runtime
if path in self.cache: if path in self.cache:
return self.cache[path] return self.cache[path]
# Already compiled # Already compiled
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
# Print heuristic for the first time
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
simplified_kwargs = dict()
for key, value in kwargs.items():
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
simplified_kwargs[key] = value
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
runtime = runtime_cls(path) runtime = runtime_cls(path)
self.cache[path] = runtime self.cache[path] = runtime
return runtime return runtime

View File

@@ -3,11 +3,11 @@ import torch
from functools import lru_cache from functools import lru_cache
from typing import Tuple from typing import Tuple
from ..jit import build
from .runtime import ( from .runtime import (
FP8GemmRuntime, GemmType, FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc) make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@@ -18,7 +18,6 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
def get_swizzle_mode(block_n: int) -> int: def get_swizzle_mode(block_n: int) -> int:
# TODO: remove some candidates if slow
elem_size = 2 elem_size = 2
for mode_bytes in (128, 64, 32): for mode_bytes in (128, 64, 32):
if (block_n * elem_size) % mode_bytes == 0: if (block_n * elem_size) % mode_bytes == 0:
@@ -180,22 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Type and shape checks # Type and shape checks
assert m == m_ and n == n_ and k == k_ assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0 assert n > 0 and k > 0
assert lhs_scales.shape == (m, (k + 127) // 128) assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16 assert out.dtype == torch.bfloat16
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
# LHS scales must be transposed for TMA loads, but not for RHS scales # LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous() assert rhs_scales.is_contiguous()
@@ -204,33 +196,34 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
return return
# K must be aligned to 128 # K must be aligned to 128
aligned_k = (k + 127) // 128 * 128 aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation # Auto-tuning with compilation
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
m, n, k, 1, num_sms)
block_k = 128 block_k = 128
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc( tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_b = make_2d_tma_b_desc( tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
kwargs = { kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal, 'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads, 'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1, 'NUM_GROUPS': 1,
'BLOCK_K': block_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'GMEM_D': out, 'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'SCALES_B': rhs_scales, 'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
'NUM_SMS': num_sms, 'NUM_SMS': num_sms,
@@ -240,21 +233,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d, 'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream, 'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
} }
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': aligned_k,
'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel # Generate, build and run the kernel
runtime(**best_keys, **kwargs) code = FP8GemmRuntime.generate(kwargs)
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(kwargs)

View File

@@ -1,13 +1,13 @@
import torch import torch
from typing import Tuple from typing import Tuple
from ..jit import build
from .gemm import get_best_configs from .gemm import get_best_configs
from .runtime import ( from .runtime import (
FP8GemmRuntime, GemmType, FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc) make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .tuner import jit_tuner from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Type and shape checks # Type and shape checks
assert m == m_ == m__ and k == k_ and n == n_ assert m == m_ == m__ and k == k_ and n == n_
assert lhs_scales.shape == (m, (k + 127) // 128) assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16 assert out.dtype == torch.bfloat16
@@ -69,21 +69,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc( tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups) tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc( tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups) tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = { kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads, 'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'M': m, 'N': n, 'K': k,
'BLOCK_K': block_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'GMEM_D': out, 'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedContiguous,
# Runtime arguments
'SCALES_B': rhs_scales, 'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices, 'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms, 'NUM_SMS': num_sms,
@@ -93,25 +97,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d, 'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream, 'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
} }
runtime, best_keys = jit_tuner.compile_and_tune( # Generate, build and run the kernel
name='m_grouped_gemm_fp8_fp8_bf16_nt', code = FP8GemmRuntime.generate(kwargs)
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
'SWIZZLE_D_MODE': smem_config[1], runtime(kwargs)
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -150,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
assert num_groups == num_groups_ == num_groups__ == num_groups___ assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_ assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16 assert out.dtype == torch.bfloat16
@@ -176,21 +168,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc( tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups) tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc( tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups) tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = { kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads, 'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'M': m, 'N': n, 'K': k,
'BLOCK_K': block_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'GMEM_D': out, 'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedMasked,
# Runtime arguments
'SCALES_B': rhs_scales, 'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m, 'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms, 'NUM_SMS': num_sms,
@@ -200,22 +196,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'TENSOR_MAP_SCALES_A': tensor_map_scales_a, 'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d, 'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream, 'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
} }
runtime, best_keys = jit_tuner.compile_and_tune( # Generate, build and run the kernel
name='m_grouped_gemm_fp8_fp8_bf16_nt', code = FP8GemmRuntime.generate(kwargs)
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
'SWIZZLE_D_MODE': smem_config[1], runtime(kwargs)
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)

View File

@@ -5,14 +5,10 @@ import torch
import cuda.bindings.driver as cbd import cuda.bindings.driver as cbd
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
from .utils import get_tma_aligned_size
from ..jit.runtime import Runtime from ..jit.runtime import Runtime
class Layout(enum.Enum):
RowMajor = 0
ColMajor = 1
class GemmType(enum.Enum): class GemmType(enum.Enum):
Normal = 0 Normal = 0
GroupedContiguous = 1 GroupedContiguous = 1
@@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, def make_2d_tma_copy_desc(t: torch.Tensor,
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
stride_in_bytes: cbd.cuuint64_t, smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
tensor_dtype = tmap_type_map[global_address.dtype] tensor_dtype = tmap_type_map[t.dtype]
res, tensor_map = cbd.cuTensorMapEncodeTiled( res, tensor_map = cbd.cuTensorMapEncodeTiled(
tensor_dtype, tensor_dtype,
2, 2,
global_address.data_ptr(), t.data_ptr(),
gmem_dim, gmem_dims,
(stride_in_bytes, ), (gmem_outer_stride,),
smem_dim, smem_dims,
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)), (cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type, swizzle_type,
@@ -86,93 +81,64 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
return tensor_map return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, def make_2d_tma_desc(t: torch.Tensor,
gmem_rows: int, gmem_cols: int, gmem_stride: int, gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
smem_rows: int, smem_cols: int, smem_inner_dim: int, smem_outer_dim: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
if layout == Layout.RowMajor: gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows)) smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows)) return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_k: int, shape_m: int, shape_k: int, m_stride: int,
block_m: int, block_k: int, block_m: int, block_k: int,
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap: num_groups: int) -> cbd.CUtensorMap:
a_stride = shape_k if a_stride == 0 else a_stride return make_2d_tma_desc(t,
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride, block_k, block_m)
block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
shape_k: int, shape_n: int, shape_n: int, shape_k: int, n_stride: int,
block_k: int, block_n: int, block_n: int, block_k: int,
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap: num_groups: int) -> cbd.CUtensorMap:
b_stride = shape_k if b_stride == 0 else b_stride return make_2d_tma_desc(t,
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
block_k, block_n) block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor, def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_n: int, shape_m: int, shape_n: int, m_stride: int,
block_m: int, block_n: int, block_m: int, block_n: int,
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap: num_groups: int,
swizzle_mode: int) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
d_stride = shape_n if d_stride == 0 else d_stride return make_2d_tma_desc(t,
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride, block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
swizzle_type_map[swizzle_mode]) swizzle_type_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap: def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
shape_mn: int, shape_k: int,
block_mn: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes # Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size() shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment return make_2d_tma_desc(t,
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
return make_2d_tma_desc(global_address, Layout.ColMajor, block_mn, 1,
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
tma_alignment = 16 / global_address.element_size()
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
return make_2d_tma_desc(global_address, Layout.ColMajor,
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime): class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__(path, [ super().__init__(path)
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod @staticmethod
def generate(**kwargs) -> str: def generate(kwargs: Dict[str, Any]) -> str:
code = f''' code = f'''
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh> #include <deep_gemm/nvrtc_std.cuh>
@@ -213,21 +179,16 @@ static void __instantiate_kernel() {{
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
@staticmethod @staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
if res != cbd.CUresult.CUDA_SUCCESS: kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
raise Exception(f'Failed to set max dynamic shared memory size: {res}') assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue() attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1 attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1 attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute() attr = cbd.CUlaunchAttribute()
@@ -237,23 +198,23 @@ static void __instantiate_kernel() {{
config = cbd.CUlaunchConfig() config = cbd.CUlaunchConfig()
config.numAttrs = 1 config.numAttrs = 1
config.attrs = [attr] config.attrs = [attr]
config.gridDimX = num_sms config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1 config.gridDimY = 1
config.gridDimZ = 1 config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1 config.blockDimY = 1
config.blockDimZ = 1 config.blockDimZ = 1
config.sharedMemBytes = smem_size config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = stream config.hStream = kwargs['STREAM']
arg_values = ( arg_values = (
scales_b.data_ptr(), kwargs['SCALES_B'].data_ptr(),
grouped_layout.data_ptr(), kwargs['GROUPED_LAYOUT'].data_ptr(),
shape_m, kwargs['M'],
tensor_map_a, kwargs['TENSOR_MAP_A'],
tensor_map_b, kwargs['TENSOR_MAP_B'],
tensor_map_scales_a, kwargs['TENSOR_MAP_SCALES_A'],
tensor_map_d, kwargs['TENSOR_MAP_D'],
) )
arg_types = ( arg_types = (
ctypes.c_void_p, ctypes.c_void_p,
@@ -269,23 +230,10 @@ static void __instantiate_kernel() {{
class FP8WGradGemmRuntime(Runtime): class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__(path, [ super().__init__(path)
'NUM_TMA_MULTICAST',
'K',
'BLOCK_M',
'GMEM_D',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_SCALES_B',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod @staticmethod
def generate(**kwargs) -> str: def generate(kwargs: Dict[str, Any]) -> str:
code = f''' code = f'''
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh> #include <deep_gemm/nvrtc_std.cuh>
@@ -309,7 +257,7 @@ static void __instantiate_kernel() {{
{kwargs['BLOCK_N']}, {kwargs['BLOCK_N']},
{kwargs['BLOCK_K']}, {kwargs['BLOCK_K']},
{kwargs['NUM_STAGES']}, {kwargs['NUM_STAGES']},
{kwargs['LAST_STAGES']}, {kwargs['NUM_LAST_STAGES']},
{kwargs['NUM_TMA_THREADS']}, {kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']}, {kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']}, {kwargs['NUM_TMA_MULTICAST']},
@@ -323,21 +271,16 @@ static void __instantiate_kernel() {{
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
@staticmethod @staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int, def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
tensor_map_d: cbd.CUtensorMap,
stream: cbd.CUstream) -> cbd.CUresult:
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0] result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
if res != cbd.CUresult.CUDA_SUCCESS: kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
raise Exception(f'Failed to set max dynamic shared memory size: {res}') assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue() attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1 attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1 attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute() attr = cbd.CUlaunchAttribute()
@@ -347,22 +290,22 @@ static void __instantiate_kernel() {{
config = cbd.CUlaunchConfig() config = cbd.CUlaunchConfig()
config.numAttrs = 1 config.numAttrs = 1
config.attrs = [attr] config.attrs = [attr]
config.gridDimX = num_sms config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1 config.gridDimY = 1
config.gridDimZ = 1 config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m) config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1 config.blockDimY = 1
config.blockDimZ = 1 config.blockDimZ = 1
config.sharedMemBytes = smem_size config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = stream config.hStream = kwargs['STREAM']
arg_values = ( arg_values = (
shape_k, kwargs['K'],
tensor_map_a, kwargs['TENSOR_MAP_A'],
tensor_map_b, kwargs['TENSOR_MAP_B'],
tensor_map_scales_a, kwargs['TENSOR_MAP_SCALES_A'],
tensor_map_scales_b, kwargs['TENSOR_MAP_SCALES_B'],
tensor_map_d, kwargs['TENSOR_MAP_D'],
) )
arg_types = ( arg_types = (
ctypes.c_uint32, ctypes.c_uint32,

View File

@@ -1,82 +0,0 @@
import copy
import os
import torch
import cuda.bindings.driver as cbd
from typing import Any, Callable, Dict, Type, Tuple
from ..jit import build, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert kwargs is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = runtime_cls.generate(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cbd.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g., insufficient shared memory capacity
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int,
device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)
return best_runtime, best_keys
jit_tuner = JITTuner()

View File

@@ -1,18 +1,18 @@
import torch import torch
from typing import List, Tuple from typing import List, Tuple
from ..jit import build
from .runtime import ( from .runtime import (
FP8WGradGemmRuntime, GemmType, FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc) make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .gemm import get_best_configs from .gemm import get_best_configs
from .tuner import jit_tuner from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: Tuple[torch.Tensor, torch.Tensor]): out: torch.Tensor):
""" """
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor. Results will be accumulated into the output tensor.
@@ -21,8 +21,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed. RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement, The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
this function will do a transposing with a set of slow PyTorch operations. If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
Arguments: Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
@@ -40,71 +40,62 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
# Type and shape checks # Type and shape checks
assert m == m_ and n == n_ and k == k_ assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0 assert n > 0 and m > 0
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m) assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n) assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
lhs_stride = lhs.stride(0)
rhs_stride = rhs.stride(0)
out_stride = out.stride(0)
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
# LHS and RHS scales must be transposed for TMA load # LHS and RHS scales must be transposed for TMA load
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
if lhs_scales.shape == ((k + 127) // 128, m): def get_valid_scales(scales: torch.Tensor, mn: int):
lhs_scales = lhs_scales.permute(1, 0) if scales.shape == (ceil_div(k, 128), mn):
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m # For k-grouped GEMMs
else: scales = scales.permute(1, 0)
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
assert lhs_scales.stride(0) == 1 else:
scales = get_col_major_tma_aligned_tensor(scales)
if rhs_scales.shape == ((k + 127) // 128, n): return scales
rhs_scales = rhs_scales.permute(1, 0)
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n lhs_scales = get_valid_scales(lhs_scales, m)
else: rhs_scales = get_valid_scales(rhs_scales, n)
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
assert rhs_scales.stride(0) == 1
# Do nothing if `k` is zero # Do nothing if `k` is zero
if k == 0: if k == 0:
return return
# K must be aligned to 128 # K must be aligned to 128
aligned_k = (k + 127) // 128 * 128 aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation # Auto-tuning with compilation
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
last_stages = (k + 127) // 128 % num_stages num_last_stages = ceil_div(k, 128) % num_stages
block_k = 128 block_k = 128
num_tma_threads = 128 num_tma_threads = 128
num_math_threads_per_group = 128 num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc( tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride) tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_b = make_2d_tma_b_desc( tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride) tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
tensor_map_d = make_2d_tma_d_desc( tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
tensor_map_scales_b = make_2d_tma_scales_b_desc(
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
kwargs = { kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal, 'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads, 'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'K': aligned_k, 'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1, 'NUM_GROUPS': 1,
'BLOCK_K': block_k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'GMEM_D': out, 'NUM_STAGES': num_stages,
'NUM_LAST_STAGES': num_last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'NUM_SMS': num_sms, 'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0], 'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a, 'TENSOR_MAP_A': tensor_map_a,
@@ -113,23 +104,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'TENSOR_MAP_SCALES_B': tensor_map_scales_b, 'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d, 'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream, 'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
} }
runtime, best_keys = jit_tuner.compile_and_tune( # Generate, build and run the kernel
name='wgrad_gemm_fp8_fp8_fp32_nt', code = FP8WGradGemmRuntime.generate(kwargs)
keys={'M': m, 'N': n, runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
'BLOCK_M': block_m, 'BLOCK_N': block_n, runtime(kwargs)
'NUM_STAGES': num_stages,
'LAST_STAGES': last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8WGradGemmRuntime,
)
# Run the kernel
runtime(**best_keys, **kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -144,16 +125,16 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
This function handles multiple batches with varying k-dimensions, processing each batch sequentially. This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous. Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed. The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format. The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
Arguments: Arguments:
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
representing the per-128-channel scaling factors. representing the per-128-channel scaling factors.
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
representing the per-128-channel scaling factors. representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch. batch_sizes: A list of integers specifying the k-dimension for each batch.
@@ -164,16 +145,14 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
lhs_offset, rhs_offset, scales_offset = 0, 0, 0 lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for idx in range(num_batches): for i in range(num_batches):
k = batch_sizes[idx] k = batch_sizes[i]
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k) lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k) rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128] lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128] rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
D = out[idx] wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
lhs_offset += m * k lhs_offset += m * k
rhs_offset += n * k rhs_offset += n * k
scales_offset += (k + 127) // 128 scales_offset += ceil_div(k, 128)

View File

@@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
assert m % 4 == 0, f'TMA alignment error: {m}' assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x) x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups): for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
@@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
assert m % 4 == 0, f'TMA alignment error: {m}' assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups): for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
@@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum((k + 127) // 128 for k in k_sizes) total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
@@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = (k + 127) // 128 num_scales = ceil_div(k, 128)
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)

View File

@@ -2,6 +2,7 @@ import ctypes
import os import os
import torch import torch
import cuda.bindings.driver as cbd import cuda.bindings.driver as cbd
from typing import Any, Dict
from deep_gemm import jit from deep_gemm import jit
@@ -12,15 +13,10 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class VectorAddRuntime(jit.Runtime): class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__(path, [ super().__init__(path)
'A',
'B',
'C',
'STREAM',
])
@staticmethod @staticmethod
def generate(**kwargs) -> str: def generate(kwargs: Dict[str, Any]) -> str:
return f""" return f"""
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh> #include <deep_gemm/nvrtc_std.cuh>
@@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
# noinspection PyShadowingNames,PyMethodOverriding # noinspection PyShadowingNames,PyMethodOverriding
@staticmethod @staticmethod
def launch(kernel: cbd.CUkernel, def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
stream: cbd.CUstream) -> cbd.CUresult: assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
assert a.shape == b.shape == c.shape assert kwargs['A'].dim() == 1
assert a.device == b.device == c.device
assert a.dim() == 1
config = cbd.CUlaunchConfig() config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128 config.gridDimX = (kwargs['A'].numel() + 127) // 128
config.gridDimY = 1 config.gridDimY = 1
config.gridDimZ = 1 config.gridDimZ = 1
config.blockDimX = 128 config.blockDimX = 128
config.blockDimY = 1 config.blockDimY = 1
config.blockDimZ = 1 config.blockDimZ = 1
config.hStream = stream config.hStream = kwargs['STREAM']
arg_values = ( arg_values = (
a.data_ptr(), kwargs['A'].data_ptr(),
b.data_ptr(), kwargs['B'].data_ptr(),
c.data_ptr(), kwargs['C'].data_ptr(),
a.numel(), kwargs['A'].numel(),
) )
arg_types = ( arg_types = (
ctypes.c_void_p, ctypes.c_void_p,