diff --git a/clearml_agent/helper/package/pytorch.py b/clearml_agent/helper/package/pytorch.py index 5b2ef15..1340985 100644 --- a/clearml_agent/helper/package/pytorch.py +++ b/clearml_agent/helper/package/pytorch.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import re import sys +import platform from furl import furl import urllib.parse from operator import itemgetter @@ -245,10 +246,15 @@ class PytorchRequirement(SimpleSubstitution): return "macos" raise RuntimeError("unrecognized OS") + @staticmethod + def get_arch(): + return str(platform.machine()).lower() + def _get_link_from_torch_page(self, req, torch_url): links_parser = LinksHTMLParser() links_parser.feed(requests.get(torch_url, timeout=10).text) platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform() + arch_wheel = self.get_arch() py_ver = self.python_major_minor_str.replace('.', '') url = None last_v = None @@ -269,8 +275,11 @@ class PytorchRequirement(SimpleSubstitution): continue if len(parts) < 3 or not parts[2].endswith(py_ver): continue - if len(parts) < 5 or platform_wheel not in parts[4]: + if len(parts) < 5 or platform_wheel not in parts[4].lower(): continue + if len(parts) < 5 or arch_wheel not in parts[4].lower(): + continue + # yes this is for linux python 2.7 support, this is the only python 2.7 we support... if py_ver and py_ver[0] == '2' and len(parts) > 3 and not parts[3].endswith('u'): continue