Add assertions

This commit is contained in:
Chenggang Zhao 2025-05-27 13:21:19 +08:00
parent 780b4098e4
commit 3bd234e79c

View File

@ -99,6 +99,7 @@ def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group:
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
assert masked_m.amax().item() <= max_m
return x_fp8, y_fp8, masked_m, out, ref_out