mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Update some comments and docs
This commit is contained in:
parent
3885404ffb
commit
77bb07aa20
11
README.md
11
README.md
@ -282,8 +282,15 @@ For two micro-batch overlapping, you can refer to the following figure. With our
|
||||
|
||||
## Notices
|
||||
|
||||
- For extreme performance, we discover and use a behavior-out-of-doc PTX instruction: `ld.global.nc.L1::no_allocate.L2::256B`. This instruction will lead to an undefined behavior: accessing volatile GPU memory with non-coherent read-only PTX modifiers `.nc`. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
|
||||
- For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
|
||||
#### Undefined-behavior PTX usage
|
||||
|
||||
- For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1.
|
||||
- Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery.
|
||||
- If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
|
||||
|
||||
#### Auto-tuning on your cluster
|
||||
|
||||
For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
|
||||
|
||||
## License
|
||||
|
||||
|
@ -378,7 +378,7 @@ combine(void* combined_x,
|
||||
atomic_add_release_global(atomic_clean_flag, num_experts);
|
||||
}
|
||||
|
||||
// FP8 cast and issue IBGDA sends
|
||||
// Issue IBGDA sends
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
|
@ -148,9 +148,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
|
||||
#define LD_NC_FUNC "ld.volatile.global"
|
||||
#endif
|
||||
|
||||
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS,
|
||||
// which does not have cache allocation, and `CONSTANT` memory does not have coherence control,
|
||||
// so we have to control them by queue semantics
|
||||
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
|
||||
auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
|
||||
@ -234,8 +232,7 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
|
||||
}
|
||||
|
||||
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
|
||||
// which does not have cache allocation (obviously in L1, I guess not in L2 too)
|
||||
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS
|
||||
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
#define ST_NA_FUNC "st.global.L1::no_allocate"
|
||||
#else
|
||||
|
Loading…
Reference in New Issue
Block a user