From 3a3398f6864785b7ae6818ae775f662f6115ed8c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 19 Jun 2025 10:38:42 +0800 Subject: [PATCH] Minor fix --- csrc/kernels/buffer.cuh | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/csrc/kernels/buffer.cuh b/csrc/kernels/buffer.cuh index 7c243d3..19400a5 100644 --- a/csrc/kernels/buffer.cuh +++ b/csrc/kernels/buffer.cuh @@ -17,12 +17,12 @@ public: __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { total_bytes = num_elems * sizeof(dtype_t); - ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + ptr = static_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } @@ -51,8 +51,8 @@ public: int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; - ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, @@ -63,8 +63,8 @@ public: int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; for (int i = 0; i < kNumRanks; ++ i) { - ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; } } @@ -75,14 +75,14 @@ public: } __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } template __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { for (int i = 0; i < kNumAlsoRanks; ++ i) - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; return *this; } @@ -114,9 +114,9 @@ public: int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); - send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; - recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {