mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-05-10 06:40:39 +00:00
Merge pull request #32 from sijiac/fp16-support
Support FP16 dtype in FlashMLA kenrel
This commit is contained in:
commit
b549289fb4
@ -3,7 +3,7 @@
|
|||||||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
||||||
|
|
||||||
Currently released:
|
Currently released:
|
||||||
- BF16
|
- BF16, FP16
|
||||||
- Paged kvcache with block size of 64
|
- Paged kvcache with block size of 64
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
@ -61,7 +61,7 @@ std::vector<at::Tensor>
|
|||||||
mha_fwd_kvcache_mla(
|
mha_fwd_kvcache_mla(
|
||||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||||
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||||
const int head_size_v,
|
const int head_size_v,
|
||||||
const at::Tensor &seqlens_k, // batch_size
|
const at::Tensor &seqlens_k, // batch_size
|
||||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||||
@ -77,7 +77,6 @@ mha_fwd_kvcache_mla(
|
|||||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||||
|
|
||||||
auto q_dtype = q.dtype();
|
auto q_dtype = q.dtype();
|
||||||
TORCH_CHECK(q_dtype == torch::kBFloat16);
|
|
||||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
|
|
||||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||||
@ -186,7 +185,18 @@ mha_fwd_kvcache_mla(
|
|||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
TORCH_CHECK(head_size == 576);
|
TORCH_CHECK(head_size == 576);
|
||||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
|
||||||
|
if (q_dtype == torch::kBFloat16) {
|
||||||
|
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||||
|
}
|
||||||
|
#ifndef FLASH_MLA_DISABLE_FP16
|
||||||
|
else if (q_dtype == torch::kHalf) {
|
||||||
|
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||||
|
}
|
||||||
|
|
||||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||||
|
3
csrc/flash_fwd_mla_fp16_sm90.cu
Normal file
3
csrc/flash_fwd_mla_fp16_sm90.cu
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#include "flash_fwd_mla_kernel.h"
|
||||||
|
|
||||||
|
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream)
|
|||||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int MaxBatchSize = 4096;
|
|
||||||
|
|
||||||
__global__ void __launch_bounds__(256, 1, 1)
|
|
||||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
|
||||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
|
||||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
|
||||||
int *num_splits_ptr = params.num_splits_ptr;
|
|
||||||
int batch_size = params.batch_size;
|
|
||||||
int block_size_n = params.block_size_n;
|
|
||||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
|
||||||
int num_sm_parts = params.num_sm_parts;
|
|
||||||
|
|
||||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
|
||||||
__shared__ int num_splits_shared[MaxBatchSize];
|
|
||||||
|
|
||||||
int total_num_blocks = 0;
|
|
||||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
|
||||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
|
||||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
|
||||||
num_blocks_shared[i] = num_blocks;
|
|
||||||
}
|
|
||||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
|
||||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
|
||||||
|
|
||||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
|
||||||
num_splits_shared[0] = 0;
|
|
||||||
for (int i = 0; i < num_sm_parts; ++i) {
|
|
||||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
|
||||||
tile_scheduler_metadata0[0] = now_idx;
|
|
||||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
|
||||||
tile_scheduler_metadata1 = now_n_split_idx;
|
|
||||||
int remain_payload = payload;
|
|
||||||
while (now_idx < batch_size) {
|
|
||||||
int num_blocks = num_blocks_shared[now_idx];
|
|
||||||
int now_remain_blocks = num_blocks - now_block;
|
|
||||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
|
||||||
cum_num_splits += now_n_split_idx + 1;
|
|
||||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
|
||||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
|
||||||
++now_idx;
|
|
||||||
now_block = 0;
|
|
||||||
now_n_split_idx = 0;
|
|
||||||
} else {
|
|
||||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
|
||||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
|
||||||
++now_n_split_idx;
|
|
||||||
remain_payload = 0;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
|
||||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
|
||||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
|
||||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
|
||||||
}
|
|
||||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
|
|
||||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
|
||||||
num_splits_ptr[i] = num_splits_shared[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
|
||||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
|
||||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
|
||||||
CHECK_CUDA_KERNEL_LAUNCH();
|
|
||||||
}
|
|
||||||
|
77
csrc/flash_fwd_mla_metadata.cu
Normal file
77
csrc/flash_fwd_mla_metadata.cu
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#include "flash_fwd_mla_kernel.h"
|
||||||
|
|
||||||
|
static constexpr int MaxBatchSize = 4096;
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(256, 1, 1)
|
||||||
|
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||||
|
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||||
|
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||||
|
int *num_splits_ptr = params.num_splits_ptr;
|
||||||
|
int batch_size = params.batch_size;
|
||||||
|
int block_size_n = params.block_size_n;
|
||||||
|
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||||
|
int num_sm_parts = params.num_sm_parts;
|
||||||
|
|
||||||
|
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||||
|
__shared__ int num_splits_shared[MaxBatchSize];
|
||||||
|
|
||||||
|
int total_num_blocks = 0;
|
||||||
|
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||||
|
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||||
|
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||||
|
num_blocks_shared[i] = num_blocks;
|
||||||
|
}
|
||||||
|
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||||
|
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||||
|
|
||||||
|
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||||
|
num_splits_shared[0] = 0;
|
||||||
|
for (int i = 0; i < num_sm_parts; ++i) {
|
||||||
|
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||||
|
tile_scheduler_metadata0[0] = now_idx;
|
||||||
|
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||||
|
tile_scheduler_metadata1 = now_n_split_idx;
|
||||||
|
int remain_payload = payload;
|
||||||
|
while (now_idx < batch_size) {
|
||||||
|
int num_blocks = num_blocks_shared[now_idx];
|
||||||
|
int now_remain_blocks = num_blocks - now_block;
|
||||||
|
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||||
|
cum_num_splits += now_n_split_idx + 1;
|
||||||
|
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||||
|
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||||
|
++now_idx;
|
||||||
|
now_block = 0;
|
||||||
|
now_n_split_idx = 0;
|
||||||
|
} else {
|
||||||
|
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||||
|
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||||
|
++now_n_split_idx;
|
||||||
|
remain_payload = 0;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||||
|
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||||
|
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||||
|
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||||
|
}
|
||||||
|
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||||
|
num_splits_ptr[i] = num_splits_shared[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||||
|
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||||
|
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||||
|
CHECK_CUDA_KERNEL_LAUNCH();
|
||||||
|
}
|
27
setup.py
27
setup.py
@ -11,11 +11,29 @@ from torch.utils.cpp_extension import (
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||||
|
|
||||||
def append_nvcc_threads(nvcc_extra_args):
|
def append_nvcc_threads(nvcc_extra_args):
|
||||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||||
|
|
||||||
|
def get_sources():
|
||||||
|
sources = [
|
||||||
|
"csrc/flash_api.cpp",
|
||||||
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
|
"csrc/flash_fwd_mla_metadata.cu",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not DISABLE_FP16:
|
||||||
|
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||||
|
|
||||||
|
return sources
|
||||||
|
|
||||||
|
def get_features_args():
|
||||||
|
features_args = []
|
||||||
|
if DISABLE_FP16:
|
||||||
|
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||||
|
return features_args
|
||||||
|
|
||||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||||
|
|
||||||
@ -34,12 +52,9 @@ ext_modules = []
|
|||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="flash_mla_cuda",
|
name="flash_mla_cuda",
|
||||||
sources=[
|
sources=get_sources(),
|
||||||
"csrc/flash_api.cpp",
|
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
|
||||||
],
|
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": cxx_args,
|
"cxx": cxx_args + get_features_args(),
|
||||||
"nvcc": append_nvcc_threads(
|
"nvcc": append_nvcc_threads(
|
||||||
[
|
[
|
||||||
"-O3",
|
"-O3",
|
||||||
@ -57,7 +72,7 @@ ext_modules.append(
|
|||||||
"--ptxas-options=-v,--register-usage-level=10"
|
"--ptxas-options=-v,--register-usage-level=10"
|
||||||
]
|
]
|
||||||
+ cc_flag
|
+ cc_flag
|
||||||
),
|
) + get_features_args(),
|
||||||
},
|
},
|
||||||
include_dirs=[
|
include_dirs=[
|
||||||
Path(this_dir) / "csrc",
|
Path(this_dir) / "csrc",
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
import argparse
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||||
@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
print(
|
||||||
|
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
|
||||||
|
)
|
||||||
|
|
||||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||||
if varlen:
|
if varlen:
|
||||||
@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
|||||||
|
|
||||||
q = torch.randn(b, s_q, h_q, d)
|
q = torch.randn(b, s_q, h_q, d)
|
||||||
block_size = 64
|
block_size = 64
|
||||||
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
|
block_table = torch.arange(
|
||||||
|
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||||
|
).view(b, max_seqlen_pad // block_size)
|
||||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
|
||||||
|
float("nan")
|
||||||
|
)
|
||||||
blocked_v = blocked_k[..., :dv]
|
blocked_v = blocked_k[..., :dv]
|
||||||
|
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||||
|
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||||
|
)
|
||||||
|
|
||||||
def flash_mla():
|
def flash_mla():
|
||||||
return flash_mla_with_kvcache(
|
return flash_mla_with_kvcache(
|
||||||
q, blocked_k, block_table, cache_seqlens, dv,
|
q,
|
||||||
tile_scheduler_metadata, num_splits, causal=causal,
|
blocked_k,
|
||||||
|
block_table,
|
||||||
|
cache_seqlens,
|
||||||
|
dv,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ref_mla():
|
def ref_mla():
|
||||||
@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
|||||||
|
|
||||||
t = triton.testing.do_bench(flash_mla)
|
t = triton.testing.do_bench(flash_mla)
|
||||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
||||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
torch.finfo(q.dtype).bits // 8
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main(torch_dtype):
|
||||||
dtype = torch.bfloat16
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(torch_dtype)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@ -114,3 +132,22 @@ if __name__ == "__main__":
|
|||||||
for s_q in [1, 2]: # MTP = 1, 2
|
for s_q in [1, 2]: # MTP = 1, 2
|
||||||
for varlen in [False, True]:
|
for varlen in [False, True]:
|
||||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["bf16", "fp16"],
|
||||||
|
default="bf16",
|
||||||
|
help="Data type to use for testing (bf16 or fp16)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
if args.dtype == "fp16":
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
main(torch_dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user