mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Some lints and refactor
This commit is contained in:
@@ -19,6 +19,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Skip useless computation on M
|
||||
- [x] NVRTC as a faster compiler
|
||||
- [ ] Fully remove NVCC compilation
|
||||
- [ ] Sanitizer for testing
|
||||
- [ ] Weight gradient kernels for dense models
|
||||
- [ ] Weight gradient kernels for MoE models
|
||||
@@ -105,12 +106,13 @@ The library provides some utility functions besides the above kernels:
|
||||
The library also provides some environment variables, which may be useful:
|
||||
|
||||
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
|
||||
- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default
|
||||
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
|
||||
- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
|
||||
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
|
||||
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
|
||||
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details
|
||||
- `DG_JIT_PRINT_NVCC_COMMAND`: 0 or 1, print NVCC compilation command
|
||||
- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command
|
||||
- `DG_JIT_DEBUG`: 0 or 1, print more debugging information
|
||||
|
||||
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
|
||||
|
||||
@@ -84,8 +84,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#ifndef __CUDACC_RTC__
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#ifdef __CUDACC_RTC__
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
@@ -32,8 +32,7 @@ using cuuint64_t = unsigned long long;
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
struct CUtensorMap_st
|
||||
{
|
||||
struct CUtensorMap_st {
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
@@ -46,16 +45,16 @@ using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
using type = integral_constant;
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept {
|
||||
return value;
|
||||
} // since c++14
|
||||
__device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
@@ -69,6 +68,7 @@ template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
@@ -89,6 +89,7 @@ struct make_index_sequence
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
@@ -96,6 +97,7 @@ using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
|
||||
@@ -46,7 +46,7 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
@@ -63,7 +63,8 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
|
||||
uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#endif
|
||||
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// TODO: move this function to other files
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
|
||||
@@ -1,39 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
||||
asm volatile("trap;");
|
||||
}
|
||||
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#define DG_HOST_ASSERT(cond) ((void)0)
|
||||
#else
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
|
||||
from .template import generate
|
||||
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
|
||||
from .runtime import Runtime
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
@@ -14,7 +13,7 @@ import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -32,11 +31,11 @@ def get_jit_include_dir() -> str:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
md5 = hashlib.md5()
|
||||
|
||||
# Update include directories
|
||||
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
|
||||
assert os.path.exists(
|
||||
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(os.path.join(include_dir, filename), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
@@ -98,24 +97,20 @@ def make_tmp_dir():
|
||||
|
||||
|
||||
def put(path, data):
|
||||
is_binary = isinstance(data, bytes)
|
||||
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
class Compiler(abc.ABC):
|
||||
class Compiler:
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -132,13 +127,12 @@ class Compiler(abc.ABC):
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
@@ -147,7 +141,7 @@ class Compiler(abc.ABC):
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return cached_runtime
|
||||
|
||||
@@ -160,9 +154,8 @@ class Compiler(abc.ABC):
|
||||
cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
@@ -177,12 +170,12 @@ class Compiler(abc.ABC):
|
||||
return runtime
|
||||
|
||||
|
||||
class NvccCompiler(Compiler):
|
||||
class NVCCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return (major, minor)
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
@@ -197,7 +190,7 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str):
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
@@ -205,26 +198,23 @@ class NvccCompiler(Compiler):
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
|
||||
assert False, f'Failed to compile {src_path}'
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
class NVRTCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
res, major, minor = nvrtc.nvrtcVersion()
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
# Failed to get actual NVRTC version, use bindings version instead
|
||||
# Failed to get the actual NVRTC version, use cuda-bindings version instead
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return (major, minor)
|
||||
return major, minor
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
@@ -238,54 +228,51 @@ class NvrtcCompiler(Compiler):
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch']
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Create program
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
result, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to create program: {res}')
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
|
||||
|
||||
# Compile
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
program, len(options), options)[0]
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with options: {options}')
|
||||
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
|
||||
|
||||
# Print compiler log
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get program log size: {res}')
|
||||
log_bytes = bytes(log_size)
|
||||
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get program log: {res}')
|
||||
log_str = log_bytes.decode('utf-8')
|
||||
print(log_str)
|
||||
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
|
||||
print(f'Compiler log: {log_bytes.decode("utf-8")}')
|
||||
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to compile program: {compile_res}')
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get CUBIN size: {res}')
|
||||
# Exit if failed
|
||||
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
|
||||
|
||||
# Create CUBIN
|
||||
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get CUBIN: {res}')
|
||||
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
|
||||
|
||||
# Write into the file system
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to destroy program: {res}')
|
||||
# Destroy handler
|
||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
else:
|
||||
return NvccCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
|
||||
|
||||
@@ -37,7 +37,7 @@ def extract_ffma(sass):
|
||||
collected.append((f'{arch_name}::{func_name}', current))
|
||||
current = []
|
||||
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
|
||||
print(f'Found {len(collected)} FFMA segments')
|
||||
return collected
|
||||
|
||||
@@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines):
|
||||
dst_reg_set.add(dst_reg)
|
||||
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||
last_reused, last_dst_reg = reused, dst_reg
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
|
||||
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
|
||||
|
||||
# Find the offset
|
||||
@@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines):
|
||||
|
||||
|
||||
def process(path):
|
||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||
if int(os.getenv('DG_PRINT_REG_REUSE', 0)):
|
||||
print(f'Processing {path}')
|
||||
output = run_cuobjdump(path)
|
||||
segments = extract_ffma(output)
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
def __init__(self, path: str, kernel_name: str = None,
|
||||
caller: Callable[..., cuda.CUresult] = None,
|
||||
args: List[str] = None) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
@@ -27,7 +27,7 @@ class Runtime:
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
@@ -59,9 +59,8 @@ class Runtime:
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
@@ -75,25 +74,6 @@ class Runtime:
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
@@ -101,14 +81,14 @@ class RuntimeCache:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
|
||||
#else
|
||||
|
||||
#include <string>
|
||||
#include <cuda.h>
|
||||
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
'''
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Generated code:\n{code}')
|
||||
|
||||
return code
|
||||
@@ -1,164 +0,0 @@
|
||||
import ctypes
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
|
||||
class Layout(Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
GroupedMasked = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
0: 'Normal',
|
||||
1: 'GroupedContiguous',
|
||||
2: 'GroupedMasked',
|
||||
}[self.value]
|
||||
|
||||
|
||||
typename_map: Dict[Any, str] = {
|
||||
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
|
||||
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
|
||||
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
||||
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
}
|
||||
|
||||
swizzle_map = {
|
||||
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
|
||||
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
}
|
||||
|
||||
def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
|
||||
tensor_dtype = typename_map[global_address.dtype]
|
||||
res, tensor_map = cuda.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2, # tensor rank
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes,), # global strides
|
||||
smem_dim,
|
||||
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
|
||||
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
)
|
||||
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to encode tensor map: {res}')
|
||||
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
||||
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
kAlignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
|
||||
attr_val = cuda.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cuda.CUlaunchAttribute()
|
||||
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
|
||||
kernelValues = (
|
||||
gmem_d.data_ptr(),
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
)
|
||||
kernelTypes = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)
|
||||
@@ -3,8 +3,8 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .runtime import FP8GemmRuntime, generate
|
||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
@@ -122,7 +122,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
assert best_smem_config is not None
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast and whether broadcast on A
|
||||
# Decide the number of TMA multicasts and whether broadcast on A
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
@@ -155,13 +155,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
@@ -183,7 +183,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
@@ -201,11 +201,11 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k)
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k, 1)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n)
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n, 1)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
|
||||
GemmType.Normal, out, m, n, block_m, block_n, 1, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
@@ -237,7 +237,9 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
from .gemm import get_best_configs
|
||||
from .runtime import FP8GemmRuntime, generate
|
||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
@@ -15,7 +15,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
@@ -23,11 +23,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`,
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
`m_indices[i]` records the group which the i-th row of the LHS belong to,
|
||||
`m_indices[i]` records the group which the i-th row of the LHS belongs to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
"""
|
||||
@@ -70,7 +70,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
GemmType.GroupedContiguous, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
@@ -103,6 +103,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
@@ -116,7 +118,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
@@ -125,7 +127,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
@@ -157,7 +159,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
@@ -175,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
GemmType.GroupedMasked, out, m, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
@@ -208,6 +209,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
|
||||
256
deep_gemm/jit_kernels/runtime.py
Normal file
256
deep_gemm/jit_kernels/runtime.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import ctypes
|
||||
import os
|
||||
import enum
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from ..jit.runtime import Runtime
|
||||
|
||||
|
||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
'''
|
||||
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated code:\n{code}')
|
||||
return code
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', launch, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(enum.Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
GroupedMasked = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
0: 'Normal',
|
||||
1: 'GroupedContiguous',
|
||||
2: 'GroupedMasked',
|
||||
}[self.value]
|
||||
|
||||
|
||||
tmap_type_map: Dict[Any, str] = {
|
||||
torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
|
||||
torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
|
||||
torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
||||
torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
}
|
||||
|
||||
swizzle_type_map = {
|
||||
0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
|
||||
64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
}
|
||||
|
||||
|
||||
def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
|
||||
stride_in_bytes: cbd.cuuint64_t,
|
||||
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
||||
tensor_dtype = tmap_type_map[global_address.dtype]
|
||||
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2,
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes, ),
|
||||
smem_dim,
|
||||
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
||||
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
)
|
||||
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to encode tensor map: {res}')
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
|
||||
gmem_rows: int, gmem_cols: int,
|
||||
smem_rows: int, smem_cols: int,
|
||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_k: int,
|
||||
block_m: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k,
|
||||
block_m, block_k)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_k: int, shape_n: int,
|
||||
block_k: int, block_n: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1),
|
||||
block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_n: int,
|
||||
block_m: int, block_n: int,
|
||||
num_groups: int, swizzle_mode: int) -> cbd.CUtensorMap:
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
||||
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
|
||||
swizzle_type_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1),
|
||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
|
||||
arg_values = (
|
||||
gmem_d.data_ptr(),
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
@@ -1,29 +1,28 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Callable, Dict, Type, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from ..jit import build, generate, Runtime
|
||||
from ..jit import build, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||
# NOTES: we always assume the space, template and GPU devices will not change
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
@@ -35,19 +34,19 @@ class JITTuner:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code), full_keys))
|
||||
code = generator(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
||||
|
||||
# TODO: fix tuning with space > 1
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(**tuned_keys, **kwargs)
|
||||
if return_code != cuda.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
if return_code != cbd.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g., insufficient shared memory capacity
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
@@ -59,7 +58,7 @@ class JITTuner:
|
||||
(8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
|
||||
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
@@ -69,13 +68,12 @@ class JITTuner:
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
|
||||
print(
|
||||
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = (best_runtime, best_keys)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
# Essential debugging staffs
|
||||
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
|
||||
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
|
||||
|
||||
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def launch_vector_add(kernel: cuda.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cuda.CUstream) -> cuda.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
@@ -24,28 +29,25 @@ def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: t
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
kernelValues = (
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
n,
|
||||
)
|
||||
kernelTypes = (
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||
|
||||
|
||||
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
|
||||
def generate_vector_add(**kwargs) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
@@ -63,14 +65,14 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
}}
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', run_vector_add, [
|
||||
super().__init__(path, 'vector_add', launch_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
@@ -79,38 +81,25 @@ class VectorAddRuntime(jit.Runtime):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# NVCC
|
||||
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='float')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
|
||||
print()
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
for compiler_name in ('NVCC', 'NVRTC'):
|
||||
# Get compiler
|
||||
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
|
||||
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
|
||||
|
||||
print('JIT test for NVCC passed\n')
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = compiler_cls.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
# NVRTC
|
||||
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='__nv_bfloat16')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test for NVRTC passed')
|
||||
# Run and check
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
torch.testing.assert_close(c, a + b)
|
||||
print(f'JIT test for {compiler_name} passed\n')
|
||||
|
||||
Reference in New Issue
Block a user