Add torch_nightly flag support (if torch wheel is not found on stable try the nightly builds), improve support for torch in freeze (add actually used HTTP link as comment to the original package)

This commit is contained in:
allegroai 2020-05-09 20:08:05 +03:00
parent 53f511f536
commit 2ad929fa00
3 changed files with 46 additions and 10 deletions

View File

@ -55,6 +55,10 @@ agent {
# additional conda channels to use when installing with conda package manager
conda_channels: ["pytorch", "conda-forge", ]
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
torch_nightly: false,
},
# target folder for virtual environments builds, created when executing experiment

View File

@ -39,6 +39,10 @@
# additional conda channels to use when installing with conda package manager
conda_channels: ["defaults", "conda-forge", "pytorch", ]
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
torch_nightly: false,
},
# target folder for virtual environments builds, created when executing experiment

View File

@ -74,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
packages = ("torch", "torchvision", "torchaudio")
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
torch_page_lookup = {
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
@ -115,11 +116,23 @@ class SimplePytorchRequirement(SimpleSubstitution):
package_manager.add_extra_install_flags(('-f', extra_url))
@classmethod
def get_torch_page(cls, cuda_version):
def get_torch_page(cls, cuda_version, nightly=False):
try:
cuda = int(cuda_version)
except:
cuda = 0
if nightly:
# then try the nightly builds, it might be there...
torch_url = cls.nightly_page_lookup_template.format(cuda)
try:
if requests.get(torch_url, timeout=10).ok:
cls.torch_page_lookup[cuda] = torch_url
return cls.torch_page_lookup[cuda], cuda
except Exception:
pass
return
# first check if key is valid
if cuda in cls.torch_page_lookup:
return cls.torch_page_lookup[cuda], cuda
@ -180,6 +193,8 @@ class PytorchRequirement(SimpleSubstitution):
except PytorchResolutionError as e:
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
self._original_req = []
@property
def is_conda(self):
return self.package_manager == "conda"
@ -242,6 +257,13 @@ class PytorchRequirement(SimpleSubstitution):
continue
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
last_v = v
# if we found an exact match, use it
try:
if req.specs[0][0] == '==' and \
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
break
except:
pass
return url
@ -273,6 +295,9 @@ class PytorchRequirement(SimpleSubstitution):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly"):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU):
while not url and torch_url_key > 0:
previous_cuda_key = torch_url_key
@ -363,7 +388,10 @@ class PytorchRequirement(SimpleSubstitution):
def replace(self, req):
try:
return self._replace(req)
new_req = self._replace(req)
if new_req:
self._original_req.append((req, new_req))
return new_req
except Exception as e:
message = "Exception when trying to resolve python wheel"
self.log.debug(message, exc_info=True)
@ -378,13 +406,13 @@ class PytorchRequirement(SimpleSubstitution):
except:
pass
try:
result = self._table_lookup(req)
except Exception as e:
exc = e
else:
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
# try:
# result = self._table_lookup(req)
# except Exception as e:
# exc = e
# else:
# self.log.debug('Replacing requirement "%s" with %r', req, result)
# return result
self.log.debug(
"Could not find Pytorch wheel in table, trying manually constructing URL"
@ -399,7 +427,7 @@ class PytorchRequirement(SimpleSubstitution):
if result:
self.log.debug("URL not found: {}".format(result))
exc = PytorchResolutionError(
"Was not able to find pytorch wheel URL: {}".format(exc)
"Could not find pytorch wheel URL for: {}".format(req)
)
# cancel exception chaining
six.raise_from(exc, None)