mirror of
https://github.com/clearml/clearml-agent
synced 2025-02-07 13:26:08 +00:00
Update torch version after using system pre-installed version
This commit is contained in:
parent
31a56c71bd
commit
d419fa1e4f
@ -269,26 +269,30 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
|
|
||||||
def get_url_for_platform(self, req):
|
def get_url_for_platform(self, req):
|
||||||
# check if package is already installed with system packages
|
# check if package is already installed with system packages
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if self.config.get("agent.package_manager.system_site_packages", None):
|
if self.config.get("agent.package_manager.system_site_packages", None):
|
||||||
from pip._internal.commands.show import search_packages_info
|
from pip._internal.commands.show import search_packages_info
|
||||||
installed_torch = list(search_packages_info([req.name]))
|
installed_torch = list(search_packages_info([req.name]))
|
||||||
# notice the comparision order, the first part will make sure we have a valid installed package
|
# notice the comparison order, the first part will make sure we have a valid installed package
|
||||||
if installed_torch[0]['version'] and req.compare_version(installed_torch[0]['version']):
|
if installed_torch and installed_torch[0]['version'] and \
|
||||||
|
req.compare_version(installed_torch[0]['version']):
|
||||||
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
||||||
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
|
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
|
||||||
# package already installed, do nothing
|
# package already installed, do nothing
|
||||||
|
req.specs = [('==', str(installed_torch[0]['version']))]
|
||||||
return str(req), True
|
return str(req), True
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# make sure we have a specific version to retrieve
|
# make sure we have a specific version to retrieve
|
||||||
if not req.specs:
|
if not req.specs:
|
||||||
req.specs = [('>', '0')]
|
req.specs = [('>', '0')]
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
op, version = req.specs[0]
|
op, version = req.specs[0]
|
||||||
# assert op == "=="
|
# assert op == "=="
|
||||||
@ -308,8 +312,8 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
req, previous_cuda_key))
|
req, previous_cuda_key))
|
||||||
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||||
req, self.cuda_version))
|
req, self.cuda_version))
|
||||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
|
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, '
|
||||||
req, previous_cuda_key, torch_url_key))
|
'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key))
|
||||||
url = self._get_link_from_torch_page(req, torch_url)
|
url = self._get_link_from_torch_page(req, torch_url)
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
|
@ -138,7 +138,8 @@ class MarkerRequirement(object):
|
|||||||
version = self.specs[0][1]
|
version = self.specs[0][1]
|
||||||
op = (op or self.specs[0][0]).strip()
|
op = (op or self.specs[0][0]).strip()
|
||||||
|
|
||||||
return SimpleVersion.compare_versions(requested_version, op, version)
|
return SimpleVersion.compare_versions(
|
||||||
|
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
|
||||||
|
|
||||||
|
|
||||||
class SimpleVersion:
|
class SimpleVersion:
|
||||||
@ -177,7 +178,7 @@ class SimpleVersion:
|
|||||||
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
|
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
|
||||||
"""
|
"""
|
||||||
Compare two versions based on the op operator
|
Compare two versions based on the op operator
|
||||||
returns bool(version_a op version_b)
|
returns bool(version_a op version_b)
|
||||||
@ -188,12 +189,12 @@ class SimpleVersion:
|
|||||||
:param str version_b:
|
:param str version_b:
|
||||||
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
||||||
(ignore a/b/rc/post/dev in the comparison)
|
(ignore a/b/rc/post/dev in the comparison)
|
||||||
|
:param int num_parts: number of parts to compare, split by . (dot)
|
||||||
:return bool: version_a op version_b
|
:return bool: version_a op version_b
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not version_b:
|
if not version_b:
|
||||||
return True
|
return True
|
||||||
num_parts = 3
|
|
||||||
|
|
||||||
if op == '~=':
|
if op == '~=':
|
||||||
num_parts = max(num_parts, 2)
|
num_parts = max(num_parts, 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user