Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)