Update torch version after using system pre-installed version

This commit is contained in:
allegroai 2020-10-04 19:29:47 +03:00
parent 31a56c71bd
commit d419fa1e4f
2 changed files with 14 additions and 9 deletions

View File

@ -269,26 +269,30 @@ class PytorchRequirement(SimpleSubstitution):
def get_url_for_platform(self, req):
# check if package is already installed with system packages
# noinspection PyBroadException
try:
if self.config.get("agent.package_manager.system_site_packages", None):
from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name]))
# notice the comparision order, the first part will make sure we have a valid installed package
if installed_torch[0]['version'] and req.compare_version(installed_torch[0]['version']):
# notice the comparison order, the first part will make sure we have a valid installed package
if installed_torch and installed_torch[0]['version'] and \
req.compare_version(installed_torch[0]['version']):
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
# package already installed, do nothing
req.specs = [('==', str(installed_torch[0]['version']))]
return str(req), True
except:
except Exception:
pass
# make sure we have a specific version to retrieve
if not req.specs:
req.specs = [('>', '0')]
# noinspection PyBroadException
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except:
except Exception:
pass
op, version = req.specs[0]
# assert op == "=="
@ -308,8 +312,8 @@ class PytorchRequirement(SimpleSubstitution):
req, previous_cuda_key))
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
req, self.cuda_version))
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
req, previous_cuda_key, torch_url_key))
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, '
'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key))
url = self._get_link_from_torch_page(req, torch_url)
if not url:

View File

@ -138,7 +138,8 @@ class MarkerRequirement(object):
version = self.specs[0][1]
op = (op or self.specs[0][0]).strip()
return SimpleVersion.compare_versions(requested_version, op, version)
return SimpleVersion.compare_versions(
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
class SimpleVersion:
@ -177,7 +178,7 @@ class SimpleVersion:
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
@classmethod
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
"""
Compare two versions based on the op operator
returns bool(version_a op version_b)
@ -188,12 +189,12 @@ class SimpleVersion:
:param str version_b:
:param bool ignore_sub_versions: if true compare only major.minor.patch
(ignore a/b/rc/post/dev in the comparison)
:param int num_parts: number of parts to compare, split by . (dot)
:return bool: version_a op version_b
"""
if not version_b:
return True
num_parts = 3
if op == '~=':
num_parts = max(num_parts, 2)