From 39c10e6c3129d1afcfc2e50fe0544befc3656ab4 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 10 Mar 2025 09:47:02 +0800 Subject: [PATCH] Revert "Merge pull request #49 from A-transformer/maximum_fp8_e4m3_value" This reverts commit 4d4f2342febb829e88f9cc7cbf4b955823b395fc, reversing changes made to 9d3222a93e9637f6fb1cf2538199738c1eee23aa. --- tests/test_core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 7ba3e91..68d9b79 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,17 +5,14 @@ from typing import Tuple import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor -FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3 format def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return ( - (x_view * (FP8_E4M3_MAX / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), - (x_amax / FP8_E4M3_MAX).view(m, -1) - ) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 @@ -24,8 +21,8 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (FP8_E4M3_MAX / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / FP8_E4M3_MAX).view(x_view.size(0), x_view.size(2)) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def construct(m: int, k: int, n: int) -> \