from __future__ import unicode_literals import re import sys from furl import furl import urllib.parse from operator import itemgetter from html.parser import HTMLParser from typing import Text import attr import requests from semantic_version import Version, Spec import six from .requirements import SimpleSubstitution, FatalSpecsResolutionError OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"} def os_to_wheel_name(x): return OS_TO_WHEEL_NAME[x] def fix_version(version): def replace(nums, prerelease): if prerelease: return "{}-{}".format(nums, prerelease) return nums return re.sub( r"(\d+(?:\.\d+){,2})(?:\.(.*))?", lambda match: replace(*match.groups()), version, ) class LinksHTMLParser(HTMLParser): def __init__(self): super(LinksHTMLParser, self).__init__() self.links = [] def handle_data(self, data): if data and data.strip(): self.links += [data] @attr.s class PytorchWheel(object): os_name = attr.ib(type=str, converter=os_to_wheel_name) cuda_version = attr.ib(converter=lambda x: "cu{}".format(x) if x else "cpu") python = attr.ib(type=str, converter=lambda x: str(x).replace(".", "")) torch_version = attr.ib(type=str, converter=fix_version) url_template = ( "http://download.pytorch.org/whl/" "{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl" ) def __attrs_post_init__(self): self.unicode = "u" if self.python.startswith("2") else "" def make_url(self): # type: () -> Text return self.url_template.format(self) class PytorchResolutionError(FatalSpecsResolutionError): pass class SimplePytorchRequirement(SimpleSubstitution): name = "torch" packages = ("torch", "torchvision", "torchaudio") page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html' torch_page_lookup = { 0: 'https://download.pytorch.org/whl/cpu/torch_stable.html', 80: 'https://download.pytorch.org/whl/cu80/torch_stable.html', 90: 'https://download.pytorch.org/whl/cu90/torch_stable.html', 92: 'https://download.pytorch.org/whl/cu92/torch_stable.html', 100: 'https://download.pytorch.org/whl/cu100/torch_stable.html', 101: 'https://download.pytorch.org/whl/cu101/torch_stable.html', } def __init__(self, *args, **kwargs): super(SimplePytorchRequirement, self).__init__(*args, **kwargs) self._matched = False def match(self, req): # match both any of out packages return req.name in self.packages def replace(self, req): """ Replace a requirement :raises: ValueError if version is pre-release """ # Get rid of +cpu +cu?? etc. try: req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0]) except: pass self._matched = True return Text(req) def matching_done(self, reqs, package_manager): # type: (Sequence[MarkerRequirement], object) -> () if not self._matched: return # TODO: add conda channel support from .pip_api.system import SystemPip if package_manager and isinstance(package_manager, SystemPip): extra_url, _ = self.get_torch_page(self.cuda_version) package_manager.add_extra_install_flags(('-f', extra_url)) @classmethod def get_torch_page(cls, cuda_version): try: cuda = int(cuda_version) except: cuda = 0 # first check if key is valid if cuda in cls.torch_page_lookup: return cls.torch_page_lookup[cuda], cuda # then try a new cuda version page torch_url = cls.page_lookup_template.format(cuda) try: if requests.get(torch_url, timeout=10).ok: cls.torch_page_lookup[cuda] = torch_url return cls.torch_page_lookup[cuda], cuda except Exception: pass keys = sorted(cls.torch_page_lookup.keys(), reverse=True) for k in keys: if k <= cuda: return cls.torch_page_lookup[k], k # return default - zero return cls.torch_page_lookup[0], 0 class PytorchRequirement(SimpleSubstitution): name = "torch" packages = ("torch", "torchvision", "torchaudio") def __init__(self, *args, **kwargs): os_name = kwargs.pop("os_override", None) super(PytorchRequirement, self).__init__(*args, **kwargs) self.log = self._session.get_logger(__name__) self.package_manager = self.config["agent.package_manager.type"].lower() self.os = os_name or self.get_platform() self.cuda = "cuda{}".format(self.cuda_version).lower() self.python_version_string = str(self.config["agent.default_python"]) self.python_semantic_version = Version.coerce( self.python_version_string, partial=True ) self.python = "python{}.{}".format(self.python_semantic_version.major, self.python_semantic_version.minor) self.exceptions = [ PytorchResolutionError(message) for message in ( None, 'cuda version "{}" is not supported'.format(self.cuda), 'python version "{}" is not supported'.format( self.python_version_string ), ) ] try: self.validate_python_version() except PytorchResolutionError as e: self.log.warn("will not be able to install pytorch wheels: %s", e.args[0]) @property def is_conda(self): return self.package_manager == "conda" @property def is_pip(self): return not self.is_conda def validate_python_version(self): """ Make sure python version has both major and minor versions as required for choosing pytorch wheel """ if self.is_pip and not ( self.python_semantic_version.major and self.python_semantic_version.minor ): raise PytorchResolutionError( "invalid python version {!r} defined in configuration file, key 'agent.default_python': " "must have both major and minor parts of the version (for example: '3.7')".format( self.python_version_string ) ) def match(self, req): return req.name in self.packages @staticmethod def get_platform(): if sys.platform == "linux": return "linux" if sys.platform == "win32" or sys.platform == "cygwin": return "windows" if sys.platform == "darwin": return "macos" raise RuntimeError("unrecognized OS") def _get_link_from_torch_page(self, req, torch_url): links_parser = LinksHTMLParser() links_parser.feed(requests.get(torch_url, timeout=10).text) platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform() py_ver = "{0.major}{0.minor}".format(self.python_semantic_version) url = None # search for our package for l in links_parser.links: parts = l.split('/')[-1].split('-') if len(parts) < 5: continue if parts[0] != req.name: continue # version (ignore +cpu +cu92 etc. + is %2B in the file link) if parts[1].split('%')[0].split('+')[0] != req.specs[0][1]: continue if not parts[2].endswith(py_ver): continue if platform_wheel not in parts[4]: continue url = '/'.join(torch_url.split('/')[:-1] + l.split('/')) break return url def get_url_for_platform(self, req): assert self.package_manager == "pip" assert self.os != "mac" # check if package is already installed with system packages try: if self.config.get("agent.package_manager.system_site_packages"): from pip._internal.commands.show import search_packages_info installed_torch = list(search_packages_info([req.name])) op, version = req.specs[0] if req.specs else (None, None) # notice the comparision order, the first part will make sure we have a valid installed package if installed_torch[0]['version'] and (installed_torch[0]['version'] == version or not version): # package already installed, do nothing return str(req), True except: pass # make sure we have a specific version to retrieve assert req.specs try: req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0]) except: pass op, version = req.specs[0] # assert op == "==" torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) url = self._get_link_from_torch_page(req, torch_url) # try one more time, with a lower cuda version (never fallback to CPU): while not url and torch_url_key > 0: previous_cuda_key = torch_url_key torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1) # never fallback to CPU if torch_url_key < 1: print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format( req, previous_cuda_key)) raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format( req, self.cuda_version)) print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format( req, previous_cuda_key, torch_url_key)) url = self._get_link_from_torch_page(req, torch_url) if not url: url = PytorchWheel( torch_version=fix_version(version), python="{0.major}{0.minor}".format(self.python_semantic_version), os_name=self.os, cuda_version=self.cuda_version, ).make_url() if url: # normalize url (sometimes we will get ../ which we should not... url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize())) self.log.debug("checking url: %s", url) return url, requests.head(url, timeout=10).ok @staticmethod def match_version(req, options): versioned_options = sorted( ((Version(fix_version(key)), value) for key, value in options.items()), key=itemgetter(0), reverse=True, ) req.specs = [(op, fix_version(version)) for op, version in req.specs] if req.specs: specs = Spec(req.format_specs()) else: specs = None try: return next( replacement for version, replacement in versioned_options if not specs or version in specs ) except StopIteration: raise PytorchResolutionError( 'Could not find wheel for "{}", ' "Available versions: {}".format(req, list(options)) ) def replace_conda(self, req): spec = "".join(req.specs[0]) if req.specs else "" if not self.cuda_version: return "pytorch-cpu{spec}\ntorchvision-cpu".format(spec=spec) return "pytorch{spec}\ntorchvision\ncuda{self.cuda_version}".format( self=self, spec=spec ) def _table_lookup(self, req): """ Look for pytorch wheel matching `req` in table :param req: python requirement """ def check(base_, key_, exception_): result = base_.get(key_) if not result: if key_.startswith('cuda'): print('Could not locate, {}'.format(exception_)) ver = sorted([float(a.replace('cuda', '').replace('none', '0')) for a in base_.keys()], reverse=True)[0] key_ = 'cuda'+str(int(ver)) result = base_.get(key_) print('Reverting to \"{}\"'.format(key_)) if not result: raise exception_ return result raise exception_ if isinstance(result, Exception): raise result return result if self.is_conda: return self.replace_conda(req) base = self.MAP for key, exception in zip((self.os, self.cuda, self.python), self.exceptions): base = check(base, key, exception) return self.match_version(req, base).replace(" ", "\n") def replace(self, req): try: return self._replace(req) except Exception as e: message = "Exception when trying to resolve python wheel" self.log.debug(message, exc_info=True) raise PytorchResolutionError("{}: {}".format(message, e)) def _replace(self, req): self.validate_python_version() try: result, ok = self.get_url_for_platform(req) self.log.debug('Replacing requirement "%s" with %r', req, result) return result except: pass try: result = self._table_lookup(req) except Exception as e: exc = e else: self.log.debug('Replacing requirement "%s" with %r', req, result) return result self.log.debug( "Could not find Pytorch wheel in table, trying manually constructing URL" ) result = ok = None # try: # result, ok = self.get_url_for_platform(req) # except Exception: # pass if not ok: if result: self.log.debug("URL not found: {}".format(result)) exc = PytorchResolutionError( "Was not able to find pytorch wheel URL: {}".format(exc) ) # cancel exception chaining six.raise_from(exc, None) self.log.debug('Replacing requirement "%s" with %r', req, result) return result MAP = { "windows": { "cuda100": { "python3.7": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-win_amd64.whl" }, "python3.6": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-win_amd64.whl" }, "python3.5": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-win_amd64.whl" }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, "cuda92": { "python3.7": { "0.4.1", "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp37-cp37m-win_amd64.whl", }, "python3.6": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-win_amd64.whl" }, "python3.5": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-win_amd64.whl" }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, "cuda91": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-win_amd64.whl" }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-win_amd64.whl" }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, "cuda90": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-win_amd64.whl", }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-win_amd64.whl", }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, "cuda80": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-win_amd64.whl", }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-win_amd64.whl", }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, "cudanone": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl", }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-win_amd64.whl", "1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-win_amd64.whl", }, "python2.7": PytorchResolutionError( "PyTorch does not support Python 2.7 on Windows" ), }, }, "macos": { "cuda100": PytorchResolutionError( "MacOS Binaries dont support CUDA, install from source if CUDA is needed" ), "cuda92": PytorchResolutionError( "MacOS Binaries dont support CUDA, install from source if CUDA is needed" ), "cuda91": PytorchResolutionError( "MacOS Binaries dont support CUDA, install from source if CUDA is needed" ), "cuda90": PytorchResolutionError( "MacOS Binaries dont support CUDA, install from source if CUDA is needed" ), "cuda80": PytorchResolutionError( "MacOS Binaries dont support CUDA, install from source if CUDA is needed" ), "cudanone": { "python3.6": {"0.4.0": "torch"}, "python3.5": {"0.4.0": "torch"}, "python2.7": {"0.4.0": "torch"}, }, }, "linux": { "cuda100": { "python3.7": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-linux_x86_64.whl", "1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp37-cp37m-linux_x86_64.whl", "1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl", "1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl", }, "python3.6": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl", "1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp36-cp36m-linux_x86_64.whl", "1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl", "1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl", }, "python3.5": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-linux_x86_64.whl", "1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp35-cp35m-linux_x86_64.whl", "1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp35-cp35m-linux_x86_64.whl", "1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl", }, "python2.7": { "1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl", "1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp27-cp27mu-linux_x86_64.whl", "1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl", "1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl", }, }, "cuda92": { "python3.7": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl", "1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp37-cp37m-manylinux1_x86_64.whl" }, "python3.6": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl", "1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl" }, "python3.5": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-linux_x86_64.whl", "1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp35-cp35m-manylinux1_x86_64.whl" }, "python2.7": { "0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl", "1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp27-cp27mu-manylinux1_x86_64.whl" }, }, "cuda91": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl" }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl" }, "python2.7": { "0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl" }, }, "cuda90": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl", }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-linux_x86_64.whl", }, "python2.7": { "0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl", }, }, "cuda80": { "python3.6": { "0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl", "0.3.1": "torch==0.3.1", "0.3.0.post4": "torch==0.3.0.post4", "0.1.2.post1": "torch==0.1.2.post1", "0.1.2": "torch==0.1.2", "1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl", }, "python3.5": { "0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp35-cp35m-linux_x86_64.whl", "0.3.1": "torch==0.3.1", "0.3.0.post4": "torch==0.3.0.post4", "0.1.2.post1": "torch==0.1.2.post1", "0.1.2": "torch==0.1.2", "1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-linux_x86_64.whl", }, "python2.7": { "0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl", "0.3.1": "torch==0.3.1", "0.3.0.post4": "torch==0.3.0.post4", "0.1.2.post1": "torch==0.1.2.post1", "0.1.2": "torch==0.1.2", "1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl", }, }, "cudanone": { "python3.6": { "0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl", }, "python3.5": { "0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl", }, "python2.7": { "0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl", "1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl", }, }, }, }