Add torch_nightly flag support (if torch wheel is not found on stable try the nightly builds), improve support for torch in freeze (add actually used HTTP link as comment to the original package)

This commit is contained in:
allegroai 2020-05-09 20:08:05 +03:00
parent 53f511f536
commit 2ad929fa00
3 changed files with 46 additions and 10 deletions

View File

@ -55,6 +55,10 @@ agent {
# additional conda channels to use when installing with conda package manager # additional conda channels to use when installing with conda package manager
conda_channels: ["pytorch", "conda-forge", ] conda_channels: ["pytorch", "conda-forge", ]
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
torch_nightly: false,
}, },
# target folder for virtual environments builds, created when executing experiment # target folder for virtual environments builds, created when executing experiment

View File

@ -39,6 +39,10 @@
# additional conda channels to use when installing with conda package manager # additional conda channels to use when installing with conda package manager
conda_channels: ["defaults", "conda-forge", "pytorch", ] conda_channels: ["defaults", "conda-forge", "pytorch", ]
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
torch_nightly: false,
}, },
# target folder for virtual environments builds, created when executing experiment # target folder for virtual environments builds, created when executing experiment

View File

@ -74,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
packages = ("torch", "torchvision", "torchaudio") packages = ("torch", "torchvision", "torchaudio")
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html' page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
torch_page_lookup = { torch_page_lookup = {
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html', 0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html', 80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
@ -115,11 +116,23 @@ class SimplePytorchRequirement(SimpleSubstitution):
package_manager.add_extra_install_flags(('-f', extra_url)) package_manager.add_extra_install_flags(('-f', extra_url))
@classmethod @classmethod
def get_torch_page(cls, cuda_version): def get_torch_page(cls, cuda_version, nightly=False):
try: try:
cuda = int(cuda_version) cuda = int(cuda_version)
except: except:
cuda = 0 cuda = 0
if nightly:
# then try the nightly builds, it might be there...
torch_url = cls.nightly_page_lookup_template.format(cuda)
try:
if requests.get(torch_url, timeout=10).ok:
cls.torch_page_lookup[cuda] = torch_url
return cls.torch_page_lookup[cuda], cuda
except Exception:
pass
return
# first check if key is valid # first check if key is valid
if cuda in cls.torch_page_lookup: if cuda in cls.torch_page_lookup:
return cls.torch_page_lookup[cuda], cuda return cls.torch_page_lookup[cuda], cuda
@ -180,6 +193,8 @@ class PytorchRequirement(SimpleSubstitution):
except PytorchResolutionError as e: except PytorchResolutionError as e:
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0]) self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
self._original_req = []
@property @property
def is_conda(self): def is_conda(self):
return self.package_manager == "conda" return self.package_manager == "conda"
@ -242,6 +257,13 @@ class PytorchRequirement(SimpleSubstitution):
continue continue
url = '/'.join(torch_url.split('/')[:-1] + l.split('/')) url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
last_v = v last_v = v
# if we found an exact match, use it
try:
if req.specs[0][0] == '==' and \
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
break
except:
pass
return url return url
@ -273,6 +295,9 @@ class PytorchRequirement(SimpleSubstitution):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url) url = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly"):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU): # try one more time, with a lower cuda version (never fallback to CPU):
while not url and torch_url_key > 0: while not url and torch_url_key > 0:
previous_cuda_key = torch_url_key previous_cuda_key = torch_url_key
@ -363,7 +388,10 @@ class PytorchRequirement(SimpleSubstitution):
def replace(self, req): def replace(self, req):
try: try:
return self._replace(req) new_req = self._replace(req)
if new_req:
self._original_req.append((req, new_req))
return new_req
except Exception as e: except Exception as e:
message = "Exception when trying to resolve python wheel" message = "Exception when trying to resolve python wheel"
self.log.debug(message, exc_info=True) self.log.debug(message, exc_info=True)
@ -378,13 +406,13 @@ class PytorchRequirement(SimpleSubstitution):
except: except:
pass pass
try: # try:
result = self._table_lookup(req) # result = self._table_lookup(req)
except Exception as e: # except Exception as e:
exc = e # exc = e
else: # else:
self.log.debug('Replacing requirement "%s" with %r', req, result) # self.log.debug('Replacing requirement "%s" with %r', req, result)
return result # return result
self.log.debug( self.log.debug(
"Could not find Pytorch wheel in table, trying manually constructing URL" "Could not find Pytorch wheel in table, trying manually constructing URL"
@ -399,7 +427,7 @@ class PytorchRequirement(SimpleSubstitution):
if result: if result:
self.log.debug("URL not found: {}".format(result)) self.log.debug("URL not found: {}".format(result))
exc = PytorchResolutionError( exc = PytorchResolutionError(
"Was not able to find pytorch wheel URL: {}".format(exc) "Could not find pytorch wheel URL for: {}".format(req)
) )
# cancel exception chaining # cancel exception chaining
six.raise_from(exc, None) six.raise_from(exc, None)