Initial commit

This commit is contained in:
Chenggang Zhao
2025-02-24 21:56:14 +08:00
commit ebfe47e46f
32 changed files with 8461 additions and 0 deletions

43
csrc/event.hpp Normal file
View File

@@ -0,0 +1,43 @@
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#include "kernels/exception.cuh"
namespace deep_ep {
struct EventHandle {
std::shared_ptr<torch::Event> event;
EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
}
explicit EventHandle(const at::cuda::CUDAStream& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}
EventHandle(const EventHandle& other) = default;
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
}
};
torch::Event create_event(const at::cuda::CUDAStream &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
}
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
}
} // namespace deep_ep