Allow to pverride pytorch lookup page: "agent.package_manager.torch_page / torch_nightly_page / torch_url_template_prefix"

This commit is contained in:
allegroai 2022-09-15 20:16:41 +03:00
parent a69766bd8b
commit 594ee5842e

View File

@ -53,17 +53,16 @@ class PytorchWheel(object):
python = attr.ib(type=str, converter=lambda x: str(x).replace(".", ""))
torch_version = attr.ib(type=str, converter=fix_version)
url_template = (
"http://download.pytorch.org/whl/"
"{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
)
url_template_prefix = "http://download.pytorch.org/whl/"
url_template = "{0.cuda_version}/torch-{0.torch_version}" \
"-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
def __attrs_post_init__(self):
self.unicode = "u" if self.python.startswith("2") else ""
def make_url(self):
# type: () -> Text
return self.url_template.format(self)
return (self.url_template_prefix + self.url_template).format(self)
class PytorchResolutionError(FatalSpecsResolutionError):
@ -183,6 +182,19 @@ class PytorchRequirement(SimpleSubstitution):
self._fix_setuptools = None
self.exceptions = []
self._original_req = []
# allow override pytorch lookup pages
if self.config.get("agent.package_manager.torch_page", None):
SimplePytorchRequirement.page_lookup_template = \
self.config.get("agent.package_manager.torch_page", None)
if self.config.get("agent.package_manager.torch_nightly_page", None):
SimplePytorchRequirement.nightly_page_lookup_template = \
self.config.get("agent.package_manager.torch_nightly_page", None)
if self.config.get("agent.package_manager.torch_url_template_prefix", None):
PytorchWheel.url_template_prefix = \
self.config.get("agent.package_manager.torch_url_template_prefix", None)
if self.config.get("agent.package_manager.torch_url_template", None):
PytorchWheel.url_template = \
self.config.get("agent.package_manager.torch_url_template", None)
def _init_python_ver_cuda_ver(self):
if self.cuda is None: