mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Merge pull request #100 from deepseek-ai/remove-tuner
Refactor some launch-related structures
This commit is contained in:
@@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful:
|
|||||||
- Post optimization
|
- Post optimization
|
||||||
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
|
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
|
||||||
- Heuristic selection
|
- Heuristic selection
|
||||||
- `DG_PRINT_AUTOTUNE`: `0` or `1`, print selected configs for each shape, `0` by default
|
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||||
- Testing
|
- Testing
|
||||||
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default
|
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ namespace deep_gemm {
|
|||||||
|
|
||||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t kNumStages, uint32_t kLastStages,
|
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||||
@@ -127,7 +127,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
|||||||
struct DivisibleK {};
|
struct DivisibleK {};
|
||||||
struct NotDivisibleK {};
|
struct NotDivisibleK {};
|
||||||
auto launch_k_iterations = [&](const auto& func) {
|
auto launch_k_iterations = [&](const auto& func) {
|
||||||
if constexpr (kLastStages == 0) {
|
if constexpr (kNumLastStages == 0) {
|
||||||
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||||
func(k_iter, DivisibleK{});
|
func(k_iter, DivisibleK{});
|
||||||
} else {
|
} else {
|
||||||
@@ -155,7 +155,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
|||||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||||
launch_k_iterations([&](int k_iter, auto type) {
|
launch_k_iterations([&](int k_iter, auto type) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||||
|
|
||||||
// Assign TMA multicast number into A and B
|
// Assign TMA multicast number into A and B
|
||||||
@@ -244,7 +244,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
|||||||
// Launch MMAs
|
// Launch MMAs
|
||||||
launch_k_iterations([&](int k_iter, auto type) {
|
launch_k_iterations([&](int k_iter, auto type) {
|
||||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kLastStages;
|
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Tuple, Type
|
from typing import Any, Dict, List, Tuple, Type
|
||||||
|
|
||||||
import cuda.bindings
|
import cuda.bindings
|
||||||
import cuda.bindings.nvrtc as nvrtc
|
import cuda.bindings.nvrtc as nvrtc
|
||||||
@@ -128,7 +128,7 @@ class Compiler:
|
|||||||
return [get_jit_include_dir()]
|
return [get_jit_include_dir()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||||
# Compiler flags
|
# Compiler flags
|
||||||
flags = cls.flags()
|
flags = cls.flags()
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ class Compiler:
|
|||||||
|
|
||||||
# Check runtime cache or file system hit
|
# Check runtime cache or file system hit
|
||||||
global runtime_cache
|
global runtime_cache
|
||||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||||
if cached_runtime is not None:
|
if cached_runtime is not None:
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||||
print(f'Using cached JIT runtime {name} during build')
|
print(f'Using cached JIT runtime {name} during build')
|
||||||
@@ -166,8 +166,8 @@ class Compiler:
|
|||||||
os.replace(tmp_cubin_path, cubin_path)
|
os.replace(tmp_cubin_path, cubin_path)
|
||||||
|
|
||||||
# Put cache and return
|
# Put cache and return
|
||||||
runtime = runtime_cls(path)
|
runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||||
runtime_cache[path] = runtime
|
assert runtime is not None
|
||||||
return runtime
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
@@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler):
|
|||||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||||
|
|
||||||
|
|
||||||
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||||
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
|
return compiler_cls.build(name, code, runtime_cls, kwargs)
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
import torch
|
||||||
import cuda.bindings.driver as cbd
|
import cuda.bindings.driver as cbd
|
||||||
|
|
||||||
from typing import List, Optional, Type
|
from typing import Any, Dict, Optional, Type
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, path: str, args: List[str] = None) -> None:
|
def __init__(self, path: str) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.lib = None
|
self.lib = None
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
self.args = args
|
|
||||||
assert self.is_path_valid(self.path)
|
assert self.is_path_valid(self.path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -26,14 +26,14 @@ class Runtime:
|
|||||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> cbd.CUresult:
|
def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
# Load CUBIN
|
# Load CUBIN
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
@@ -48,8 +48,10 @@ class Runtime:
|
|||||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
|
illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail']
|
||||||
|
check_illegal = lambda line: any([name in line for name in illegal_names])
|
||||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||||
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
if line.startswith('STT_FUNC') and not check_illegal(line)]
|
||||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||||
|
|
||||||
# Load kernel from the library
|
# Load kernel from the library
|
||||||
@@ -62,7 +64,7 @@ class Runtime:
|
|||||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
||||||
|
|
||||||
# noinspection PyArgumentList
|
# noinspection PyArgumentList
|
||||||
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
|
return self.launch(self.kernel, kwargs)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self.lib is not None:
|
if self.lib is not None:
|
||||||
@@ -78,13 +80,23 @@ class RuntimeCache:
|
|||||||
def __setitem__(self, path: str, runtime: Runtime) -> None:
|
def __setitem__(self, path: str, runtime: Runtime) -> None:
|
||||||
self.cache[path] = runtime
|
self.cache[path] = runtime
|
||||||
|
|
||||||
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
|
def get(self, path: str, runtime_cls: Type[Runtime],
|
||||||
|
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
|
||||||
# In Python runtime
|
# In Python runtime
|
||||||
if path in self.cache:
|
if path in self.cache:
|
||||||
return self.cache[path]
|
return self.cache[path]
|
||||||
|
|
||||||
# Already compiled
|
# Already compiled
|
||||||
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||||
|
# Print heuristic for the first time
|
||||||
|
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
|
||||||
|
simplified_kwargs = dict()
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
|
||||||
|
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
|
||||||
|
simplified_kwargs[key] = value
|
||||||
|
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
|
||||||
|
|
||||||
runtime = runtime_cls(path)
|
runtime = runtime_cls(path)
|
||||||
self.cache[path] = runtime
|
self.cache[path] = runtime
|
||||||
return runtime
|
return runtime
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ import torch
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from ..jit import build
|
||||||
from .runtime import (
|
from .runtime import (
|
||||||
FP8GemmRuntime, GemmType,
|
FP8GemmRuntime, GemmType,
|
||||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||||
from .tuner import jit_tuner
|
|
||||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||||
|
|
||||||
|
|
||||||
@@ -18,7 +18,6 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
|
|||||||
|
|
||||||
|
|
||||||
def get_swizzle_mode(block_n: int) -> int:
|
def get_swizzle_mode(block_n: int) -> int:
|
||||||
# TODO: remove some candidates if slow
|
|
||||||
elem_size = 2
|
elem_size = 2
|
||||||
for mode_bytes in (128, 64, 32):
|
for mode_bytes in (128, 64, 32):
|
||||||
if (block_n * elem_size) % mode_bytes == 0:
|
if (block_n * elem_size) % mode_bytes == 0:
|
||||||
@@ -180,22 +179,15 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert n > 0 and k > 0
|
assert n > 0 and k > 0
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||||
|
|
||||||
lhs_stride = lhs.stride(0)
|
|
||||||
rhs_stride = rhs.stride(0)
|
|
||||||
out_stride = out.stride(0)
|
|
||||||
|
|
||||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
|
||||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 8 == 0
|
|
||||||
|
|
||||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||||
assert rhs_scales.is_contiguous()
|
assert rhs_scales.is_contiguous()
|
||||||
|
|
||||||
@@ -204,33 +196,34 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
return
|
return
|
||||||
|
|
||||||
# K must be aligned to 128
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = ceil_div(k, 128) * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||||
m, n, k, 1, num_sms)
|
|
||||||
block_k = 128
|
block_k = 128
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
tensor_map_a = make_2d_tma_a_desc(
|
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||||
tensor_map_b = make_2d_tma_b_desc(
|
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||||
tensor_map_d = make_2d_tma_d_desc(
|
|
||||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
|
||||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
|
||||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
# Templated arguments
|
||||||
'GEMM_TYPE': GemmType.Normal,
|
'GEMM_TYPE': GemmType.Normal,
|
||||||
'NUM_TMA_THREADS': num_tma_threads,
|
'NUM_TMA_THREADS': num_tma_threads,
|
||||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||||
'M': m,
|
'M': m, 'N': n, 'K': aligned_k,
|
||||||
'NUM_GROUPS': 1,
|
'NUM_GROUPS': 1,
|
||||||
'BLOCK_K': block_k,
|
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||||
'GMEM_D': out,
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
# Runtime arguments
|
||||||
'SCALES_B': rhs_scales,
|
'SCALES_B': rhs_scales,
|
||||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||||
'NUM_SMS': num_sms,
|
'NUM_SMS': num_sms,
|
||||||
@@ -240,21 +233,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||||
'TENSOR_MAP_D': tensor_map_d,
|
'TENSOR_MAP_D': tensor_map_d,
|
||||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||||
|
'DEVICE_INDEX': out.device.index
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
|
||||||
keys={'N': n, 'K': aligned_k,
|
|
||||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
|
||||||
'SWIZZLE_D_MODE': smem_config[1],
|
|
||||||
'BLOCK_N_PADDING': smem_config[2],
|
|
||||||
'NUM_STAGES': num_stages,
|
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
|
||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
|
||||||
space=(),
|
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the kernel
|
# Generate, build and run the kernel
|
||||||
runtime(**best_keys, **kwargs)
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
|
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
|
runtime(kwargs)
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from ..jit import build
|
||||||
from .gemm import get_best_configs
|
from .gemm import get_best_configs
|
||||||
from .runtime import (
|
from .runtime import (
|
||||||
FP8GemmRuntime, GemmType,
|
FP8GemmRuntime, GemmType,
|
||||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||||
from .tuner import jit_tuner
|
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
|
||||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
|
||||||
|
|
||||||
|
|
||||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -44,8 +44,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
|
|
||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ == m__ and k == k_ and n == n_
|
assert m == m_ == m__ and k == k_ and n == n_
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
@@ -69,21 +69,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
tensor_map_a = make_2d_tma_a_desc(
|
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
|
||||||
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
|
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
|
||||||
tensor_map_b = make_2d_tma_b_desc(
|
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||||
tensor_map_d = make_2d_tma_d_desc(
|
|
||||||
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
|
||||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
|
||||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
# Templated arguments
|
||||||
'NUM_TMA_THREADS': num_tma_threads,
|
'NUM_TMA_THREADS': num_tma_threads,
|
||||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||||
'M': m,
|
'M': m, 'N': n, 'K': k,
|
||||||
'BLOCK_K': block_k,
|
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||||
'GMEM_D': out,
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
|
'NUM_GROUPS': num_groups,
|
||||||
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
'GEMM_TYPE': GemmType.GroupedContiguous,
|
||||||
|
# Runtime arguments
|
||||||
'SCALES_B': rhs_scales,
|
'SCALES_B': rhs_scales,
|
||||||
'GROUPED_LAYOUT': m_indices,
|
'GROUPED_LAYOUT': m_indices,
|
||||||
'NUM_SMS': num_sms,
|
'NUM_SMS': num_sms,
|
||||||
@@ -93,25 +97,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||||
'TENSOR_MAP_D': tensor_map_d,
|
'TENSOR_MAP_D': tensor_map_d,
|
||||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||||
|
'DEVICE_INDEX': out.device.index
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
# Generate, build and run the kernel
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
'SWIZZLE_D_MODE': smem_config[1],
|
runtime(kwargs)
|
||||||
'BLOCK_N_PADDING': smem_config[2],
|
|
||||||
'NUM_GROUPS': num_groups,
|
|
||||||
'NUM_STAGES': num_stages,
|
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
|
||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
|
||||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
|
||||||
space=(),
|
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the kernel
|
|
||||||
runtime(**best_keys, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -150,8 +142,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
|
||||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.bfloat16
|
assert out.dtype == torch.bfloat16
|
||||||
@@ -176,21 +168,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
tensor_map_a = make_2d_tma_a_desc(
|
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
|
||||||
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
|
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
|
||||||
tensor_map_b = make_2d_tma_b_desc(
|
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||||
tensor_map_d = make_2d_tma_d_desc(
|
|
||||||
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
|
||||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
|
||||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
# Templated arguments
|
||||||
'NUM_TMA_THREADS': num_tma_threads,
|
'NUM_TMA_THREADS': num_tma_threads,
|
||||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||||
'M': m,
|
'M': m, 'N': n, 'K': k,
|
||||||
'BLOCK_K': block_k,
|
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||||
'GMEM_D': out,
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
|
'NUM_GROUPS': num_groups,
|
||||||
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
'GEMM_TYPE': GemmType.GroupedMasked,
|
||||||
|
# Runtime arguments
|
||||||
'SCALES_B': rhs_scales,
|
'SCALES_B': rhs_scales,
|
||||||
'GROUPED_LAYOUT': masked_m,
|
'GROUPED_LAYOUT': masked_m,
|
||||||
'NUM_SMS': num_sms,
|
'NUM_SMS': num_sms,
|
||||||
@@ -200,22 +196,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||||
'TENSOR_MAP_D': tensor_map_d,
|
'TENSOR_MAP_D': tensor_map_d,
|
||||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||||
|
'DEVICE_INDEX': out.device.index
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
# Generate, build and run the kernel
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
'SWIZZLE_D_MODE': smem_config[1],
|
runtime(kwargs)
|
||||||
'BLOCK_N_PADDING': smem_config[2],
|
|
||||||
'NUM_GROUPS': num_groups,
|
|
||||||
'NUM_STAGES': num_stages,
|
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
|
||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
|
||||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
|
||||||
space=(),
|
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the kernel
|
|
||||||
runtime(**best_keys, **kwargs)
|
|
||||||
|
|||||||
@@ -5,14 +5,10 @@ import torch
|
|||||||
import cuda.bindings.driver as cbd
|
import cuda.bindings.driver as cbd
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
from .utils import get_tma_aligned_size
|
||||||
from ..jit.runtime import Runtime
|
from ..jit.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
class Layout(enum.Enum):
|
|
||||||
RowMajor = 0
|
|
||||||
ColMajor = 1
|
|
||||||
|
|
||||||
|
|
||||||
class GemmType(enum.Enum):
|
class GemmType(enum.Enum):
|
||||||
Normal = 0
|
Normal = 0
|
||||||
GroupedContiguous = 1
|
GroupedContiguous = 1
|
||||||
@@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int
|
|||||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
def make_2d_tma_copy_desc(t: torch.Tensor,
|
||||||
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
|
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
|
||||||
stride_in_bytes: cbd.cuuint64_t,
|
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||||
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
|
||||||
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
||||||
tensor_dtype = tmap_type_map[global_address.dtype]
|
tensor_dtype = tmap_type_map[t.dtype]
|
||||||
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
||||||
tensor_dtype,
|
tensor_dtype,
|
||||||
2,
|
2,
|
||||||
global_address.data_ptr(),
|
t.data_ptr(),
|
||||||
gmem_dim,
|
gmem_dims,
|
||||||
(stride_in_bytes, ),
|
(gmem_outer_stride,),
|
||||||
smem_dim,
|
smem_dims,
|
||||||
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
||||||
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||||
swizzle_type,
|
swizzle_type,
|
||||||
@@ -86,93 +81,64 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
|||||||
return tensor_map
|
return tensor_map
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
|
def make_2d_tma_desc(t: torch.Tensor,
|
||||||
gmem_rows: int, gmem_cols: int, gmem_stride: int,
|
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
|
||||||
smem_rows: int, smem_cols: int,
|
smem_inner_dim: int, smem_outer_dim: int,
|
||||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||||
if layout == Layout.RowMajor:
|
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
|
||||||
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
|
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
|
||||||
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
|
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
|
||||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
|
||||||
else:
|
|
||||||
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
|
|
||||||
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
|
|
||||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||||
shape_m: int, shape_k: int,
|
shape_m: int, shape_k: int, m_stride: int,
|
||||||
block_m: int, block_k: int,
|
block_m: int, block_k: int,
|
||||||
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
|
num_groups: int) -> cbd.CUtensorMap:
|
||||||
a_stride = shape_k if a_stride == 0 else a_stride
|
return make_2d_tma_desc(t,
|
||||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
|
block_k, block_m)
|
||||||
block_m, block_k)
|
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||||
shape_k: int, shape_n: int,
|
shape_n: int, shape_k: int, n_stride: int,
|
||||||
block_k: int, block_n: int,
|
block_n: int, block_k: int,
|
||||||
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
|
num_groups: int) -> cbd.CUtensorMap:
|
||||||
b_stride = shape_k if b_stride == 0 else b_stride
|
return make_2d_tma_desc(t,
|
||||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
|
||||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
|
|
||||||
block_k, block_n)
|
block_k, block_n)
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||||
shape_m: int, shape_n: int,
|
shape_m: int, shape_n: int, m_stride: int,
|
||||||
block_m: int, block_n: int,
|
block_m: int, block_n: int,
|
||||||
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
|
num_groups: int,
|
||||||
|
swizzle_mode: int) -> cbd.CUtensorMap:
|
||||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||||
d_stride = shape_n if d_stride == 0 else d_stride
|
return make_2d_tma_desc(t,
|
||||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
|
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
|
||||||
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
|
|
||||||
swizzle_type_map[swizzle_mode])
|
swizzle_type_map[swizzle_mode])
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||||
|
shape_mn: int, shape_k: int,
|
||||||
|
block_mn: int, block_k: int,
|
||||||
|
num_groups: int) -> cbd.CUtensorMap:
|
||||||
# Make TMA aligned to 16 bytes
|
# Make TMA aligned to 16 bytes
|
||||||
tma_alignment = 16 / global_address.element_size()
|
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
|
||||||
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
|
return make_2d_tma_desc(t,
|
||||||
|
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
|
||||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
block_mn, 1,
|
||||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
|
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
|
||||||
|
|
||||||
|
|
||||||
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
|
||||||
# Make TMA aligned to 16 bytes
|
|
||||||
tma_alignment = 16 / global_address.element_size()
|
|
||||||
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
|
|
||||||
|
|
||||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
|
||||||
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
|
||||||
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
|
||||||
|
|
||||||
|
|
||||||
class FP8GemmRuntime(Runtime):
|
class FP8GemmRuntime(Runtime):
|
||||||
def __init__(self, path: str) -> None:
|
def __init__(self, path: str) -> None:
|
||||||
super().__init__(path, [
|
super().__init__(path)
|
||||||
'NUM_TMA_MULTICAST',
|
|
||||||
'M',
|
|
||||||
'BLOCK_M',
|
|
||||||
'GMEM_D',
|
|
||||||
'SCALES_B',
|
|
||||||
'GROUPED_LAYOUT',
|
|
||||||
'NUM_SMS',
|
|
||||||
'SMEM_SIZE',
|
|
||||||
'TENSOR_MAP_A',
|
|
||||||
'TENSOR_MAP_B',
|
|
||||||
'TENSOR_MAP_SCALES_A',
|
|
||||||
'TENSOR_MAP_D',
|
|
||||||
'STREAM',
|
|
||||||
])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
code = f'''
|
code = f'''
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
@@ -213,21 +179,16 @@ static void __instantiate_kernel() {{
|
|||||||
|
|
||||||
# noinspection PyMethodOverriding
|
# noinspection PyMethodOverriding
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
|
||||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
|
||||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
|
||||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
|
|
||||||
stream: cbd.CUstream) -> cbd.CUresult:
|
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||||
|
|
||||||
attr_val = cbd.CUlaunchAttributeValue()
|
attr_val = cbd.CUlaunchAttributeValue()
|
||||||
attr_val.clusterDim.x = num_tma_multicast
|
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||||
attr_val.clusterDim.y = 1
|
attr_val.clusterDim.y = 1
|
||||||
attr_val.clusterDim.z = 1
|
attr_val.clusterDim.z = 1
|
||||||
attr = cbd.CUlaunchAttribute()
|
attr = cbd.CUlaunchAttribute()
|
||||||
@@ -237,23 +198,23 @@ static void __instantiate_kernel() {{
|
|||||||
config = cbd.CUlaunchConfig()
|
config = cbd.CUlaunchConfig()
|
||||||
config.numAttrs = 1
|
config.numAttrs = 1
|
||||||
config.attrs = [attr]
|
config.attrs = [attr]
|
||||||
config.gridDimX = num_sms
|
config.gridDimX = kwargs['NUM_SMS']
|
||||||
config.gridDimY = 1
|
config.gridDimY = 1
|
||||||
config.gridDimZ = 1
|
config.gridDimZ = 1
|
||||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||||
config.blockDimY = 1
|
config.blockDimY = 1
|
||||||
config.blockDimZ = 1
|
config.blockDimZ = 1
|
||||||
config.sharedMemBytes = smem_size
|
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||||
config.hStream = stream
|
config.hStream = kwargs['STREAM']
|
||||||
|
|
||||||
arg_values = (
|
arg_values = (
|
||||||
scales_b.data_ptr(),
|
kwargs['SCALES_B'].data_ptr(),
|
||||||
grouped_layout.data_ptr(),
|
kwargs['GROUPED_LAYOUT'].data_ptr(),
|
||||||
shape_m,
|
kwargs['M'],
|
||||||
tensor_map_a,
|
kwargs['TENSOR_MAP_A'],
|
||||||
tensor_map_b,
|
kwargs['TENSOR_MAP_B'],
|
||||||
tensor_map_scales_a,
|
kwargs['TENSOR_MAP_SCALES_A'],
|
||||||
tensor_map_d,
|
kwargs['TENSOR_MAP_D'],
|
||||||
)
|
)
|
||||||
arg_types = (
|
arg_types = (
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
@@ -269,23 +230,10 @@ static void __instantiate_kernel() {{
|
|||||||
|
|
||||||
class FP8WGradGemmRuntime(Runtime):
|
class FP8WGradGemmRuntime(Runtime):
|
||||||
def __init__(self, path: str) -> None:
|
def __init__(self, path: str) -> None:
|
||||||
super().__init__(path, [
|
super().__init__(path)
|
||||||
'NUM_TMA_MULTICAST',
|
|
||||||
'K',
|
|
||||||
'BLOCK_M',
|
|
||||||
'GMEM_D',
|
|
||||||
'NUM_SMS',
|
|
||||||
'SMEM_SIZE',
|
|
||||||
'TENSOR_MAP_A',
|
|
||||||
'TENSOR_MAP_B',
|
|
||||||
'TENSOR_MAP_SCALES_A',
|
|
||||||
'TENSOR_MAP_SCALES_B',
|
|
||||||
'TENSOR_MAP_D',
|
|
||||||
'STREAM',
|
|
||||||
])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
code = f'''
|
code = f'''
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
@@ -309,7 +257,7 @@ static void __instantiate_kernel() {{
|
|||||||
{kwargs['BLOCK_N']},
|
{kwargs['BLOCK_N']},
|
||||||
{kwargs['BLOCK_K']},
|
{kwargs['BLOCK_K']},
|
||||||
{kwargs['NUM_STAGES']},
|
{kwargs['NUM_STAGES']},
|
||||||
{kwargs['LAST_STAGES']},
|
{kwargs['NUM_LAST_STAGES']},
|
||||||
{kwargs['NUM_TMA_THREADS']},
|
{kwargs['NUM_TMA_THREADS']},
|
||||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||||
{kwargs['NUM_TMA_MULTICAST']},
|
{kwargs['NUM_TMA_MULTICAST']},
|
||||||
@@ -323,21 +271,16 @@ static void __instantiate_kernel() {{
|
|||||||
|
|
||||||
# noinspection PyMethodOverriding
|
# noinspection PyMethodOverriding
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
|
|
||||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
|
||||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
|
|
||||||
tensor_map_d: cbd.CUtensorMap,
|
|
||||||
stream: cbd.CUstream) -> cbd.CUresult:
|
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||||
|
|
||||||
attr_val = cbd.CUlaunchAttributeValue()
|
attr_val = cbd.CUlaunchAttributeValue()
|
||||||
attr_val.clusterDim.x = num_tma_multicast
|
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||||
attr_val.clusterDim.y = 1
|
attr_val.clusterDim.y = 1
|
||||||
attr_val.clusterDim.z = 1
|
attr_val.clusterDim.z = 1
|
||||||
attr = cbd.CUlaunchAttribute()
|
attr = cbd.CUlaunchAttribute()
|
||||||
@@ -347,22 +290,22 @@ static void __instantiate_kernel() {{
|
|||||||
config = cbd.CUlaunchConfig()
|
config = cbd.CUlaunchConfig()
|
||||||
config.numAttrs = 1
|
config.numAttrs = 1
|
||||||
config.attrs = [attr]
|
config.attrs = [attr]
|
||||||
config.gridDimX = num_sms
|
config.gridDimX = kwargs['NUM_SMS']
|
||||||
config.gridDimY = 1
|
config.gridDimY = 1
|
||||||
config.gridDimZ = 1
|
config.gridDimZ = 1
|
||||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||||
config.blockDimY = 1
|
config.blockDimY = 1
|
||||||
config.blockDimZ = 1
|
config.blockDimZ = 1
|
||||||
config.sharedMemBytes = smem_size
|
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||||
config.hStream = stream
|
config.hStream = kwargs['STREAM']
|
||||||
|
|
||||||
arg_values = (
|
arg_values = (
|
||||||
shape_k,
|
kwargs['K'],
|
||||||
tensor_map_a,
|
kwargs['TENSOR_MAP_A'],
|
||||||
tensor_map_b,
|
kwargs['TENSOR_MAP_B'],
|
||||||
tensor_map_scales_a,
|
kwargs['TENSOR_MAP_SCALES_A'],
|
||||||
tensor_map_scales_b,
|
kwargs['TENSOR_MAP_SCALES_B'],
|
||||||
tensor_map_d,
|
kwargs['TENSOR_MAP_D'],
|
||||||
)
|
)
|
||||||
arg_types = (
|
arg_types = (
|
||||||
ctypes.c_uint32,
|
ctypes.c_uint32,
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
import copy
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import cuda.bindings.driver as cbd
|
|
||||||
from typing import Any, Callable, Dict, Type, Tuple
|
|
||||||
|
|
||||||
from ..jit import build, Runtime
|
|
||||||
|
|
||||||
|
|
||||||
class JITTuner:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.tuned = {}
|
|
||||||
|
|
||||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
|
||||||
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
|
||||||
# NOTES: we always assume the space, template and GPU devices will not change
|
|
||||||
# NOTES: the function must have no accumulated side effects
|
|
||||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
|
||||||
signature = (name, f'{keys}')
|
|
||||||
if signature in self.tuned:
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
||||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
|
||||||
return self.tuned[signature]
|
|
||||||
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
||||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
|
||||||
|
|
||||||
assert signature not in self.tuned
|
|
||||||
assert kwargs is not None
|
|
||||||
space = (dict(), ) if len(space) == 0 else space
|
|
||||||
|
|
||||||
kernels = []
|
|
||||||
for tuned_keys in space:
|
|
||||||
assert isinstance(tuned_keys, dict)
|
|
||||||
full_keys = copy.deepcopy(keys)
|
|
||||||
full_keys.update(tuned_keys)
|
|
||||||
code = runtime_cls.generate(**kwargs, **full_keys)
|
|
||||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
|
||||||
|
|
||||||
# TODO: fix tuning with space > 1
|
|
||||||
best_runtime, best_time, best_keys = None, None, None
|
|
||||||
for runtime, tuned_keys in kernels:
|
|
||||||
if len(space) > 1:
|
|
||||||
# Check kernel validity
|
|
||||||
return_code = runtime(**tuned_keys, **kwargs)
|
|
||||||
if return_code != cbd.CUresult.CUDA_SUCCESS:
|
|
||||||
# Pass illegal kernels, e.g., insufficient shared memory capacity
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
||||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
torch.empty(int(256e6 // 4), dtype=torch.int,
|
|
||||||
device='cuda').zero_()
|
|
||||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
|
|
||||||
(8192, 8192), dtype=torch.float, device='cuda')
|
|
||||||
start_event.record()
|
|
||||||
for i in range(20):
|
|
||||||
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
elapsed_time = start_event.elapsed_time(end_event)
|
|
||||||
else:
|
|
||||||
elapsed_time = 0
|
|
||||||
|
|
||||||
# Compare if better
|
|
||||||
if best_time is None or elapsed_time < best_time:
|
|
||||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
||||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
|
||||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
|
||||||
|
|
||||||
# Cache the best runtime and return
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
|
|
||||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
|
||||||
self.tuned[signature] = (best_runtime, best_keys)
|
|
||||||
return best_runtime, best_keys
|
|
||||||
|
|
||||||
|
|
||||||
jit_tuner = JITTuner()
|
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from ..jit import build
|
||||||
from .runtime import (
|
from .runtime import (
|
||||||
FP8WGradGemmRuntime, GemmType,
|
FP8WGradGemmRuntime, GemmType,
|
||||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
|
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||||
from .gemm import get_best_configs
|
from .gemm import get_best_configs
|
||||||
from .tuner import jit_tuner
|
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||||
from .utils import get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
|
||||||
|
|
||||||
|
|
||||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
out: Tuple[torch.Tensor, torch.Tensor]):
|
out: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||||
Results will be accumulated into the output tensor.
|
Results will be accumulated into the output tensor.
|
||||||
@@ -21,8 +21,8 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
||||||
RHS and RHS scaling factors are required to be transposed.
|
RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling and RHS scaling tensor require TMA-aligned transposed format, if your input does not match the requirement,
|
The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
|
||||||
this function will do a transposing with a set of slow PyTorch operations.
|
If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||||
@@ -40,71 +40,62 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Type and shape checks
|
# Type and shape checks
|
||||||
assert m == m_ and n == n_ and k == k_
|
assert m == m_ and n == n_ and k == k_
|
||||||
assert n > 0 and m > 0
|
assert n > 0 and m > 0
|
||||||
assert lhs_scales.shape == (m, (k + 127) // 128) or lhs_scales.shape == ((k + 127) // 128, m)
|
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
|
||||||
assert rhs_scales.shape == (n, (k + 127) // 128) or rhs_scales.shape == ((k + 127) // 128, n)
|
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
|
||||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||||
assert out.dtype == torch.float
|
assert out.dtype == torch.float
|
||||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||||
|
|
||||||
lhs_stride = lhs.stride(0)
|
|
||||||
rhs_stride = rhs.stride(0)
|
|
||||||
out_stride = out.stride(0)
|
|
||||||
|
|
||||||
# The stride(0) of LHS, RHS, and output must be aligned to 16 bytes
|
|
||||||
assert lhs_stride % 16 == 0 and rhs_stride % 16 == 0 and out_stride % 4 == 0
|
|
||||||
|
|
||||||
# LHS and RHS scales must be transposed for TMA load
|
# LHS and RHS scales must be transposed for TMA load
|
||||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||||
if lhs_scales.shape == ((k + 127) // 128, m):
|
def get_valid_scales(scales: torch.Tensor, mn: int):
|
||||||
lhs_scales = lhs_scales.permute(1, 0)
|
if scales.shape == (ceil_div(k, 128), mn):
|
||||||
assert get_tma_aligned_size(m, 4) == m and lhs_scales.stride(1) == m
|
# For k-grouped GEMMs
|
||||||
else:
|
scales = scales.permute(1, 0)
|
||||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
|
||||||
assert lhs_scales.stride(0) == 1
|
else:
|
||||||
|
scales = get_col_major_tma_aligned_tensor(scales)
|
||||||
if rhs_scales.shape == ((k + 127) // 128, n):
|
return scales
|
||||||
rhs_scales = rhs_scales.permute(1, 0)
|
|
||||||
assert get_tma_aligned_size(n, 4) == n and rhs_scales.stride(1) == n
|
lhs_scales = get_valid_scales(lhs_scales, m)
|
||||||
else:
|
rhs_scales = get_valid_scales(rhs_scales, n)
|
||||||
rhs_scales = get_col_major_tma_aligned_tensor(rhs_scales)
|
|
||||||
assert rhs_scales.stride(0) == 1
|
|
||||||
|
|
||||||
# Do nothing if `k` is zero
|
# Do nothing if `k` is zero
|
||||||
if k == 0:
|
if k == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# K must be aligned to 128
|
# K must be aligned to 128
|
||||||
aligned_k = (k + 127) // 128 * 128
|
aligned_k = ceil_div(k, 128) * 128
|
||||||
|
|
||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||||
last_stages = (k + 127) // 128 % num_stages
|
num_last_stages = ceil_div(k, 128) % num_stages
|
||||||
block_k = 128
|
block_k = 128
|
||||||
num_tma_threads = 128
|
num_tma_threads = 128
|
||||||
num_math_threads_per_group = 128
|
num_math_threads_per_group = 128
|
||||||
|
|
||||||
tensor_map_a = make_2d_tma_a_desc(
|
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1, a_stride=lhs_stride)
|
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||||
tensor_map_b = make_2d_tma_b_desc(
|
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1, b_stride=rhs_stride)
|
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||||
tensor_map_d = make_2d_tma_d_desc(
|
tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
|
||||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1], d_stride=out_stride)
|
|
||||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
|
||||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
|
||||||
tensor_map_scales_b = make_2d_tma_scales_b_desc(
|
|
||||||
GemmType.Normal, rhs_scales, n, k, block_n, block_k)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
# Templated arguments
|
||||||
'GEMM_TYPE': GemmType.Normal,
|
'GEMM_TYPE': GemmType.Normal,
|
||||||
'NUM_TMA_THREADS': num_tma_threads,
|
'NUM_TMA_THREADS': num_tma_threads,
|
||||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||||
'K': aligned_k,
|
'M': m, 'N': n, 'K': aligned_k,
|
||||||
'NUM_GROUPS': 1,
|
'NUM_GROUPS': 1,
|
||||||
'BLOCK_K': block_k,
|
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||||
'GMEM_D': out,
|
'NUM_STAGES': num_stages,
|
||||||
|
'NUM_LAST_STAGES': num_last_stages,
|
||||||
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||||
|
# Runtime arguments
|
||||||
'NUM_SMS': num_sms,
|
'NUM_SMS': num_sms,
|
||||||
'SMEM_SIZE': smem_config[0],
|
'SMEM_SIZE': smem_config[0],
|
||||||
'TENSOR_MAP_A': tensor_map_a,
|
'TENSOR_MAP_A': tensor_map_a,
|
||||||
@@ -113,23 +104,13 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
|
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
|
||||||
'TENSOR_MAP_D': tensor_map_d,
|
'TENSOR_MAP_D': tensor_map_d,
|
||||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||||
|
'DEVICE_INDEX': out.device.index
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
# Generate, build and run the kernel
|
||||||
name='wgrad_gemm_fp8_fp8_fp32_nt',
|
code = FP8WGradGemmRuntime.generate(kwargs)
|
||||||
keys={'M': m, 'N': n,
|
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
||||||
'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
runtime(kwargs)
|
||||||
'NUM_STAGES': num_stages,
|
|
||||||
'LAST_STAGES': last_stages,
|
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
|
||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
|
||||||
space=(),
|
|
||||||
kwargs=kwargs,
|
|
||||||
runtime_cls=FP8WGradGemmRuntime,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the kernel
|
|
||||||
runtime(**best_keys, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -144,16 +125,16 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||||
The RHS and RHS scaling factors are required to be transposed.
|
The RHS and RHS scaling factors are required to be transposed.
|
||||||
The LHS scaling and RHS scaling tensors require TMA-aligned transposed format.
|
The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
lhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||||
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
||||||
the second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||||
representing the per-128-channel scaling factors.
|
representing the per-128-channel scaling factors.
|
||||||
rhs: the first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||||
the second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||||
representing the per-128-channel scaling factors.
|
representing the per-128-channel scaling factors.
|
||||||
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
||||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||||
@@ -164,16 +145,14 @@ def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
|
|
||||||
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
||||||
|
|
||||||
for idx in range(num_batches):
|
for i in range(num_batches):
|
||||||
k = batch_sizes[idx]
|
k = batch_sizes[i]
|
||||||
A = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||||
B = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||||
A_scales = lhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||||
B_scales = rhs_scales[scales_offset:scales_offset + (k + 127) // 128]
|
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||||
D = out[idx]
|
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
||||||
|
|
||||||
wgrad_gemm_fp8_fp8_fp32_nt((A, A_scales), (B, B_scales), D)
|
|
||||||
|
|
||||||
lhs_offset += m * k
|
lhs_offset += m * k
|
||||||
rhs_offset += n * k
|
rhs_offset += n * k
|
||||||
scales_offset += (k + 127) // 128
|
scales_offset += ceil_div(k, 128)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = per_token_cast_to_fp8(x)
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
|||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||||
@@ -137,7 +137,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
|||||||
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
|
||||||
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
total_scale_factors = sum((k + 127) // 128 for k in k_sizes)
|
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
|
||||||
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
|
||||||
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
|
||||||
|
|
||||||
@@ -150,7 +150,7 @@ def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
|
|||||||
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
|
||||||
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
|
||||||
|
|
||||||
num_scales = (k + 127) // 128
|
num_scales = ceil_div(k, 128)
|
||||||
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
|
||||||
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import ctypes
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import cuda.bindings.driver as cbd
|
import cuda.bindings.driver as cbd
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from deep_gemm import jit
|
from deep_gemm import jit
|
||||||
|
|
||||||
@@ -12,15 +13,10 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
|||||||
|
|
||||||
class VectorAddRuntime(jit.Runtime):
|
class VectorAddRuntime(jit.Runtime):
|
||||||
def __init__(self, path: str) -> None:
|
def __init__(self, path: str) -> None:
|
||||||
super().__init__(path, [
|
super().__init__(path)
|
||||||
'A',
|
|
||||||
'B',
|
|
||||||
'C',
|
|
||||||
'STREAM',
|
|
||||||
])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
@@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
|
|||||||
|
|
||||||
# noinspection PyShadowingNames,PyMethodOverriding
|
# noinspection PyShadowingNames,PyMethodOverriding
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel,
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
|
||||||
stream: cbd.CUstream) -> cbd.CUresult:
|
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
|
||||||
assert a.shape == b.shape == c.shape
|
assert kwargs['A'].dim() == 1
|
||||||
assert a.device == b.device == c.device
|
|
||||||
assert a.dim() == 1
|
|
||||||
|
|
||||||
config = cbd.CUlaunchConfig()
|
config = cbd.CUlaunchConfig()
|
||||||
config.gridDimX = (a.numel() + 127) // 128
|
config.gridDimX = (kwargs['A'].numel() + 127) // 128
|
||||||
config.gridDimY = 1
|
config.gridDimY = 1
|
||||||
config.gridDimZ = 1
|
config.gridDimZ = 1
|
||||||
config.blockDimX = 128
|
config.blockDimX = 128
|
||||||
config.blockDimY = 1
|
config.blockDimY = 1
|
||||||
config.blockDimZ = 1
|
config.blockDimZ = 1
|
||||||
config.hStream = stream
|
config.hStream = kwargs['STREAM']
|
||||||
|
|
||||||
arg_values = (
|
arg_values = (
|
||||||
a.data_ptr(),
|
kwargs['A'].data_ptr(),
|
||||||
b.data_ptr(),
|
kwargs['B'].data_ptr(),
|
||||||
c.data_ptr(),
|
kwargs['C'].data_ptr(),
|
||||||
a.numel(),
|
kwargs['A'].numel(),
|
||||||
)
|
)
|
||||||
arg_types = (
|
arg_types = (
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
|
|||||||
Reference in New Issue
Block a user