import os
import subprocess
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


if __name__ == '__main__':
    nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
    assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
    print(f'NVSHMEM directory: {nvshmem_dir}')

    # TODO: currently, we only support Hopper architecture, we may add Ampere support later
    os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
    cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
                 '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
    nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
    include_dirs = ['csrc/', f'{nvshmem_dir}/include']
    sources = ['csrc/deep_ep.cpp',
               'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
               'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
    library_dirs = [f'{nvshmem_dir}/lib']

    # Disable aggressive PTX instructions
    if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
        cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
        nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')

    # Disable DLTO (default by PyTorch)
    nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
    extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
    extra_compile_args = {
        'cxx': cxx_flags,
        'nvcc': nvcc_flags,
        'nvcc_dlink': nvcc_dlink
    }

    # noinspection PyBroadException
    try:
        cmd = ['git', 'rev-parse', '--short', 'HEAD']
        revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
    except Exception as _:
        revision = ''

    setuptools.setup(
        name='deep_ep',
        version='1.0.0' + revision,
        packages=setuptools.find_packages(
            include=['deep_ep']
        ),
        ext_modules=[
            CUDAExtension(
                name='deep_ep_cpp',
                include_dirs=include_dirs,
                library_dirs=library_dirs,
                sources=sources,
                extra_compile_args=extra_compile_args,
                extra_link_args=extra_link_args
            )
        ],
        cmdclass={
            'build_ext': BuildExtension
        }
    )