Fix torch resolver settings applied to PytorchRequirement instance are not used

This commit is contained in:
allegroai 2024-03-17 18:56:47 +02:00
parent 2de1c926bf
commit f1f9278928

View File

@ -670,8 +670,7 @@ class PytorchRequirement(SimpleSubstitution):
return MarkerRequirement(Requirement.parse(self._fix_setuptools)) return MarkerRequirement(Requirement.parse(self._fix_setuptools))
return None return None
@classmethod def get_torch_index_url(self, cuda_version, nightly=False):
def get_torch_index_url(cls, cuda_version, nightly=False):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
cuda = int(cuda_version) cuda = int(cuda_version)
@ -681,39 +680,39 @@ class PytorchRequirement(SimpleSubstitution):
if nightly: if nightly:
for c in range(cuda, max(-1, cuda-15), -1): for c in range(cuda, max(-1, cuda-15), -1):
# then try the nightly builds, it might be there... # then try the nightly builds, it might be there...
torch_url = cls.nightly_extra_index_url_template.format(c) torch_url = self.nightly_extra_index_url_template.format(c)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if requests.get(torch_url, timeout=10).ok: if requests.get(torch_url, timeout=10).ok:
print('Torch nightly CUDA {} index page found'.format(c)) print('Torch nightly CUDA {} index page found'.format(c))
cls.torch_index_url_lookup[c] = torch_url self.torch_index_url_lookup[c] = torch_url
return cls.torch_index_url_lookup[c], c return self.torch_index_url_lookup[c], c
except Exception: except Exception:
pass pass
return return
# first check if key is valid # first check if key is valid
if cuda in cls.torch_index_url_lookup: if cuda in self.torch_index_url_lookup:
return cls.torch_index_url_lookup[cuda], cuda return self.torch_index_url_lookup[cuda], cuda
# then try a new cuda version page # then try a new cuda version page
for c in range(cuda, max(-1, cuda-15), -1): for c in range(cuda, max(-1, cuda-15), -1):
torch_url = cls.extra_index_url_template.format(c) torch_url = self.extra_index_url_template.format(c)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if requests.get(torch_url, timeout=10).ok: if requests.get(torch_url, timeout=10).ok:
print('Torch CUDA {} index page found, adding `{}`'.format(c, torch_url)) print('Torch CUDA {} index page found, adding `{}`'.format(c, torch_url))
cls.torch_index_url_lookup[c] = torch_url self.torch_index_url_lookup[c] = torch_url
return cls.torch_index_url_lookup[c], c return self.torch_index_url_lookup[c], c
except Exception: except Exception:
pass pass
keys = sorted(cls.torch_index_url_lookup.keys(), reverse=True) keys = sorted(self.torch_index_url_lookup.keys(), reverse=True)
for k in keys: for k in keys:
if k <= cuda: if k <= cuda:
return cls.torch_index_url_lookup[k], k return self.torch_index_url_lookup[k], k
# return default - zero # return default - zero
return cls.torch_index_url_lookup[0], 0 return self.torch_index_url_lookup[0], 0
MAP = { MAP = {
"windows": { "windows": {