mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support UE8M0 data format. (#206)
This commit is contained in:
@@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
if x_scales.dtype == torch.int:
|
||||
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
|
||||
x_scales = x_scales.view(dtype=torch.float)
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user