diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index f58997f..6b5096e 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -269,26 +269,30 @@ class PytorchRequirement(SimpleSubstitution): def get_url_for_platform(self, req): # check if package is already installed with system packages + # noinspection PyBroadException try: if self.config.get("agent.package_manager.system_site_packages", None): from pip._internal.commands.show import search_packages_info installed_torch = list(search_packages_info([req.name])) - # notice the comparision order, the first part will make sure we have a valid installed package - if installed_torch[0]['version'] and req.compare_version(installed_torch[0]['version']): + # notice the comparison order, the first part will make sure we have a valid installed package + if installed_torch and installed_torch[0]['version'] and \ + req.compare_version(installed_torch[0]['version']): print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format( req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version'])) # package already installed, do nothing + req.specs = [('==', str(installed_torch[0]['version']))] return str(req), True - except: + except Exception: pass # make sure we have a specific version to retrieve if not req.specs: req.specs = [('>', '0')] + # noinspection PyBroadException try: req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0]) - except: + except Exception: pass op, version = req.specs[0] # assert op == "==" @@ -308,8 +312,8 @@ class PytorchRequirement(SimpleSubstitution): req, previous_cuda_key)) raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format( req, self.cuda_version)) - print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format( - req, previous_cuda_key, torch_url_key)) + print('Warning! Could not locate PyTorch version {} matching CUDA version {}, ' + 'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key)) url = self._get_link_from_torch_page(req, torch_url) if not url: diff --git a/trains_agent/helper/package/requirements.py b/trains_agent/helper/package/requirements.py index 187b379..76bf637 100644 --- a/trains_agent/helper/package/requirements.py +++ b/trains_agent/helper/package/requirements.py @@ -138,7 +138,8 @@ class MarkerRequirement(object): version = self.specs[0][1] op = (op or self.specs[0][0]).strip() - return SimpleVersion.compare_versions(requested_version, op, version) + return SimpleVersion.compare_versions( + version_a=requested_version, op=op, version_b=version, num_parts=num_parts) class SimpleVersion: @@ -177,7 +178,7 @@ class SimpleVersion: _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE) @classmethod - def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True): + def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3): """ Compare two versions based on the op operator returns bool(version_a op version_b) @@ -188,12 +189,12 @@ class SimpleVersion: :param str version_b: :param bool ignore_sub_versions: if true compare only major.minor.patch (ignore a/b/rc/post/dev in the comparison) + :param int num_parts: number of parts to compare, split by . (dot) :return bool: version_a op version_b """ if not version_b: return True - num_parts = 3 if op == '~=': num_parts = max(num_parts, 2)