Upgrade pynvml add detect CUDA version from driver level

This commit is contained in:
allegroai 2021-02-17 00:03:16 +02:00
parent 22d5892b12
commit 58cb344ee6
3 changed files with 1938 additions and 44 deletions

View File

@ -20,6 +20,7 @@ import platform
import sys
import time
from datetime import datetime
from typing import Optional
import psutil
from ..gpu import pynvml as N
@ -390,3 +391,34 @@ def new_query(shutdown=False, per_process_stats=False, get_driver_info=False):
'''
return GPUStatCollection.new_query(shutdown=shutdown, per_process_stats=per_process_stats,
get_driver_info=get_driver_info)
def get_driver_cuda_version():
# type: () -> Optional[str]
"""
:return: Return detected CUDA version from driver. On fail return value is None.
Example: `110` is cuda version 11.0
"""
# noinspection PyBroadException
try:
N.nvmlInit()
except BaseException:
return None
# noinspection PyBroadException
try:
cuda_version = str(N.nvmlSystemGetCudaDriverVersion())
except BaseException:
# noinspection PyBroadException
try:
cuda_version = str(N.nvmlSystemGetCudaDriverVersion_v2())
except BaseException:
cuda_version = ''
# noinspection PyBroadException
try:
N.nvmlShutdown()
except BaseException:
return None
return cuda_version[:3] if cuda_version else None

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ import six
from clearml_agent.definitions import PIP_EXTRA_INDICES
from clearml_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
from clearml_agent.helper.process import Argv, PathLike
from clearml_agent.helper.gpu.gpustat import get_driver_cuda_version
from clearml_agent.session import Session, normalize_cuda_version
from clearml_agent.external.requirements_parser import parse
from clearml_agent.external.requirements_parser.requirement import Requirement
@ -537,6 +538,9 @@ class RequirementsManager(object):
if cuda_version and cudnn_version:
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
if not cuda_version:
cuda_version = get_driver_cuda_version()
if not cuda_version and is_windows_platform():
try:
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()