mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-03 11:41:13 +00:00
44 lines
1.1 KiB
C++
44 lines
1.1 KiB
C++
#include <ATen/cuda/CUDAContext.h>
|
|
#include <memory>
|
|
|
|
#include "kernels/exception.cuh"
|
|
|
|
namespace deep_ep {
|
|
|
|
struct EventHandle {
|
|
std::shared_ptr<torch::Event> event;
|
|
|
|
EventHandle() {
|
|
event = std::make_shared<torch::Event>(torch::kCUDA);
|
|
event->record(at::cuda::getCurrentCUDAStream());
|
|
}
|
|
|
|
explicit EventHandle(const at::cuda::CUDAStream& stream) {
|
|
event = std::make_shared<torch::Event>(torch::kCUDA);
|
|
event->record(stream);
|
|
}
|
|
|
|
EventHandle(const EventHandle& other) = default;
|
|
|
|
void current_stream_wait() const {
|
|
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
|
|
}
|
|
};
|
|
|
|
torch::Event create_event(const at::cuda::CUDAStream &s) {
|
|
auto event = torch::Event(torch::kCUDA);
|
|
event.record(s);
|
|
return event;
|
|
}
|
|
|
|
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
|
|
EP_HOST_ASSERT(s_0.id() != s_1.id());
|
|
s_0.unwrap().wait(create_event(s_1));
|
|
}
|
|
|
|
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
|
|
s.unwrap().wait(*event.event);
|
|
}
|
|
|
|
} // namespace deep_ep
|