Minor fix

This commit is contained in:
Chenggang Zhao
2025-06-19 10:38:42 +08:00
parent 24453275e3
commit 3a3398f686

View File

@@ -17,12 +17,12 @@ public:
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
ptr = static_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
@@ -51,8 +51,8 @@ public:
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
@@ -63,8 +63,8 @@ public:
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}
@@ -75,14 +75,14 @@ public:
}
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
template<int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
return *this;
}
@@ -114,9 +114,9 @@ public:
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
send_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {