mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Update layout.py
This commit is contained in:
parent
a437e0b1ca
commit
cc416ee4fa
@ -140,8 +140,14 @@ def transform_sf_into_required_layout(sf: torch.Tensor,
|
||||
is_sfa: bool = False):
|
||||
gran = (recipe[0 if is_sfa else 1], recipe[2])
|
||||
|
||||
# Pre-transform checks
|
||||
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
|
||||
should_skip_transform = (
|
||||
(sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == '100a')
|
||||
or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == '100a')
|
||||
)
|
||||
|
||||
if not should_skip_transform:
|
||||
# Pre-transform checks
|
||||
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
|
||||
|
||||
# (FP32, 1, 128) on Hopper: transform to TMA-aligned and MN-major
|
||||
if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == '90a':
|
||||
@ -162,8 +168,7 @@ def transform_sf_into_required_layout(sf: torch.Tensor,
|
||||
sf = get_col_major_tma_aligned_packed_tensor(sf)
|
||||
return check_sf_layout(sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int)
|
||||
|
||||
# (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == '100a':
|
||||
if should_skip_transform:
|
||||
# TODO: add transpose kernel if SF layout is not satisfied
|
||||
return check_sf_layout(sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user