From 828a19c7204483164478ce50a1a34be8c4ea34f4 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 22 Apr 2025 17:46:35 +0800 Subject: [PATCH] Move flash_mla.h to kernels/params.h --- csrc/flash_api.cpp | 1 + csrc/kernels/get_mla_metadata.h | 2 +- csrc/kernels/mla_combine.cu | 2 +- csrc/kernels/mla_combine.h | 2 +- csrc/{flash_mla.h => kernels/params.h} | 0 csrc/kernels/splitkv_mla.cu | 2 +- csrc/kernels/splitkv_mla.h | 2 +- 7 files changed, 6 insertions(+), 5 deletions(-) rename csrc/{flash_mla.h => kernels/params.h} (100%) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 0a19789..a87e1ab 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -13,6 +13,7 @@ #include "kernels/config.h" #include "kernels/get_mla_metadata.h" #include "kernels/mla_combine.h" +#include "kernels/params.h" #include "kernels/splitkv_mla.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/kernels/get_mla_metadata.h index 5faa665..5130581 100644 --- a/csrc/kernels/get_mla_metadata.h +++ b/csrc/kernels/get_mla_metadata.h @@ -1,5 +1,5 @@ #pragma once -#include "flash_mla.h" +#include "params.h" void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu index 681dfe0..b6ba8f8 100644 --- a/csrc/kernels/mla_combine.cu +++ b/csrc/kernels/mla_combine.cu @@ -5,7 +5,7 @@ #include #include -#include "flash_mla.h" +#include "params.h" #include "utils.h" #include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V diff --git a/csrc/kernels/mla_combine.h b/csrc/kernels/mla_combine.h index 3f33b8f..69035e9 100644 --- a/csrc/kernels/mla_combine.h +++ b/csrc/kernels/mla_combine.h @@ -1,6 +1,6 @@ #pragma once -#include "flash_mla.h" +#include "params.h" template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_mla.h b/csrc/kernels/params.h similarity index 100% rename from csrc/flash_mla.h rename to csrc/kernels/params.h diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index 0333605..ff29305 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1,6 +1,6 @@ #include -#include "flash_mla.h" +#include "params.h" #include "utils.h" #include "config.h" #include "traits.h" diff --git a/csrc/kernels/splitkv_mla.h b/csrc/kernels/splitkv_mla.h index 42109d4..479fb50 100644 --- a/csrc/kernels/splitkv_mla.h +++ b/csrc/kernels/splitkv_mla.h @@ -1,6 +1,6 @@ #pragma once -#include "flash_mla.h" +#include "params.h" template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);