mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 21:44:21 +00:00
Merge 8aff6309d4
into d374456787
This commit is contained in:
commit
5acb5a9ec5
@ -18,7 +18,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [ ] Skip useless computation on M
|
||||
- [ ] NVRTC as a faster compiler
|
||||
- [x] NVRTC as a faster compiler
|
||||
- [ ] Sanitizer for testing
|
||||
- [ ] Weight gradient kernels for dense models
|
||||
- [ ] Weight gradient kernels for MoE models
|
||||
|
@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
// `tensor_map_d` is only used in swizzling mode
|
||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||
if constexpr (kSwizzleDMode > 0)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@ -447,119 +447,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
BLOCK_N_PADDING,
|
||||
kSwizzleDMode,
|
||||
kNumGroups, kNumStages,
|
||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
101
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
101
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
@ -0,0 +1,101 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
using uint16_t = unsigned short;
|
||||
using int32_t = signed int;
|
||||
using uint32_t = unsigned int;
|
||||
using int64_t = signed long long;
|
||||
using uint64_t = unsigned long long;
|
||||
using cuuint64_t = unsigned long long;
|
||||
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
struct CUtensorMap_st
|
||||
{
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
_Alignas(64)
|
||||
#endif
|
||||
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
|
||||
};
|
||||
|
||||
using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept {
|
||||
return value;
|
||||
} // since c++14
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
using true_type = integral_constant<bool, true>;
|
||||
|
||||
template <class T, class U> struct is_same : false_type {};
|
||||
|
||||
template <class T> struct is_same<T, T> : true_type {};
|
||||
|
||||
template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
using value_type = size_t;
|
||||
static constexpr size_t size() noexcept { return sizeof...(Ints); }
|
||||
};
|
||||
|
||||
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
|
||||
|
||||
template <size_t... I1, size_t... I2>
|
||||
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
|
||||
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||
|
||||
template <size_t N>
|
||||
struct make_index_sequence
|
||||
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
|
||||
typename make_index_sequence<N - N / 2>::type> {};
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
} // namespace std
|
||||
|
||||
#endif
|
@ -1,85 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#endif
|
||||
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map = {};
|
||||
uint64_t global_stride[1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[2] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
@ -16,8 +17,12 @@ public:
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#define DG_HOST_ASSERT(cond) ((void)0)
|
||||
#else
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
@ -27,6 +32,7 @@ do { \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
|
||||
from .template import generate
|
||||
from .runtime import Runtime
|
||||
|
@ -1,15 +1,20 @@
|
||||
import hashlib
|
||||
import abc
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@ -22,21 +27,22 @@ def hash_to_hex(s: str) -> str:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
|
||||
assert os.path.exists(
|
||||
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
with open(os.path.join(include_dir, filename), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
@ -46,14 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
|
||||
nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc'
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
command = [path, '--version']
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
match = version_pattern.search(result.stdout)
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
@ -67,17 +78,17 @@ def get_default_user_dir():
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
return os.path.join(get_default_user_dir(), 'tmp')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
return os.path.join(get_default_user_dir(), 'cache')
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
@ -86,67 +97,195 @@ def make_tmp_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
def put(path, data):
|
||||
is_binary = isinstance(data, bytes)
|
||||
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
class Compiler(abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
pass
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,161,174,177,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return cached_runtime
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
cubin_path = os.path.join(path, 'kernel.cubin')
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
# Put cache and return
|
||||
runtime = runtime_cls(path)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_so_path)
|
||||
class NvccCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return (major, minor)
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
if platform.system() != 'Windows':
|
||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
else:
|
||||
cxx_flags = ['/O2', '/std:c++20']
|
||||
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str):
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
put(src_path, code)
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
res, major, minor = nvrtc.nvrtcVersion()
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
# Failed to get actual NVRTC version, use bindings version instead
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return (major, minor)
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch']
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to create program: {res}')
|
||||
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
program, len(options), options)[0]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get program log size: {res}')
|
||||
log_bytes = bytes(log_size)
|
||||
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get program log: {res}')
|
||||
log_str = log_bytes.decode('utf-8')
|
||||
print(log_str)
|
||||
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to compile program: {compile_res}')
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get CUBIN size: {res}')
|
||||
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to get CUBIN: {res}')
|
||||
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f'Failed to destroy program: {res}')
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
else:
|
||||
return NvccCompiler.build(name, code, runtime_cls=runtime_cls)
|
||||
|
@ -1,17 +1,20 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
from .template import map_ctype
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@ -21,46 +24,91 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to load library: {res}')
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to get kernel count: {res}')
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to enumerate kernels: {res}')
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to get kernel name: {res}')
|
||||
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
raise Exception('Failed to find required kernel')
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
res = cuda.cuLibraryUnload(self.lib)[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
return None
|
@ -1,111 +1,48 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
|
||||
#else
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
#include <string>
|
||||
#include <cuda.h>
|
||||
|
||||
#endif
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
if value.dtype == torch.int:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.bfloat16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float8_e4m3fn:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
else:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
if hasattr(value, 'cuda_stream'):
|
||||
return ctypes.c_void_p(value.cuda_stream)
|
||||
using namespace deep_gemm;
|
||||
|
||||
if isinstance(value, bool):
|
||||
return ctypes.c_bool(value)
|
||||
elif isinstance(value, int):
|
||||
return ctypes.c_int(value)
|
||||
elif isinstance(value, float):
|
||||
return ctypes.c_float(value)
|
||||
|
||||
return ctype_map[type(value)](value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
value_str = str(value)
|
||||
if isinstance(value, bool):
|
||||
value_str = value_str.lower()
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
'''
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
|
164
deep_gemm/jit/utils.py
Normal file
164
deep_gemm/jit/utils.py
Normal file
@ -0,0 +1,164 @@
|
||||
import ctypes
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
|
||||
class Layout(Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
GroupedMasked = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
0: 'Normal',
|
||||
1: 'GroupedContiguous',
|
||||
2: 'GroupedMasked',
|
||||
}[self.value]
|
||||
|
||||
|
||||
typename_map: Dict[Any, str] = {
|
||||
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
|
||||
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
|
||||
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
||||
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
}
|
||||
|
||||
swizzle_map = {
|
||||
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
|
||||
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
}
|
||||
|
||||
def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
|
||||
tensor_dtype = typename_map[global_address.dtype]
|
||||
res, tensor_map = cuda.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2, # tensor rank
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes,), # global strides
|
||||
smem_dim,
|
||||
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
|
||||
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
)
|
||||
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to encode tensor map: {res}')
|
||||
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
||||
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
kAlignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
|
||||
attr_val = cuda.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cuda.CUlaunchAttribute()
|
||||
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
|
||||
kernelValues = (
|
||||
gmem_d.data_ptr(),
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
)
|
||||
kernelTypes = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)
|
@ -3,40 +3,11 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = 1;
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
|
||||
require_divisible: bool = False) -> bool:
|
||||
@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
def fix_wave_saturate(x): return num_sms if x == 0 else x
|
||||
|
||||
def get_num_waves(bm, bn): return (ceil_div(
|
||||
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
|
||||
def get_last_wave_util(bm, bn): return fix_wave_saturate(
|
||||
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
num_waves, best_num_waves = get_num_waves(
|
||||
block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
success |= block_n == best_block_n and block_m < best_block_m
|
||||
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||
success |= block_m != best_block_m and block_n > best_block_n
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (
|
||||
best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = (4, 3)
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||
best_smem_config = get_smem_config(
|
||||
num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
@ -159,8 +139,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) *
|
||||
ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(
|
||||
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
@ -211,11 +193,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.Normal, lhs, m, k, block_m, block_k)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.Normal, rhs, k, n, block_k, block_n)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
|
||||
|
||||
kwargs = {
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@ -224,14 +237,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
@ -1,41 +1,12 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto BLOCK_K = 128;
|
||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||
constexpr auto kNumGroups = {NUM_GROUPS};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
|
||||
gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': m_indices,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedContiguous'},
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(
|
||||
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(
|
||||
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(
|
||||
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_desc(
|
||||
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m,
|
||||
'BLOCK_K': block_k,
|
||||
'GMEM_D': out,
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': masked_m,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
}
|
||||
|
||||
runtime, best_keys = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': 'GroupedMasked'},
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
runtime(**best_keys, **kwargs)
|
||||
|
@ -3,15 +3,16 @@ import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from ..jit import build, generate, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
@ -26,7 +27,7 @@ class JITTuner:
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
assert kwargs is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
@ -34,30 +35,31 @@ class JITTuner:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
code = generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code), full_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
return_code = runtime(**tuned_keys, **kwargs)
|
||||
if return_code != cuda.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
print(
|
||||
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int,
|
||||
device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
|
||||
(8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
@ -68,14 +70,16 @@ class JITTuner:
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
print(
|
||||
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
print(
|
||||
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = (best_runtime, best_keys)
|
||||
return best_runtime, best_keys
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
|
@ -1,64 +1,116 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
n = a.numel()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.gridDimX = (n + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
kernelValues = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
n,
|
||||
)
|
||||
kernelTypes = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
|
||||
|
||||
|
||||
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < N) {{
|
||||
c[i] = a[i] + b[i];
|
||||
}}
|
||||
}}
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', run_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
# NVCC
|
||||
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
code = generate_vector_add(T='float')
|
||||
print(code)
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test passed')
|
||||
print('JIT test for NVCC passed\n')
|
||||
|
||||
# NVRTC
|
||||
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='__nv_bfloat16')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test for NVRTC passed')
|
Loading…
Reference in New Issue
Block a user