fix: compiler version

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-23 00:06:18 +00:00
parent c14cad0c06
commit 78c7fa347e

View File

@ -176,7 +176,8 @@ class Compiler(abc.ABC):
class NvccCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
@classmethod
@ -203,8 +204,7 @@ class NvccCompiler(Compiler):
class NvrtcCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor)
@staticmethod