From 594ee5842ecd5f30760969ba30cf1dbf497a0036 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 15 Sep 2022 20:16:41 +0300 Subject: [PATCH] Allow to pverride pytorch lookup page: "agent.package_manager.torch_page / torch_nightly_page / torch_url_template_prefix" --- clearml_agent/helper/package/pytorch.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/clearml_agent/helper/package/pytorch.py b/clearml_agent/helper/package/pytorch.py index 6dbfb32..096746b 100644 --- a/clearml_agent/helper/package/pytorch.py +++ b/clearml_agent/helper/package/pytorch.py @@ -53,17 +53,16 @@ class PytorchWheel(object): python = attr.ib(type=str, converter=lambda x: str(x).replace(".", "")) torch_version = attr.ib(type=str, converter=fix_version) - url_template = ( - "http://download.pytorch.org/whl/" - "{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl" - ) + url_template_prefix = "http://download.pytorch.org/whl/" + url_template = "{0.cuda_version}/torch-{0.torch_version}" \ + "-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl" def __attrs_post_init__(self): self.unicode = "u" if self.python.startswith("2") else "" def make_url(self): # type: () -> Text - return self.url_template.format(self) + return (self.url_template_prefix + self.url_template).format(self) class PytorchResolutionError(FatalSpecsResolutionError): @@ -183,6 +182,19 @@ class PytorchRequirement(SimpleSubstitution): self._fix_setuptools = None self.exceptions = [] self._original_req = [] + # allow override pytorch lookup pages + if self.config.get("agent.package_manager.torch_page", None): + SimplePytorchRequirement.page_lookup_template = \ + self.config.get("agent.package_manager.torch_page", None) + if self.config.get("agent.package_manager.torch_nightly_page", None): + SimplePytorchRequirement.nightly_page_lookup_template = \ + self.config.get("agent.package_manager.torch_nightly_page", None) + if self.config.get("agent.package_manager.torch_url_template_prefix", None): + PytorchWheel.url_template_prefix = \ + self.config.get("agent.package_manager.torch_url_template_prefix", None) + if self.config.get("agent.package_manager.torch_url_template", None): + PytorchWheel.url_template = \ + self.config.get("agent.package_manager.torch_url_template", None) def _init_python_ver_cuda_ver(self): if self.cuda is None: