Update layout.py

This commit is contained in:
fzyzcjy 2025-06-12 16:10:00 +08:00 committed by Ray Wang
parent a437e0b1ca
commit cc416ee4fa

View File

@ -140,8 +140,14 @@ def transform_sf_into_required_layout(sf: torch.Tensor,
is_sfa: bool = False):
gran = (recipe[0 if is_sfa else 1], recipe[2])
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
should_skip_transform = (
(sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == '100a')
or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == '100a')
)
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
# (FP32, 1, 128) on Hopper: transform to TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == '90a':
@ -162,8 +168,7 @@ def transform_sf_into_required_layout(sf: torch.Tensor,
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int)
# (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
if sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == '100a':
if should_skip_transform:
# TODO: add transpose kernel if SF layout is not satisfied
return check_sf_layout(sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int)