Support bias. (#257)

* Support bias.

* Fix.

* Fix style.
This commit is contained in:
Shangyan Zhou
2025-06-25 13:04:20 +08:00
committed by GitHub
parent b80e55e21f
commit bd429ffefc
7 changed files with 101 additions and 16 deletions

View File

@@ -112,6 +112,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
@@ -127,6 +128,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,