diff --git a/.gitignore b/.gitignore index d8d5133..3e6e4e5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build dist *.egg-info *.pyc + # Third-party links created by `setup.py develop` deep_gemm/include/cute deep_gemm/include/cutlass diff --git a/setup.py b/setup.py index 0c947b5..fec5036 100644 --- a/setup.py +++ b/setup.py @@ -6,10 +6,10 @@ from setuptools.command.build_py import build_py from setuptools.command.develop import develop current_dir = os.path.dirname(os.path.realpath(__file__)) -jit_include_dirs = ("deep_gemm/include/deep_gemm",) +jit_include_dirs = ('deep_gemm/include/deep_gemm', ) third_party_include_dirs = ( - "third-party/cutlass/include/cute", - "third-party/cutlass/include/cutlass", + 'third-party/cutlass/include/cute', + 'third-party/cutlass/include/cutlass', ) @@ -22,9 +22,9 @@ class PostDevelopCommand(develop): def make_jit_include_symlinks(): # Make symbolic links of third-party include directories for d in third_party_include_dirs: - dirname = d.split("/")[-1] - src_dir = f"{current_dir}/{d}" - dst_dir = f"{current_dir}/deep_gemm/include/{dirname}" + dirname = d.split('/')[-1] + src_dir = f'{current_dir}/{d}' + dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' assert os.path.exists(src_dir) if os.path.exists(dst_dir): assert os.path.islink(dst_dir) @@ -36,17 +36,18 @@ class CustomBuildPy(build_py): def run(self): # First, prepare the include directories self.prepare_includes() + # Then run the regular build build_py.run(self) def prepare_includes(self): # Create temporary build directory instead of modifying package directory - build_include_dir = os.path.join(self.build_lib, "deep_gemm/include") + build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') os.makedirs(build_include_dir, exist_ok=True) # Copy third-party includes to the build directory for d in third_party_include_dirs: - dirname = d.split("/")[-1] + dirname = d.split('/')[-1] src_dir = os.path.join(current_dir, d) dst_dir = os.path.join(build_include_dir, dirname) @@ -58,27 +59,27 @@ class CustomBuildPy(build_py): shutil.copytree(src_dir, dst_dir) -if __name__ == "__main__": +if __name__ == '__main__': # noinspection PyBroadException try: - cmd = ["git", "rev-parse", "--short", "HEAD"] - revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip() + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() except: - revision = "" + revision = '' setuptools.setup( - name="deep_gemm", - version="1.0.0" + revision, - packages=["deep_gemm", "deep_gemm/jit", "deep_gemm/jit_kernels"], + name='deep_gemm', + version='1.0.0' + revision, + packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], package_data={ - "deep_gemm": [ - "include/deep_gemm/**/*", - "include/cute/**/*", - "include/cutlass/**/*", + 'deep_gemm': [ + 'include/deep_gemm/**/*', + 'include/cute/**/*', + 'include/cutlass/**/*', ] }, cmdclass={ - "develop": PostDevelopCommand, - "build_py": CustomBuildPy, + 'develop': PostDevelopCommand, + 'build_py': CustomBuildPy, }, )