mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Add torchcsprng and torchtext to PyTorch resolving. Improve debug prints on auto cuda version resolving.
This commit is contained in:
parent
448e23825c
commit
1afa3a3914
@ -166,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
|||||||
class PytorchRequirement(SimpleSubstitution):
|
class PytorchRequirement(SimpleSubstitution):
|
||||||
|
|
||||||
name = "torch"
|
name = "torch"
|
||||||
packages = ("torch", "torchvision", "torchaudio")
|
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os_name = kwargs.pop("os_override", None)
|
os_name = kwargs.pop("os_override", None)
|
||||||
@ -244,6 +244,7 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
py_ver = self.python_major_minor_str.replace('.', '')
|
py_ver = self.python_major_minor_str.replace('.', '')
|
||||||
url = None
|
url = None
|
||||||
last_v = None
|
last_v = None
|
||||||
|
closest_v = None
|
||||||
# search for our package
|
# search for our package
|
||||||
for l in links_parser.links:
|
for l in links_parser.links:
|
||||||
parts = l.split('/')[-1].split('-')
|
parts = l.split('/')[-1].split('-')
|
||||||
@ -253,28 +254,40 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
continue
|
continue
|
||||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||||
# version ignore .postX suffix (treat as regular version)
|
# version ignore .postX suffix (treat as regular version)
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
v = str(parts[1].split('%')[0].split('+')[0])
|
v = str(parts[1].split('%')[0].split('+')[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
if len(parts) < 3 or not parts[2].endswith(py_ver):
|
||||||
|
continue
|
||||||
|
if len(parts) < 5 or platform_wheel not in parts[4]:
|
||||||
|
continue
|
||||||
|
# update the closest matched version (from above)
|
||||||
|
if not closest_v:
|
||||||
|
closest_v = v
|
||||||
|
elif SimpleVersion.compare_versions(
|
||||||
|
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
|
||||||
|
SimpleVersion.compare_versions(
|
||||||
|
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
|
||||||
|
closest_v = v
|
||||||
|
# check if this an actual match
|
||||||
if not req.compare_version(v) or \
|
if not req.compare_version(v) or \
|
||||||
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||||
continue
|
continue
|
||||||
if not parts[2].endswith(py_ver):
|
|
||||||
continue
|
|
||||||
if platform_wheel not in parts[4]:
|
|
||||||
continue
|
|
||||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||||
last_v = v
|
last_v = v
|
||||||
# if we found an exact match, use it
|
# if we found an exact match, use it
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if req.specs[0][0] == '==' and \
|
if req.specs[0][0] == '==' and \
|
||||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||||
break
|
break
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return url
|
return url, last_v or closest_v
|
||||||
|
|
||||||
def get_url_for_platform(self, req):
|
def get_url_for_platform(self, req):
|
||||||
# check if package is already installed with system packages
|
# check if package is already installed with system packages
|
||||||
@ -307,23 +320,28 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
# assert op == "=="
|
# assert op == "=="
|
||||||
|
|
||||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||||
url = self._get_link_from_torch_page(req, torch_url)
|
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||||
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
||||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||||
url = self._get_link_from_torch_page(req, torch_url)
|
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||||
while not url and torch_url_key > 0:
|
while not url and torch_url_key > 0:
|
||||||
previous_cuda_key = torch_url_key
|
previous_cuda_key = torch_url_key
|
||||||
|
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
|
||||||
|
req, previous_cuda_key, closest_matched_version))
|
||||||
|
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||||
|
if url:
|
||||||
|
break
|
||||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||||
# never fallback to CPU
|
# never fallback to CPU
|
||||||
if torch_url_key < 1:
|
if torch_url_key < 1:
|
||||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
print(
|
||||||
|
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||||
req, previous_cuda_key))
|
req, previous_cuda_key))
|
||||||
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
raise ValueError(
|
||||||
req, self.cuda_version))
|
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
|
||||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, '
|
else:
|
||||||
'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key))
|
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||||
url = self._get_link_from_torch_page(req, torch_url)
|
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
url = PytorchWheel(
|
url = PytorchWheel(
|
||||||
@ -335,6 +353,8 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
if url:
|
if url:
|
||||||
# normalize url (sometimes we will get ../ which we should not...
|
# normalize url (sometimes we will get ../ which we should not...
|
||||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||||
|
# print found
|
||||||
|
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
|
||||||
|
|
||||||
self.log.debug("checking url: %s", url)
|
self.log.debug("checking url: %s", url)
|
||||||
return url, requests.head(url, timeout=10).ok
|
return url, requests.head(url, timeout=10).ok
|
||||||
|
Loading…
Reference in New Issue
Block a user