From 8712c5e636d9a02e939a9759348d29521a3939a9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 16 Mar 2022 17:40:21 +0200 Subject: [PATCH] Fix PyTorch aarch64 and windows support --- clearml_agent/helper/package/pytorch.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/clearml_agent/helper/package/pytorch.py b/clearml_agent/helper/package/pytorch.py index 5b2ef15..1340985 100644 --- a/clearml_agent/helper/package/pytorch.py +++ b/clearml_agent/helper/package/pytorch.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import re import sys +import platform from furl import furl import urllib.parse from operator import itemgetter @@ -245,10 +246,15 @@ class PytorchRequirement(SimpleSubstitution): return "macos" raise RuntimeError("unrecognized OS") + @staticmethod + def get_arch(): + return str(platform.machine()).lower() + def _get_link_from_torch_page(self, req, torch_url): links_parser = LinksHTMLParser() links_parser.feed(requests.get(torch_url, timeout=10).text) platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform() + arch_wheel = self.get_arch() py_ver = self.python_major_minor_str.replace('.', '') url = None last_v = None @@ -269,8 +275,11 @@ class PytorchRequirement(SimpleSubstitution): continue if len(parts) < 3 or not parts[2].endswith(py_ver): continue - if len(parts) < 5 or platform_wheel not in parts[4]: + if len(parts) < 5 or platform_wheel not in parts[4].lower(): continue + if len(parts) < 5 or arch_wheel not in parts[4].lower(): + continue + # yes this is for linux python 2.7 support, this is the only python 2.7 we support... if py_ver and py_ver[0] == '2' and len(parts) > 3 and not parts[3].endswith('u'): continue