diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 10b682f..64934f9 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -239,7 +239,23 @@ class PytorchRequirement(SimpleSubstitution): def get_url_for_platform(self, req): assert self.package_manager == "pip" assert self.os != "mac" + + # check if package is already installed with system packages + try: + if self.config.get("agent.package_manager.system_site_packages"): + from pip._internal.commands.show import search_packages_info + installed_torch = list(search_packages_info([req.name])) + op, version = req.specs[0] if req.specs else (None, None) + # notice the comparision order, the first part will make sure we have a valid installed package + if installed_torch[0]['version'] and (installed_torch[0]['version'] == version or not version): + # package already installed, do nothing + return str(req), True + except: + pass + + # make sure we have a specific version to retrieve assert req.specs + try: req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0]) except: