diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 0a19789..a87e1ab 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -13,6 +13,7 @@ #include "kernels/config.h" #include "kernels/get_mla_metadata.h" #include "kernels/mla_combine.h" +#include "kernels/params.h" #include "kernels/splitkv_mla.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/kernels/get_mla_metadata.h index 5faa665..5130581 100644 --- a/csrc/kernels/get_mla_metadata.h +++ b/csrc/kernels/get_mla_metadata.h @@ -1,5 +1,5 @@ #pragma once -#include "flash_mla.h" +#include "params.h" void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu index 681dfe0..b6ba8f8 100644 --- a/csrc/kernels/mla_combine.cu +++ b/csrc/kernels/mla_combine.cu @@ -5,7 +5,7 @@ #include #include -#include "flash_mla.h" +#include "params.h" #include "utils.h" #include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V diff --git a/csrc/kernels/mla_combine.h b/csrc/kernels/mla_combine.h index 3f33b8f..69035e9 100644 --- a/csrc/kernels/mla_combine.h +++ b/csrc/kernels/mla_combine.h @@ -1,6 +1,6 @@ #pragma once -#include "flash_mla.h" +#include "params.h" template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_mla.h b/csrc/kernels/params.h similarity index 100% rename from csrc/flash_mla.h rename to csrc/kernels/params.h diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index 0333605..ff29305 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1,6 +1,6 @@ #include -#include "flash_mla.h" +#include "params.h" #include "utils.h" #include "config.h" #include "traits.h" diff --git a/csrc/kernels/splitkv_mla.h b/csrc/kernels/splitkv_mla.h index 42109d4..479fb50 100644 --- a/csrc/kernels/splitkv_mla.h +++ b/csrc/kernels/splitkv_mla.h @@ -1,6 +1,6 @@ #pragma once -#include "flash_mla.h" +#include "params.h" template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);