From 2ad929fa00932272829884fe2ab3357974a8cace Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 9 May 2020 20:08:05 +0300 Subject: [PATCH] Add torch_nightly flag support (if torch wheel is not found on stable try the nightly builds), improve support for torch in freeze (add actually used HTTP link as comment to the original package) --- docs/trains.conf | 4 ++ .../backend_api/config/default/agent.conf | 4 ++ trains_agent/helper/package/pytorch.py | 48 +++++++++++++++---- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index f8df36a..56f68c2 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -55,6 +55,10 @@ agent { # additional conda channels to use when installing with conda package manager conda_channels: ["pytorch", "conda-forge", ] + + # set to True to support torch nightly build installation, + # notice: torch nightly builds are ephemeral and are deleted from time to time + torch_nightly: false, }, # target folder for virtual environments builds, created when executing experiment diff --git a/trains_agent/backend_api/config/default/agent.conf b/trains_agent/backend_api/config/default/agent.conf index 86792c2..e22da79 100644 --- a/trains_agent/backend_api/config/default/agent.conf +++ b/trains_agent/backend_api/config/default/agent.conf @@ -39,6 +39,10 @@ # additional conda channels to use when installing with conda package manager conda_channels: ["defaults", "conda-forge", "pytorch", ] + + # set to True to support torch nightly build installation, + # notice: torch nightly builds are ephemeral and are deleted from time to time + torch_nightly: false, }, # target folder for virtual environments builds, created when executing experiment diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 8b13c57..f492cdc 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -74,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution): packages = ("torch", "torchvision", "torchaudio") page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html' + nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html' torch_page_lookup = { 0: 'https://download.pytorch.org/whl/cpu/torch_stable.html', 80: 'https://download.pytorch.org/whl/cu80/torch_stable.html', @@ -115,11 +116,23 @@ class SimplePytorchRequirement(SimpleSubstitution): package_manager.add_extra_install_flags(('-f', extra_url)) @classmethod - def get_torch_page(cls, cuda_version): + def get_torch_page(cls, cuda_version, nightly=False): try: cuda = int(cuda_version) except: cuda = 0 + + if nightly: + # then try the nightly builds, it might be there... + torch_url = cls.nightly_page_lookup_template.format(cuda) + try: + if requests.get(torch_url, timeout=10).ok: + cls.torch_page_lookup[cuda] = torch_url + return cls.torch_page_lookup[cuda], cuda + except Exception: + pass + return + # first check if key is valid if cuda in cls.torch_page_lookup: return cls.torch_page_lookup[cuda], cuda @@ -180,6 +193,8 @@ class PytorchRequirement(SimpleSubstitution): except PytorchResolutionError as e: self.log.warn("will not be able to install pytorch wheels: %s", e.args[0]) + self._original_req = [] + @property def is_conda(self): return self.package_manager == "conda" @@ -242,6 +257,13 @@ class PytorchRequirement(SimpleSubstitution): continue url = '/'.join(torch_url.split('/')[:-1] + l.split('/')) last_v = v + # if we found an exact match, use it + try: + if req.specs[0][0] == '==' and \ + SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False): + break + except: + pass return url @@ -273,6 +295,9 @@ class PytorchRequirement(SimpleSubstitution): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) url = self._get_link_from_torch_page(req, torch_url) + if not url and self.config.get("agent.package_manager.torch_nightly"): + torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True) + url = self._get_link_from_torch_page(req, torch_url) # try one more time, with a lower cuda version (never fallback to CPU): while not url and torch_url_key > 0: previous_cuda_key = torch_url_key @@ -363,7 +388,10 @@ class PytorchRequirement(SimpleSubstitution): def replace(self, req): try: - return self._replace(req) + new_req = self._replace(req) + if new_req: + self._original_req.append((req, new_req)) + return new_req except Exception as e: message = "Exception when trying to resolve python wheel" self.log.debug(message, exc_info=True) @@ -378,13 +406,13 @@ class PytorchRequirement(SimpleSubstitution): except: pass - try: - result = self._table_lookup(req) - except Exception as e: - exc = e - else: - self.log.debug('Replacing requirement "%s" with %r', req, result) - return result + # try: + # result = self._table_lookup(req) + # except Exception as e: + # exc = e + # else: + # self.log.debug('Replacing requirement "%s" with %r', req, result) + # return result self.log.debug( "Could not find Pytorch wheel in table, trying manually constructing URL" @@ -399,7 +427,7 @@ class PytorchRequirement(SimpleSubstitution): if result: self.log.debug("URL not found: {}".format(result)) exc = PytorchResolutionError( - "Was not able to find pytorch wheel URL: {}".format(exc) + "Could not find pytorch wheel URL for: {}".format(req) ) # cancel exception chaining six.raise_from(exc, None)