mirror of
https://github.com/clearml/clearml-agent
synced 2025-04-07 22:14:18 +00:00
Add torch_nightly flag support (if torch wheel is not found on stable try the nightly builds), improve support for torch in freeze (add actually used HTTP link as comment to the original package)
This commit is contained in:
parent
53f511f536
commit
2ad929fa00
@ -55,6 +55,10 @@ agent {
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["pytorch", "conda-forge", ]
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
torch_nightly: false,
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
|
@ -39,6 +39,10 @@
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["defaults", "conda-forge", "pytorch", ]
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
torch_nightly: false,
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
|
@ -74,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
|
||||
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
|
||||
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
|
||||
torch_page_lookup = {
|
||||
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
|
||||
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
|
||||
@ -115,11 +116,23 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
package_manager.add_extra_install_flags(('-f', extra_url))
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version):
|
||||
def get_torch_page(cls, cuda_version, nightly=False):
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except:
|
||||
cuda = 0
|
||||
|
||||
if nightly:
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# first check if key is valid
|
||||
if cuda in cls.torch_page_lookup:
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
@ -180,6 +193,8 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
except PytorchResolutionError as e:
|
||||
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
|
||||
|
||||
self._original_req = []
|
||||
|
||||
@property
|
||||
def is_conda(self):
|
||||
return self.package_manager == "conda"
|
||||
@ -242,6 +257,13 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
continue
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
last_v = v
|
||||
# if we found an exact match, use it
|
||||
try:
|
||||
if req.specs[0][0] == '==' and \
|
||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
return url
|
||||
|
||||
@ -273,6 +295,9 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
if not url and self.config.get("agent.package_manager.torch_nightly"):
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||
while not url and torch_url_key > 0:
|
||||
previous_cuda_key = torch_url_key
|
||||
@ -363,7 +388,10 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
def replace(self, req):
|
||||
try:
|
||||
return self._replace(req)
|
||||
new_req = self._replace(req)
|
||||
if new_req:
|
||||
self._original_req.append((req, new_req))
|
||||
return new_req
|
||||
except Exception as e:
|
||||
message = "Exception when trying to resolve python wheel"
|
||||
self.log.debug(message, exc_info=True)
|
||||
@ -378,13 +406,13 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
result = self._table_lookup(req)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
else:
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
# try:
|
||||
# result = self._table_lookup(req)
|
||||
# except Exception as e:
|
||||
# exc = e
|
||||
# else:
|
||||
# self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
# return result
|
||||
|
||||
self.log.debug(
|
||||
"Could not find Pytorch wheel in table, trying manually constructing URL"
|
||||
@ -399,7 +427,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if result:
|
||||
self.log.debug("URL not found: {}".format(result))
|
||||
exc = PytorchResolutionError(
|
||||
"Was not able to find pytorch wheel URL: {}".format(exc)
|
||||
"Could not find pytorch wheel URL for: {}".format(req)
|
||||
)
|
||||
# cancel exception chaining
|
||||
six.raise_from(exc, None)
|
||||
|
Loading…
Reference in New Issue
Block a user