mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -504,63 +504,73 @@ public:
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
shape_k, block_m, block_k);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
|
||||
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
|
||||
block_k, block_n);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
|
||||
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
|
||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleDMode == 32)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
if constexpr (kSwizzleDMode == 64)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
if constexpr (kSwizzleDMode == 128)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
||||
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
shape_n, block_m,
|
||||
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::ColMajor, shape_m,
|
||||
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap
|
||||
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
|
||||
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type =
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
||||
gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
||||
gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
||||
69
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
69
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
@@ -0,0 +1,69 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
using uint16_t = unsigned short;
|
||||
using int32_t = signed int;
|
||||
using uint32_t = unsigned int;
|
||||
using int64_t = signed long long;
|
||||
using uint64_t = unsigned long long;
|
||||
using cuuint64_t = unsigned long long;
|
||||
|
||||
namespace std
|
||||
{
|
||||
template <class T, T v>
|
||||
struct integral_constant
|
||||
{
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept
|
||||
{
|
||||
return value;
|
||||
} // since c++14
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
using true_type = integral_constant<bool, true>;
|
||||
|
||||
template <class T, class U>
|
||||
struct is_same : false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T> : true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
@@ -1,6 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cassert>
|
||||
#endif
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
@@ -16,8 +17,12 @@ public:
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#define DG_HOST_ASSERT(cond) ((void)0)
|
||||
#else
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
@@ -27,6 +32,7 @@ do { \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
|
||||
Reference in New Issue
Block a user