Rename trains-agent -> clearml-agent

This commit is contained in:
allegroai
2020-12-22 21:21:29 +02:00
parent 090327234a
commit 6e1f74402e
112 changed files with 0 additions and 0 deletions

View File

View File

@@ -0,0 +1,155 @@
from __future__ import unicode_literals
import abc
from contextlib import contextmanager
from typing import Text, Iterable, Union
import six
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
from trains_agent.helper.process import Executable, Argv, PathLike
@six.add_metaclass(abc.ABCMeta)
class PackageManager(object):
"""
ABC for classes providing python package management interface
"""
_selected_manager = None
_cwd = None
_pip_version = None
@abc.abstractproperty
def bin(self):
# type: () -> PathLike
pass
@abc.abstractmethod
def create(self):
pass
@abc.abstractmethod
def remove(self):
pass
@abc.abstractmethod
def install_from_file(self, path):
pass
@abc.abstractmethod
def freeze(self):
pass
@abc.abstractmethod
def load_requirements(self, requirements):
pass
@abc.abstractmethod
def install_packages(self, *packages):
# type: (Iterable[Text]) -> None
"""
Install packages, upgrading depends on config
"""
pass
@abc.abstractmethod
def _install(self, *packages):
# type: (Iterable[Text]) -> None
"""
Run install command
"""
pass
@abc.abstractmethod
def uninstall_packages(self, *packages):
# type: (Iterable[Text]) -> None
pass
def upgrade_pip(self):
result = self._install(
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
packages = self.run_with_env(('list',), output=True).splitlines()
# p.split is ('pip', 'x.y.z')
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
if pip:
# noinspection PyBroadException
try:
from .requirements import MarkerRequirement
pip = pip[0][1].split('.')
MarkerRequirement.pip_new_version = bool(int(pip[0]) >= 20)
except Exception:
pass
return result
def get_python_command(self, extra=()):
# type: (...) -> Executable
return Argv(self.bin, *extra)
@contextmanager
def temp_file(self, prefix, contents, suffix=".txt"):
# type: (Union[Text, Iterable[Text]], Iterable[Text], Text) -> Text
"""
Write contents to a temporary file, yielding its path. Finally, delete it.
:param prefix: file name prefix
:param contents: text lines to write
:param suffix: file name suffix
"""
f, temp_path = mkstemp(suffix=suffix, prefix=prefix)
with f:
f.write(
contents
if isinstance(contents, six.text_type)
else join_lines(contents)
)
try:
yield temp_path
finally:
if not self.session.debug_mode:
safe_remove_file(temp_path)
def set_selected_package_manager(self):
# set this instance as the selected package manager
# this is helpful when we want out of context requirement installations
PackageManager._selected_manager = self
@property
def cwd(self):
return self._cwd
@cwd.setter
def cwd(self, value):
self._cwd = value
@classmethod
def out_of_scope_install_package(cls, package_name, *args):
if PackageManager._selected_manager is not None:
try:
result = PackageManager._selected_manager._install(package_name, *args)
if result not in (0, None, True):
return False
except Exception:
return False
return True
@classmethod
def out_of_scope_freeze(cls):
if PackageManager._selected_manager is not None:
try:
return PackageManager._selected_manager.freeze()
except Exception:
pass
return []
@classmethod
def set_pip_version(cls, version):
if not version:
return
version = version.replace(' ', '')
if ('=' in version) or ('~' in version) or ('<' in version) or ('>' in version):
cls._pip_version = version
else:
cls._pip_version = "=="+version
@classmethod
def get_pip_version(cls):
return cls._pip_version or ''

View File

@@ -0,0 +1,722 @@
from __future__ import unicode_literals
import json
import re
import os
import subprocess
from collections import OrderedDict
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
import six
import yaml
from time import time
from attr import attrs, attrib, Factory
from pathlib2 import Path
from trains_agent.external.requirements_parser import parse
from trains_agent.external.requirements_parser.requirement import Requirement
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.helper.package.requirements import SimpleVersion
from trains_agent.session import Session
from .base import PackageManager
from .pip_api.venv import VirtualenvPip
from .requirements import RequirementsManager, MarkerRequirement
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
def package_set(packages):
return set(map(package_normalize, packages))
def _package_diff(path, packages):
# type: (Union[Path, Text], Iterable[Text]) -> Set[Text]
return package_set(Path(path).read_text().splitlines()) - package_set(packages)
class CondaPip(VirtualenvPip):
def __init__(self, source=None, *args, **kwargs):
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
self.source = source
def run_with_env(self, command, output=False, **kwargs):
if not self.source:
return super(CondaPip, self).run_with_env(command, output=output, **kwargs)
command = CommandSequence(self.source, Argv("pip", *command))
return (command.get_output if output else command.check_call)(
stdin=DEVNULL, **kwargs
)
class CondaAPI(PackageManager):
"""
A programmatic interface for controlling conda
"""
MINIMUM_VERSION = "4.3.30"
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
"""
:param python: base python version to use (e.g python3.6)
:param path: path of env
"""
self.session = session
self.python = python
self.source = None
self.requirements_manager = requirements_manager
self.path = path
self.env_read_only = False
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
self.conda_env_as_base_docker = \
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
bool(ENV_CONDA_ENV_PACKAGE.get())
if ENV_CONDA_ENV_PACKAGE.get():
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
else:
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
self.pip = CondaPip(
session=self.session,
source=self.source,
python=self.python,
requirements_manager=self.requirements_manager,
path=self.path,
)
try:
self.conda = (
find_executable("conda") or
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
shell=select_for_platform(windows=True, linux=False)).strip()
)
except Exception:
raise ValueError("ERROR: package manager \"conda\" selected, "
"but \'conda\' executable could not be located")
try:
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as ex:
raise CommandFailedError(
"Unable to determine conda version: {ex}, output={ex.output}".format(
ex=ex
)
)
self.conda_version = self.get_conda_version(output)
if SimpleVersion.compare_versions(self.conda_version, '<', self.MINIMUM_VERSION):
raise CommandFailedError(
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
self.conda_version, self.MINIMUM_VERSION
)
)
@staticmethod
def get_conda_version(output):
match = re.search(r"(\d+\.){0,2}\d+", output)
if not match:
raise CommandFailedError("Unidentified conda version string:", output)
return match.group(0)
@property
def bin(self):
return self.pip.bin
# noinspection SpellCheckingInspection
def upgrade_pip(self):
# do not change pip version if pre built environement is used
if self.env_read_only:
print('Conda environment in read-only mode, skipping pip upgrade.')
return ''
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
def create(self):
"""
Create a new environment
"""
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
if Path(self.conda_pre_build_env_path).is_dir():
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
self.path = Path(self.conda_pre_build_env_path)
self.source = ("conda", "activate", self.path.as_posix())
self.pip = CondaPip(
session=self.session,
source=self.source,
python=self.python,
requirements_manager=self.requirements_manager,
path=self.path,
)
conda_env = self._get_conda_sh()
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
self.env_read_only = True
return self
elif Path(self.conda_pre_build_env_path).is_file():
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
tar_path = find_executable("tar")
self.path.mkdir(parents=True, exist_ok=True)
output = Argv(
tar_path,
"-xzf",
self.conda_pre_build_env_path,
"-C",
self.path,
).get_output()
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
conda_env = self._get_conda_sh()
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# unpack cleanup
print("Fixing prefix in Conda environment {}".format(self.path))
CommandSequence(('source', conda_env.as_posix()),
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
return self
else:
raise ValueError("Could not restore Conda environment, cannot find {}".format(
self.conda_pre_build_env_path))
output = Argv(
self.conda,
"create",
"--yes",
"--mkdir",
"--prefix",
self.path,
"python={}".format(self.python),
).get_output(stderr=DEVNULL)
match = re.search(
r"\W*(.*activate) ({})".format(re.escape(str(self.path))), output
)
self.source = self.pip.source = (
tuple(match.group(1).split()) + (match.group(2),)
if match
else ("conda", "activate", self.path.as_posix())
)
conda_env = self._get_conda_sh()
if conda_env.is_file() and not is_windows_platform():
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# install cuda toolkit
# noinspection PyBroadException
try:
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
if cuda_version > 0:
self._install('cudatoolkit={:.1f}'.format(cuda_version))
except Exception:
pass
return self
def remove(self):
"""
Delete a conda environment.
Use 'conda env remove', then 'rm_tree' to be safe.
Conda seems to load "vcruntime140.dll" from all its environment on startup.
This means environment have to be deleted using 'conda env remove'.
If necessary, conda can be fooled into deleting a partially-deleted environment by creating an empty file
in '<ENV>\conda-meta\history' (value found in 'conda.gateways.disk.test.PREFIX_MAGIC_FILE').
Otherwise, it complains that said directory is not a conda environment.
See: https://github.com/conda/conda/issues/7682
"""
try:
self._run_command(("env", "remove", "-p", self.path))
except Exception:
pass
rm_tree(self.path)
# if we failed removing the path, change it's name
if is_windows_platform() and Path(self.path).exists():
try:
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
except Exception:
pass
def _install_from_file(self, path):
"""
Install packages from requirement file.
"""
self._install("--file", path)
def _install(self, *args):
# type: (*PathLike) -> ()
# if we are in read only mode, do not install anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
return
channels_args = tuple(
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
)
self._run_command(("install", "-p", self.path) + channels_args + args)
def _get_pip_packages(self, packages):
# type: (Iterable[Text]) -> Sequence[Text]
"""
Return subset of ``packages`` which are not available on conda
"""
pips = []
while True:
with self.temp_file("conda_reqs", packages) as path:
try:
self._install_from_file(path)
except PackageNotFoundError as e:
pips.append(e.pkg)
packages = _package_diff(path, {e.pkg})
else:
break
return pips
def install_packages(self, *packages):
# type: (*Text) -> ()
return self._install(*packages)
def uninstall_packages(self, *packages):
# if we are in read only mode, do not uninstall anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
return ''
return self._run_command(("uninstall", "-p", self.path))
def install_from_file(self, path):
"""
Try to install packages from conda. Install packages which are not available from conda with pip.
"""
try:
self._install_from_file(path)
return
except PackageNotFoundError as e:
pip_packages = [e.pkg]
except PackagesNotFoundError as e:
pip_packages = package_set(e.packages)
with self.temp_file("conda_reqs", _package_diff(path, pip_packages)) as reqs:
self.install_from_file(reqs)
with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs)
def freeze(self, freeze_full_environment=False):
requirements = self.pip.freeze()
req_lines = []
conda_lines = []
# noinspection PyBroadException
try:
pip_lines = requirements['pip']
conda_packages_json = json.loads(
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
name = (r['name'].replace('-', '_'), r['name'])
pip_req_line = [l for l in pip_lines
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
if pip_req_line and \
('@' not in pip_req_line[0] or
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
req_lines.append(pip_req_line[0])
continue
req_lines.append(
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
continue
# check if we have it in our required packages
name = r['name']
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
# skip over packages with _
if name.startswith('_'):
continue
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
# make sure we see the conda packages, put them into the pip as well
if conda_lines:
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
requirements['pip'] = req_lines
requirements['conda'] = conda_lines
except Exception:
pass
if freeze_full_environment:
# noinspection PyBroadException
try:
conda_env_json = json.loads(
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
conda_env_json.pop('name', None)
conda_env_json.pop('prefix', None)
conda_env_json.pop('channels', None)
requirements['conda_env_json'] = json.dumps(conda_env_json)
except Exception:
pass
return requirements
def _load_conda_full_env(self, conda_env_dict, requirements):
# noinspection PyBroadException
try:
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
except Exception:
cuda_version = 0
conda_env_dict['channels'] = self.extra_channels
if 'dependencies' not in conda_env_dict:
conda_env_dict['dependencies'] = []
new_dependencies = OrderedDict()
pip_requirements = None
for line in conda_env_dict['dependencies']:
if isinstance(line, dict):
pip_requirements = line.pop('pip', None)
continue
name = line.strip().split('=', 1)[0].lower()
if name == 'pip':
continue
elif name == 'python':
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
elif name == 'tensorflow-gpu' and cuda_version == 0:
line = 'tensorflow={}'.format(line.split('=')[1])
elif name == 'tensorflow' and cuda_version > 0:
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
elif name in ('cupti', 'cudnn'):
# cudatoolkit should pull them based on the cudatoolkit version
continue
elif name.startswith('_'):
continue
new_dependencies[line.split('=', 1)[0].strip()] = line
# fix packages:
conda_env_dict['dependencies'] = list(new_dependencies.values())
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
result = self._run_command(
("env", "update", "-p", self.path, "--file", name)
)
# check if we need to remove specific packages
bad_req = self._parse_conda_result_bad_packges(result)
if bad_req:
print('failed installing the following conda packages: {}'.format(bad_req))
return False
if pip_requirements:
# create a list of vcs packages that we need to replace in the pip section
vcs_reqs = {}
if 'pip' in requirements:
pip_lines = requirements['pip'].splitlines() \
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
for line in pip_lines:
try:
marker = list(parse(line))
except Exception:
marker = None
if not marker:
continue
m = MarkerRequirement(marker[0])
if m.vcs:
vcs_reqs[m.name] = m
try:
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
PackageManager._selected_manager = self.pip
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
except Exception as e:
print(e)
raise e
finally:
PackageManager._selected_manager = self
self.requirements_manager.post_install(self.session)
def load_requirements(self, requirements):
# if we are in read only mode, do not uninstall anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping requirements installation.')
return None
# if we have a full conda environment, use it and pass the pip to pip
if requirements.get('conda_env_json'):
# noinspection PyBroadException
try:
conda_env_json = json.loads(requirements.get('conda_env_json'))
print('Conda restoring full yaml environment')
return self._load_conda_full_env(conda_env_json, requirements)
except Exception:
print('Could not load fully stored conda environment, falling back to requirements')
# create new environment file
conda_env = dict()
conda_env['channels'] = self.extra_channels
reqs = []
if isinstance(requirements['pip'], six.string_types):
requirements['pip'] = requirements['pip'].split('\n')
if isinstance(requirements.get('conda'), six.string_types):
requirements['conda'] = requirements['conda'].split('\n')
has_torch = False
has_matplotlib = False
try:
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
except:
cuda_version = 0
# notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
# this should happen if experiment was executed on non-conda machine or old trains client
conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda']
conda_supported_req_names = []
pip_requirements = []
for r in conda_supported_req:
try:
marker = list(parse(r))
except:
marker = None
if not marker:
continue
m = MarkerRequirement(marker[0])
# conda does not support version control links
if m.vcs:
pip_requirements.append(m)
continue
# Skip over pip
if m.name in ('pip', 'virtualenv', ):
continue
# python version, only major.minor
if m.name == 'python' and m.specs:
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
if '.' not in m.specs[0][1]:
continue
conda_supported_req_names.append(m.name.lower())
if m.req.name.lower() == 'matplotlib':
has_matplotlib = True
elif m.req.name.lower().startswith('torch'):
has_torch = True
if m.req.name.lower() in ('torch', 'pytorch'):
has_torch = True
m.req.name = 'pytorch'
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
has_torch = True
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
reqs.append(m)
# if we have a conda list, the rest should be installed with pip,
if requirements.get('conda', None) is not None:
for r in requirements['pip']:
try:
marker = list(parse(r))
except:
marker = None
if not marker:
continue
m = MarkerRequirement(marker[0])
# skip over local files (we cannot change the version to a local file)
if m.local_file:
continue
m_name = m.name.lower()
if m_name in conda_supported_req_names:
# this package is in the conda list,
# make sure that if we changed version and we match it in conda
## conda_supported_req_names.remove(m_name)
for cr in reqs:
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
# match versions
cr.specs = m.specs
# # conda always likes "-" not "_" but only on pypi packages
# cr.name = cr.name.lower().replace('_', '-')
break
else:
# not in conda, it is a pip package
pip_requirements.append(m)
if m_name == 'matplotlib':
has_matplotlib = True
# Conda requirements Hacks:
if has_matplotlib:
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
# remove specific cudatoolkit, it should have being preinstalled.
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
if has_torch and cuda_version == 0:
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
# make sure we have no double entries
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
# conform conda packages (version/name)
for r in reqs:
# change _ to - in name but not the prefix _ (as this is conda prefix)
if not r.name.startswith('_') and not requirements.get('conda', None):
r.name = r.name.replace('_', '-')
# remove .post from version numbers, it fails ~= version, and change == to ~=
if r.specs and r.specs[0]:
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
while reqs:
# notice, we give conda more freedom in version selection, to help it choose best combination
def clean_ver(ar):
if not ar.specs:
return ar.tostr()
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
return ar.tostr()
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
result = self._run_command(
("env", "update", "-p", self.path, "--file", name)
)
# check if we need to remove specific packages
bad_req = self._parse_conda_result_bad_packges(result)
if not bad_req:
break
solved = False
for bad_r in bad_req:
name = bad_r.split('[')[0].split('=')[0].split('~')[0].split('<')[0].split('>')[0]
# look for name in requirements
for r in reqs:
if r.name.lower() == name.lower():
pip_requirements.append(r)
reqs.remove(r)
solved = True
break
# we couldn't remove even one package,
# nothing we can do but try pip
if not solved:
pip_requirements.extend(reqs)
break
if pip_requirements:
try:
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
PackageManager._selected_manager = self.pip
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
except Exception as e:
print(e)
raise e
finally:
PackageManager._selected_manager = self
self.requirements_manager.post_install(self.session)
return True
def _parse_conda_result_bad_packges(self, result_dict):
if not result_dict:
return None
if 'bad_deps' in result_dict and result_dict['bad_deps']:
return result_dict['bad_deps']
if result_dict.get('error'):
error_lines = result_dict['error'].split('\n')
if error_lines[0].strip().lower().startswith("unsatisfiableerror:"):
empty_lines = [i for i, l in enumerate(error_lines) if not l.strip()]
if len(empty_lines) >= 2:
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
try:
return yaml.load('\n'.join(deps), Loader=yaml.SafeLoader)
except:
return None
return None
def _run_command(self, command, raw=False, **kwargs):
# type: (Iterable[Text], bool, Any) -> Union[Dict, Text]
"""
Run a conda command, returning JSON output.
The command is prepended with 'conda' and run with JSON output flags.
:param command: command to run
:param raw: return text output and don't change command
:param kwargs: kwargs for Argv.get_output()
:return: JSON output or text output
"""
def escape_ansi(line):
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
command = Argv(*command) # type: Executable
if not raw:
command = (self.conda,) + command + ("--quiet", "--json")
try:
print('Executing Conda: {}'.format(command.serialize()))
result = command.get_output(stdin=DEVNULL, **kwargs)
if self.session.debug_mode:
print(result)
except Exception as e:
result = e.output if hasattr(e, 'output') else ''
if self.session.debug_mode:
print(result)
if raw:
raise
if raw:
return result
result = json.loads(escape_ansi(result)) if result else {}
if result.get('success', False):
print('Pass')
elif result.get('error'):
print('Conda error: {}'.format(result.get('error')))
return result
def get_python_command(self, extra=()):
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
def _get_conda_sh(self):
# type () -> Path
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if base_conda_env.is_file():
return base_conda_env
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
conda = find_executable("conda", path=path)
if not conda:
continue
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if conda_env.is_file():
return conda_env
return base_conda_env
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
exception = attrs(str=True, cmp=False)
@exception
class CondaException(Exception, NonStrictAttrs):
command = attrib()
message = attrib(default=None)
@exception
class UnknownCondaError(CondaException):
data = attrib(default=Factory(dict))
@exception
class PackagesNotFoundError(CondaException):
"""
Conda 4.5 exception - this reports all missing packages.
"""
packages = attrib(default=())
@exception
class PackageNotFoundError(CondaException):
"""
Conda 4.3 exception - this reports one missing package at a time,
as a singleton YAML list.
"""
pkg = attrib(default="", converter=lambda val: yaml.load(val, Loader=yaml.SafeLoader)[0].replace(" ", ""))

View File

@@ -0,0 +1,106 @@
import re
from collections import OrderedDict
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
from ..base import safe_furl as furl
class ExternalRequirements(SimpleSubstitution):
name = "external_link"
def __init__(self, *args, **kwargs):
super(ExternalRequirements, self).__init__(*args, **kwargs)
self.post_install_req = []
self.post_install_req_lookup = OrderedDict()
def match(self, req):
# match both editable or code or unparsed
if not (not req.name or req.req and (req.req.editable or req.req.vcs)):
return False
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
return False
if req.pip_new_version and not (req.req.editable or req.req.vcs):
return False
return True
def post_install(self, session):
post_install_req = self.post_install_req
self.post_install_req = []
for req in post_install_req:
try:
freeze_base = PackageManager.out_of_scope_freeze() or ''
except:
freeze_base = ''
req_line = req.tostr(markers=False)
if req_line.strip().startswith('-e ') or req_line.strip().startswith('--editable'):
req_line = re.sub(r'^(-e|--editable=?)\s*', '', req_line, count=1)
if req.req.vcs and req_line.startswith('git+'):
try:
url_no_frag = furl(req_line)
url_no_frag.set(fragment=None)
# reverse replace
fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1]
vcs_url = req_line[4:]
# reverse replace
vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1]
from ..repo import Git
vcs = Git(session=session, url=vcs_url, location=None, revision=None)
vcs._set_ssh_url()
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
if new_req_line != req_line:
furl_line = furl(new_req_line)
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
req_line,
furl_line.set(password='xxxxxx').tostr() if furl_line.password else new_req_line))
req_line = new_req_line
except Exception:
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
# if we have older pip version we have to make sure we replace back the package name with the
# git repository link. In new versions this is supported and we get "package @ git+https://..."
if not req.pip_new_version:
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
# noinspection PyBroadException
try:
freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
if package_name and package_name[0] not in self.post_install_req_lookup:
self.post_install_req_lookup[package_name[0]] = req.req.line
except Exception:
pass
# no need to force reinstall, pip will always rebuilt if the package comes from git
# and make sure the required packages are installed (if they are not it will install them)
if not PackageManager.out_of_scope_install_package(req_line):
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req.append(req)
# mark skip package, we will install it in post install hook
return Text('')
def replace_back(self, list_of_requirements):
if not list_of_requirements:
return list_of_requirements
for k in list_of_requirements:
# k is either pip/conda
if k not in ('pip', 'conda'):
continue
original_requirements = list_of_requirements[k]
list_of_requirements[k] = [r for r in original_requirements
if r not in self.post_install_req_lookup]
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
for r in self.post_install_req_lookup.keys() if r in original_requirements]
return list_of_requirements

View File

@@ -0,0 +1,94 @@
import sys
from itertools import chain
from typing import Text, Optional
from trains_agent.definitions import PIP_EXTRA_INDICES, PROGRAM_NAME
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.process import Argv, DEVNULL
from trains_agent.session import Session
class SystemPip(PackageManager):
indices_args = None
def __init__(self, interpreter=None, session=None):
# type: (Optional[Text], Optional[Session]) -> ()
"""
Program interface to the system pip.
"""
self._bin = interpreter or sys.executable
self.session = session
@property
def bin(self):
return self._bin
def create(self):
pass
def remove(self):
pass
def install_from_file(self, path):
self.run_with_env(('install', '-r', path) + self.install_flags(), cwd=self.cwd)
def install_packages(self, *packages):
self._install(*(packages + self.install_flags()))
def _install(self, *args):
self.run_with_env(('install',) + args, cwd=self.cwd)
def uninstall_packages(self, *packages):
self.run_with_env(('uninstall', '-y') + packages)
def download_package(self, package, cache_dir):
self.run_with_env(
(
'download',
package,
'--dest', cache_dir,
'--no-deps',
) + self.install_flags()
)
def load_requirements(self, requirements):
requirements = requirements.get('pip') if isinstance(requirements, dict) else requirements
if not requirements:
return
with self.temp_file('cached-reqs', requirements) as path:
self.install_from_file(path)
def uninstall(self, package):
self.run_with_env(('uninstall', '-y', package))
def freeze(self):
"""
pip freeze to all install packages except the running program
:return: Dict contains pip as key and pip's packages to install
:rtype: Dict[str: List[str]]
"""
packages = self.run_with_env(('freeze',), output=True).splitlines()
packages_without_program = [package for package in packages if PROGRAM_NAME not in package]
return {'pip': packages_without_program}
def run_with_env(self, command, output=False, **kwargs):
"""
Run a shell command using environment from a virtualenv script
:param command: command to run
:type command: Iterable[Text]
:param output: return output
:param kwargs: kwargs for get_output/check_output command
"""
command = self._make_command(command)
return (command.get_output if output else command.check_call)(stdin=DEVNULL, **kwargs)
def _make_command(self, command):
return Argv(self.bin, '-m', 'pip', '--disable-pip-version-check', *command)
def install_flags(self):
if self.indices_args is None:
self.indices_args = tuple(
chain.from_iterable(('--extra-index-url', x) for x in PIP_EXTRA_INDICES)
)
return self.indices_args

View File

@@ -0,0 +1,77 @@
from typing import Any
from pathlib2 import Path
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session
from ..pip_api.system import SystemPip
from ..requirements import RequirementsManager
class VirtualenvPip(SystemPip, PackageManager):
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
"""
Program interface to virtualenv pip.
Must be given either path to virtualenv or source command.
Either way, ``self.source`` is exposed.
:param session: a Session object for communication
:param python: interpreter path
:param path: path of virtual environment to create/manipulate
:param python: python version
:param interpreter: path of python interpreter
"""
super(VirtualenvPip, self).__init__(
session=session,
interpreter=interpreter or Path(
path, select_for_platform(linux="bin/python", windows="scripts/python.exe"))
)
self.path = path
self.requirements_manager = requirements_manager
self.python = python
def _make_command(self, command):
return self.session.command(self.bin, "-m", "pip", "--disable-pip-version-check", *command)
def load_requirements(self, requirements):
if isinstance(requirements, dict) and requirements.get("pip"):
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
super(VirtualenvPip, self).load_requirements(requirements)
self.requirements_manager.post_install(self.session)
def create_flags(self):
"""
Configurable environment creation arguments
"""
return Argv.conditional_flag(
self.session.config["agent.package_manager.system_site_packages"],
"--system-site-packages",
)
def install_flags(self):
"""
Configurable package installation creation arguments
"""
return super(VirtualenvPip, self).install_flags() + Argv.conditional_flag(
self.session.config["agent.package_manager.force_upgrade"], "--upgrade"
)
def create(self):
"""
Create virtualenv.
Only valid if instantiated with path.
Use self.python as self.bin does not exist.
"""
self.session.command(
self.python, "-m", "virtualenv", self.path, *self.create_flags()
).check_call()
return self
def remove(self):
"""
Delete virtualenv.
Only valid if instantiated with path.
"""
rm_tree(self.path)

View File

@@ -0,0 +1,139 @@
from copy import deepcopy
from functools import wraps
import attr
import sys
import os
from pathlib2 import Path
from trains_agent.helper.process import Argv, DEVNULL, check_if_command_exists
from trains_agent.session import Session, POETRY
def prop_guard(prop, log_prop=None):
assert isinstance(prop, property)
assert not log_prop or isinstance(log_prop, property)
def decorator(func):
message = "%s:%s calling {}, {} = %s".format(
func.__name__, prop.fget.__name__
)
@wraps(func)
def new_func(self, *args, **kwargs):
prop_value = prop.fget(self)
if log_prop:
log_prop.fget(self).debug(
message,
type(self).__name__,
"" if prop_value else " not",
prop_value,
)
if prop_value:
return func(self, *args, **kwargs)
return new_func
return decorator
class PoetryConfig:
def __init__(self, session, interpreter=None):
# type: (Session, str) -> ()
self.session = session
self._log = session.get_logger(__name__)
self._python = interpreter or sys.executable
self._initialized = False
@property
def log(self):
return self._log
@property
def enabled(self):
return self.session.config["agent.package_manager.type"] == POETRY
_guard_enabled = prop_guard(enabled, log)
def run(self, *args, **kwargs):
func = kwargs.pop("func", Argv.get_output)
kwargs.setdefault("stdin", DEVNULL)
kwargs['env'] = deepcopy(os.environ)
if 'VIRTUAL_ENV' in kwargs['env'] or 'CONDA_PREFIX' in kwargs['env']:
kwargs['env'].pop('VIRTUAL_ENV', None)
kwargs['env'].pop('CONDA_PREFIX', None)
kwargs['env'].pop('PYTHONPATH', None)
if hasattr(sys, "real_prefix") and hasattr(sys, "base_prefix"):
path = ':'+kwargs['env']['PATH']
path = path.replace(':'+sys.base_prefix, ':'+sys.real_prefix, 1)
kwargs['env']['PATH'] = path
if check_if_command_exists("poetry"):
argv = Argv("poetry", *args)
else:
argv = Argv(self._python, "-m", "poetry", *args)
self.log.debug("running: %s", argv)
return func(argv, **kwargs)
def _config(self, *args, **kwargs):
return self.run("config", *args, **kwargs)
@_guard_enabled
def initialize(self, cwd=None):
if not self._initialized:
self._initialized = True
try:
self._config("--local", "virtualenvs.in-project", "true", cwd=cwd)
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
except Exception as ex:
print("Exception: {}\nError: Failed configuring Poetry virtualenvs.in-project".format(ex))
raise
def get_api(self, path):
# type: (Path) -> PoetryAPI
return PoetryAPI(self, path)
@attr.s
class PoetryAPI(object):
config = attr.ib(type=PoetryConfig)
path = attr.ib(type=Path, converter=Path)
INDICATOR_FILES = "pyproject.toml", "poetry.lock"
def install(self):
# type: () -> bool
if self.enabled:
self.config.run("install", "-n", cwd=str(self.path), func=Argv.check_call)
return True
return False
@property
def enabled(self):
return self.config.enabled and (
any((self.path / indicator).exists() for indicator in self.INDICATOR_FILES)
)
def freeze(self):
lines = self.config.run("show", cwd=str(self.path)).splitlines()
lines = [[p for p in line.split(' ') if p] for line in lines]
return {"pip": [parts[0]+'=='+parts[1]+' # '+' '.join(parts[2:]) for parts in lines]}
def get_python_command(self, extra):
if check_if_command_exists("poetry"):
return Argv("poetry", "run", "python", *extra)
else:
return Argv(self.config._python, "-m", "poetry", "run", "python", *extra)
def upgrade_pip(self, *args, **kwargs):
pass
def set_selected_package_manager(self, *args, **kwargs):
pass
def out_of_scope_install_package(self, *args, **kwargs):
pass
def install_from_file(self, *args, **kwargs):
pass

View File

@@ -0,0 +1,48 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class PostRequirement(SimpleSubstitution):
name = ("horovod", )
optional_package_names = tuple()
def __init__(self, *args, **kwargs):
super(PostRequirement, self).__init__(*args, **kwargs)
self.post_install_req = []
# check if we need to replace the packages:
post_packages = self.config.get('agent.package_manager.post_packages', None)
if post_packages:
self.__class__.name = post_packages
post_optional_packages = self.config.get('agent.package_manager.post_optional_packages', None)
if post_optional_packages:
self.__class__.optional_package_names = post_optional_packages
def match(self, req):
# match both horovod
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
def post_install(self, session):
for req in self.post_install_req:
if req.name in self.optional_package_names:
# noinspection PyBroadException
try:
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
except Exception:
pass
else:
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
self.post_install_req = []
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req.append(req)
# mark skip package, we will install it in post install hook
return Text('')

View File

@@ -0,0 +1,75 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class PriorityPackageRequirement(SimpleSubstitution):
name = ("cython", "numpy", "setuptools", )
optional_package_names = tuple()
def __init__(self, *args, **kwargs):
super(PriorityPackageRequirement, self).__init__(*args, **kwargs)
# check if we need to replace the packages:
priority_packages = self.config.get('agent.package_manager.priority_packages', None)
if priority_packages:
self.__class__.name = priority_packages
priority_optional_packages = self.config.get('agent.package_manager.priority_optional_packages', None)
if priority_optional_packages:
self.__class__.optional_package_names = priority_optional_packages
def match(self, req):
# match both Cython & cython
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
if req.name in self.optional_package_names:
# noinspection PyBroadException
try:
if PackageManager.out_of_scope_install_package(str(req)):
return Text(req)
except Exception:
pass
return Text('')
PackageManager.out_of_scope_install_package(str(req))
return Text(req)
class PackageCollectorRequirement(SimpleSubstitution):
"""
This RequirementSubstitution class will allow you to have multiple instances of the same
package, it will output the last one (by order) to be actually used.
"""
name = tuple()
def __init__(self, session, collect_package):
super(PackageCollectorRequirement, self).__init__(session)
self._collect_packages = collect_package or tuple()
self._last_req = None
def match(self, req):
# match package names
return req.name and req.name.lower() in self._collect_packages
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
self._last_req = req.clone()
return ''
def post_scan_add_req(self):
"""
Allows the RequirementSubstitution to add an extra line/requirements after
the initial requirements scan is completed.
Called only once per requirements.txt object
"""
last_req = self._last_req
self._last_req = None
return last_req

View File

@@ -0,0 +1,723 @@
from __future__ import unicode_literals
import re
import sys
from furl import furl
import urllib.parse
from operator import itemgetter
from html.parser import HTMLParser
from typing import Text
import attr
import requests
import six
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
def os_to_wheel_name(x):
return OS_TO_WHEEL_NAME[x]
def fix_version(version):
def replace(nums, prerelease):
if prerelease:
return "{}-{}".format(nums, prerelease)
return nums
return re.sub(
r"(\d+(?:\.\d+){,2})(?:\.(.*))?",
lambda match: replace(*match.groups()),
version,
)
class LinksHTMLParser(HTMLParser):
def __init__(self):
super(LinksHTMLParser, self).__init__()
self.links = []
def handle_data(self, data):
if data and data.strip():
self.links += [data]
@attr.s
class PytorchWheel(object):
os_name = attr.ib(type=str, converter=os_to_wheel_name)
cuda_version = attr.ib(converter=lambda x: "cu{}".format(x) if x else "cpu")
python = attr.ib(type=str, converter=lambda x: str(x).replace(".", ""))
torch_version = attr.ib(type=str, converter=fix_version)
url_template = (
"http://download.pytorch.org/whl/"
"{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
)
def __attrs_post_init__(self):
self.unicode = "u" if self.python.startswith("2") else ""
def make_url(self):
# type: () -> Text
return self.url_template.format(self)
class PytorchResolutionError(FatalSpecsResolutionError):
pass
class SimplePytorchRequirement(SimpleSubstitution):
name = "torch"
packages = ("torch", "torchvision", "torchaudio")
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
torch_page_lookup = {
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
90: 'https://download.pytorch.org/whl/cu90/torch_stable.html',
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
}
def __init__(self, *args, **kwargs):
super(SimplePytorchRequirement, self).__init__(*args, **kwargs)
self._matched = False
def match(self, req):
# match both any of out packages
return req.name in self.packages
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Get rid of +cpu +cu?? etc.
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except:
pass
self._matched = True
return Text(req)
def matching_done(self, reqs, package_manager):
# type: (Sequence[MarkerRequirement], object) -> ()
if not self._matched:
return
# TODO: add conda channel support
from .pip_api.system import SystemPip
if package_manager and isinstance(package_manager, SystemPip):
extra_url, _ = self.get_torch_page(self.cuda_version)
package_manager.add_extra_install_flags(('-f', extra_url))
@classmethod
def get_torch_page(cls, cuda_version, nightly=False):
# noinspection PyBroadException
try:
cuda = int(cuda_version)
except Exception:
cuda = 0
if nightly:
for c in range(cuda, max(-1, cuda-15), -1):
# then try the nightly builds, it might be there...
torch_url = cls.nightly_page_lookup_template.format(c)
# noinspection PyBroadException
try:
if requests.get(torch_url, timeout=10).ok:
print('Torch nightly CUDA {} download page found'.format(c))
cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
return
# first check if key is valid
if cuda in cls.torch_page_lookup:
return cls.torch_page_lookup[cuda], cuda
# then try a new cuda version page
for c in range(cuda, max(-1, cuda-15), -1):
torch_url = cls.page_lookup_template.format(c)
# noinspection PyBroadException
try:
if requests.get(torch_url, timeout=10).ok:
print('Torch CUDA {} download page found'.format(c))
cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
for k in keys:
if k <= cuda:
return cls.torch_page_lookup[k], k
# return default - zero
return cls.torch_page_lookup[0], 0
class PytorchRequirement(SimpleSubstitution):
name = "torch"
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
def __init__(self, *args, **kwargs):
os_name = kwargs.pop("os_override", None)
super(PytorchRequirement, self).__init__(*args, **kwargs)
self.log = self._session.get_logger(__name__)
self.package_manager = self.config["agent.package_manager.type"].lower()
self.os = os_name or self.get_platform()
self.cuda = "cuda{}".format(self.cuda_version).lower()
self.python_version_string = str(self.config["agent.default_python"])
self.python_major_minor_str = '.'.join(self.python_version_string.split('.')[:2])
if '.' not in self.python_major_minor_str:
raise PytorchResolutionError(
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
"must have both major and minor parts of the version (for example: '3.7')".format(
self.python_version_string
)
)
self.python = "python{}".format(self.python_major_minor_str)
self.exceptions = [
PytorchResolutionError(message)
for message in (
None,
'cuda version "{}" is not supported'.format(self.cuda),
'python version "{}" is not supported'.format(
self.python_version_string
),
)
]
try:
self.validate_python_version()
except PytorchResolutionError as e:
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
self._original_req = []
@property
def is_conda(self):
return self.package_manager == "conda"
@property
def is_pip(self):
return not self.is_conda
def validate_python_version(self):
"""
Make sure python version has both major and minor versions as required for choosing pytorch wheel
"""
if self.is_pip and not self.python_major_minor_str:
raise PytorchResolutionError(
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
"must have both major and minor parts of the version (for example: '3.7')".format(
self.python_version_string
)
)
def match(self, req):
return req.name in self.packages
@staticmethod
def get_platform():
if sys.platform == "linux":
return "linux"
if sys.platform == "win32" or sys.platform == "cygwin":
return "windows"
if sys.platform == "darwin":
return "macos"
raise RuntimeError("unrecognized OS")
def _get_link_from_torch_page(self, req, torch_url):
links_parser = LinksHTMLParser()
links_parser.feed(requests.get(torch_url, timeout=10).text)
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
py_ver = self.python_major_minor_str.replace('.', '')
url = None
last_v = None
closest_v = None
# search for our package
for l in links_parser.links:
parts = l.split('/')[-1].split('-')
if len(parts) < 5:
continue
if parts[0] != req.name:
continue
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
# version ignore .postX suffix (treat as regular version)
# noinspection PyBroadException
try:
v = str(parts[1].split('%')[0].split('+')[0])
except Exception:
continue
if len(parts) < 3 or not parts[2].endswith(py_ver):
continue
if len(parts) < 5 or platform_wheel not in parts[4]:
continue
# update the closest matched version (from above)
if not closest_v:
closest_v = v
elif SimpleVersion.compare_versions(
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
SimpleVersion.compare_versions(
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
closest_v = v
# check if this an actual match
if not req.compare_version(v) or \
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
continue
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
last_v = v
# if we found an exact match, use it
# noinspection PyBroadException
try:
if req.specs[0][0] == '==' and \
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
break
except Exception:
pass
return url, last_v or closest_v
def get_url_for_platform(self, req):
# check if package is already installed with system packages
# noinspection PyBroadException
try:
if self.config.get("agent.package_manager.system_site_packages", None):
from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name]))
# notice the comparison order, the first part will make sure we have a valid installed package
if installed_torch and installed_torch[0]['version'] and \
req.compare_version(installed_torch[0]['version']):
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
# package already installed, do nothing
req.specs = [('==', str(installed_torch[0]['version']))]
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
except Exception:
pass
# make sure we have a specific version to retrieve
if not req.specs:
req.specs = [('>', '0')]
# noinspection PyBroadException
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except Exception:
pass
op, version = req.specs[0]
# assert op == "=="
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly", None):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU):
while not url and torch_url_key > 0:
previous_cuda_key = torch_url_key
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
req, previous_cuda_key, closest_matched_version))
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
if url:
break
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
# never fallback to CPU
if torch_url_key < 1:
print(
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
req, previous_cuda_key))
raise ValueError(
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
else:
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
if not url:
url = PytorchWheel(
torch_version=fix_version(version),
python=self.python_major_minor_str.replace('.', ''),
os_name=self.os,
cuda_version=self.cuda_version,
).make_url()
if url:
# normalize url (sometimes we will get ../ which we should not...
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
# print found
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
self.log.debug("checking url: %s", url)
return url, requests.head(url, timeout=10).ok
@staticmethod
def match_version(req, options):
versioned_options = sorted(
((fix_version(key), value) for key, value in options.items()),
key=itemgetter(0),
reverse=True,
)
req.specs = [(op, fix_version(version)) for op, version in req.specs]
try:
return next(
replacement
for version, replacement in versioned_options
if req.compare_version(version)
)
except StopIteration:
raise PytorchResolutionError(
'Could not find wheel for "{}", '
"Available versions: {}".format(req, list(options))
)
def replace_conda(self, req):
spec = "".join(req.specs[0]) if req.specs else ""
if not self.cuda_version:
return "pytorch-cpu{spec}\ntorchvision-cpu".format(spec=spec)
return "pytorch{spec}\ntorchvision\ncuda{self.cuda_version}".format(
self=self, spec=spec
)
def _table_lookup(self, req):
"""
Look for pytorch wheel matching `req` in table
:param req: python requirement
"""
def check(base_, key_, exception_):
result = base_.get(key_)
if not result:
if key_.startswith('cuda'):
print('Could not locate, {}'.format(exception_))
ver = sorted([float(a.replace('cuda', '').replace('none', '0')) for a in base_.keys()], reverse=True)[0]
key_ = 'cuda'+str(int(ver))
result = base_.get(key_)
print('Reverting to \"{}\"'.format(key_))
if not result:
raise exception_
return result
raise exception_
if isinstance(result, Exception):
raise result
return result
if self.is_conda:
return self.replace_conda(req)
base = self.MAP
for key, exception in zip((self.os, self.cuda, self.python), self.exceptions):
base = check(base, key, exception)
return self.match_version(req, base).replace(" ", "\n")
def replace(self, req):
try:
new_req = self._replace(req)
if new_req:
self._original_req.append((req, new_req))
return new_req
except Exception as e:
message = "Exception when trying to resolve python wheel"
self.log.debug(message, exc_info=True)
raise PytorchResolutionError("{}: {}".format(message, e))
def _replace(self, req):
self.validate_python_version()
try:
result, ok = self.get_url_for_platform(req)
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
except:
pass
# try:
# result = self._table_lookup(req)
# except Exception as e:
# exc = e
# else:
# self.log.debug('Replacing requirement "%s" with %r', req, result)
# return result
# self.log.debug(
# "Could not find Pytorch wheel in table, trying manually constructing URL"
# )
result = ok = None
# try:
# result, ok = self.get_url_for_platform(req)
# except Exception:
# pass
if not ok:
if result:
self.log.debug("URL not found: {}".format(result))
exc = PytorchResolutionError(
"Could not find pytorch wheel URL for: {} with cuda {} support".format(req, self.cuda_version)
)
# cancel exception chaining
six.raise_from(exc, None)
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
"""
:param list_of_requirements: {'pip': ['a==1.0', ]}
:return: {'pip': ['a==1.0', ]}
"""
if not self._original_req:
return list_of_requirements
try:
for k, lines in list_of_requirements.items():
# k is either pip/conda
if k not in ('pip', 'conda'):
continue
for i, line in enumerate(lines):
if not line or line.lstrip().startswith('#'):
continue
parts = [p for p in re.split('\s|=|\.|<|>|~|!|@|#', line) if p]
if not parts:
continue
for req, new_req in self._original_req:
if req.req.name == parts[0]:
# support for pip >= 20.1
if '@' in line:
# skip if we have nothing to add
if str(req).strip() != str(new_req).strip():
# if this is local file and use the version detection
if req.local_file:
lines[i] = '{}'.format(str(new_req))
else:
lines[i] = '{} # {}'.format(str(req), str(new_req))
else:
lines[i] = '{} # {}'.format(line, str(new_req))
break
except:
pass
return list_of_requirements
MAP = {
"windows": {
"cuda100": {
"python3.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-win_amd64.whl"
},
"python3.6": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda92": {
"python3.7": {
"0.4.1",
"http://download.pytorch.org/whl/cu92/torch-0.4.1-cp37-cp37m-win_amd64.whl",
},
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda91": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda90": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda80": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cudanone": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
},
"macos": {
"cuda100": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda92": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda91": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda90": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda80": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cudanone": {
"python3.6": {"0.4.0": "torch"},
"python3.5": {"0.4.0": "torch"},
"python2.7": {"0.4.0": "torch"},
},
},
"linux": {
"cuda100": {
"python3.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp37-cp37m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl",
},
"python3.6": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp36-cp36m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl",
},
"python3.5": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp35-cp35m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp35-cp35m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl",
},
"python2.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp27-cp27mu-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl",
},
},
"cuda92": {
"python3.7": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp37-cp37m-manylinux1_x86_64.whl"
},
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl"
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp35-cp35m-manylinux1_x86_64.whl"
},
"python2.7": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp27-cp27mu-manylinux1_x86_64.whl"
},
},
"cuda91": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl"
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl"
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl"
},
},
"cuda90": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
"cuda80": {
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
"cudanone": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
},
}

View File

@@ -0,0 +1,604 @@
from __future__ import absolute_import, unicode_literals
import operator
import os
import re
from abc import ABCMeta, abstractmethod
from copy import deepcopy, copy
from itertools import chain, starmap
from operator import itemgetter
from os import path
from typing import Text, List, Type, Optional, Tuple, Dict
from pathlib2 import Path
from pyhocon import ConfigTree
import six
from trains_agent.definitions import PIP_EXTRA_INDICES
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session, normalize_cuda_version
from trains_agent.external.requirements_parser import parse
from trains_agent.external.requirements_parser.requirement import Requirement
from .translator import RequirementsTranslator
class SpecsResolutionError(Exception):
pass
class FatalSpecsResolutionError(Exception):
pass
@six.python_2_unicode_compatible
class MarkerRequirement(object):
# if True pip version above 20.x and with support for "package @ scheme://link"
# default is True
pip_new_version = True
def __init__(self, req): # type: (Requirement) -> None
self.req = req
@property
def marker(self):
match = re.search(r';\s*(.*)', self.req.line)
if match:
return match.group(1)
return None
def tostr(self, markers=True):
if not self.uri:
parts = [self.name or self.line]
if self.extras:
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
if self.specifier:
parts.append(self.format_specs())
elif self.vcs:
# leave the line as is, let pip handle it
if self.line:
return self.line
else:
# let's build the line manually
parts = [
self.uri,
'@{}'.format(self.revision) if self.revision else '',
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
]
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
# package @ file:///example.com/somewheel.whl
# leave the line as is, let pip handle it
return self.line
else:
parts = [self.uri]
if markers and self.marker:
parts.append('; {0}'.format(self.marker))
return ''.join(parts)
def clone(self):
return MarkerRequirement(copy(self.req))
__str__ = tostr
def __repr__(self):
return '{self.__class__.__name__}[{self}]'.format(self=self)
def format_specs(self, num_parts=None, max_num_parts=None):
max_num_parts = max_num_parts or num_parts
if max_num_parts is None or not self.specs:
return ','.join(starmap(operator.add, self.specs))
op, version = self.specs[0]
for v in self._sub_versions_pep440:
version = version.replace(v, '.')
if num_parts:
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
else:
version = version.strip('.').split('.')[:max_num_parts]
return op+'.'.join(version)
def __getattr__(self, item):
return getattr(self.req, item)
@property
def specs(self): # type: () -> List[Tuple[Text, Text]]
return self.req.specs
@specs.setter
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
self.req.specs = value
def fix_specs(self):
def solve_by(func, op_is, specs):
return func([(op, version) for op, version in specs if op == op_is])
def solve_equal(specs):
if len(set(version for _, version in self.specs)) > 1:
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
return specs
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
equal = solve_by(solve_equal, '==', self.specs)
if equal:
self.specs = equal
else:
self.specs = greater + smaller
def compare_version(self, requested_version, op=None, num_parts=3):
"""
compare the requested version with the one we have in the spec,
If the requested version is 1.2.3 the self.spec should be 1.2.3*
If the requested version is 1.2 the self.spec should be 1.2*
etc.
:param str requested_version:
:param str op: '==', '>', '>=', '<=', '<', '~='
:param int num_parts: number of parts to compare
:return: True if we answer the requested version
"""
# if we have no specific version, we cannot compare, so assume it's okay
if not self.specs:
return True
version = self.specs[0][1]
op = (op or self.specs[0][0]).strip()
return SimpleVersion.compare_versions(
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
class SimpleVersion:
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
_local_version_separators = re.compile(r"[\._-]")
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
@classmethod
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
"""
Compare two versions based on the op operator
returns bool(version_a op version_b)
Notice: Ignores a/b/rc/post/dev markers on the version
:param str version_a:
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
:param str version_b:
:param bool ignore_sub_versions: if true compare only major.minor.patch
(ignore a/b/rc/post/dev in the comparison)
:param int num_parts: number of parts to compare, split by . (dot)
:return bool: version_a op version_b
"""
if not version_b:
return True
if op == '~=':
num_parts = max(num_parts, 2)
op = '=='
ignore_sub_versions = True
elif op == '===':
op = '=='
try:
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
except:
# revert to string based
for v in cls._sub_versions_pep440:
version_a = version_a.replace(v, '.')
version_b = version_b.replace(v, '.')
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
version_a_key = ''
version_b_key = ''
for i in range(num_parts):
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
version_a_key += pad.format(version_a[i])
version_b_key += pad.format(version_b[i])
if op == '==':
return version_a_key == version_b_key
if op == '<=':
return version_a_key <= version_b_key
if op == '>=':
return version_a_key >= version_b_key
if op == '>':
return version_a_key > version_b_key
if op == '<':
return version_a_key < version_b_key
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
@staticmethod
def _parse_letter_version(
letter, # type: str
number, # type: Union[str, bytes, SupportsInt]
):
# type: (...) -> Optional[Tuple[str, int]]
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
return ()
@staticmethod
def _get_match_key(match, num_parts, ignore_sub_versions):
if ignore_sub_versions:
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
(), (), (), (),)
return (
int(match.group("epoch")) if match.group("epoch") else 0,
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
SimpleVersion._parse_letter_version(
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
),
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
SimpleVersion._parse_local_version(match.group("local")),
)
@staticmethod
def _parse_local_version(local):
# type: (str) -> Optional[LocalType]
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in SimpleVersion._local_version_separators.split(local)
)
return ()
@six.add_metaclass(ABCMeta)
class RequirementSubstitution(object):
_pip_extra_index_url = PIP_EXTRA_INDICES
def __init__(self, session):
# type: (Session) -> ()
self._session = session
self.config = session.config # type: ConfigTree
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
self.package_manager = self.config['agent.package_manager.type']
@abstractmethod
def match(self, req): # type: (MarkerRequirement) -> bool
"""
Returns whether a requirement needs to be modified by this substitution.
"""
pass
@abstractmethod
def replace(self, req): # type: (MarkerRequirement) -> Text
"""
Replace a requirement
"""
pass
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
"""
Allows the RequirementSubstitution to add an extra line/requirements after
the initial requirements scan is completed.
Called only once per requirements.txt object
"""
return None
def post_install(self, session):
pass
@classmethod
def get_pip_version(cls, package):
output = Argv(
'pip',
'search',
package,
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
).get_output()
# ad-hoc pattern to duplicate the behavior of the old code
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
@property
def cuda_version(self):
return self.config['agent.cuda_version']
@property
def cudnn_version(self):
return self.config['agent.cudnn_version']
class SimpleSubstitution(RequirementSubstitution):
@property
@abstractmethod
def name(self):
pass
def match(self, req): # type: (MarkerRequirement) -> bool
return (self.name == req.name or (
req.uri and
re.match(r'https?://', req.uri) and
self.name in req.uri
))
def replace(self, req): # type: (MarkerRequirement) -> Text
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
if req.uri:
return re.sub(
r'({})(.*?)(-cp)'.format(self.name),
r'\1\2{}\3'.format(self.suffix),
req.uri,
count=1)
if req.specs:
_, version_number = req.specs[0]
# assert packaging_version.parse(version_number)
else:
version_number = self.get_pip_version(self.name)
req.specs = [('==', version_number + self.suffix)]
return Text(req)
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
"""
:param list_of_requirements: {'pip': ['a==1.0', ]}
:return: {'pip': ['a==1.0', ]}
"""
return list_of_requirements
@six.add_metaclass(ABCMeta)
class CudaSensitiveSubstitution(SimpleSubstitution):
def match(self, req): # type: (MarkerRequirement) -> bool
return self.cuda_version and self.cudnn_version and \
super(CudaSensitiveSubstitution, self).match(req)
class CudaNotFound(Exception):
pass
class RequirementsManager(object):
def __init__(self, session, base_interpreter=None):
# type: (Session, PathLike) -> ()
self._session = session
self.config = deepcopy(session.config) # type: ConfigTree
self.handlers = [] # type: List[RequirementSubstitution]
agent = self.config['agent']
self.active = not agent.get('cpu_only', False)
self.found_cuda = False
if self.active:
try:
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
self.found_cuda = True
except Exception:
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
if agent.get('cuda_version'):
self.found_cuda = True
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
cache_dir=pip_cache_dir.as_posix())
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
self.handlers.append(cls(self._session))
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
match = re.search(r';\s*(.*)', Text(req))
if match:
req.markers = match.group(1).split(',')
if not self.active:
return None
for handler in self.handlers:
if handler.match(req):
return handler.replace(req)
return None
def replace(self, requirements): # type: (Text) -> Text
def safe_parse(req_str):
try:
return next(parse(req_str))
except Exception as ex:
return Requirement(req_str)
parsed_requirements = tuple(
map(
MarkerRequirement,
[safe_parse(line) for line in (requirements.splitlines()
if isinstance(requirements, six.text_type) else requirements)]
)
)
if not parsed_requirements:
# return the original requirements just in case
return requirements
def replace_one(i, req):
# type: (int, MarkerRequirement) -> Optional[Text]
try:
return self._replace_one(req)
except FatalSpecsResolutionError:
warning('could not resolve python wheel replacement for {}'.format(req))
raise
except Exception:
warning('could not resolve python wheel replacement for \"{}\", '
'using original requirements line: {}'.format(req, i))
return None
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
conda = is_conda(self.config)
result = map(
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
new_requirements,
parsed_requirements
)
if not conda:
result = map(self.translator.translate, result)
result = list(result)
# add post scan add requirements call back
for h in self.handlers:
req = h.post_scan_add_req()
if req:
result.append(req.tostr())
return join_lines(result)
def post_install(self, session):
for h in self.handlers:
try:
h.post_install(session)
except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
raise
def replace_back(self, requirements):
if self.translator:
requirements = self.translator.replace_back(requirements)
for h in self.handlers:
try:
requirements = h.replace_back(requirements)
except Exception:
pass
return requirements
@staticmethod
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
cuda_version = config['agent.cuda_version']
cudnn_version = config['agent.cudnn_version']
if cuda_version and cudnn_version:
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
if not cuda_version and is_windows_platform():
try:
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
if k.startswith('CUDA_PATH_V')]
cuda_vers = max(cuda_vers)
if cuda_vers > 40:
cuda_version = cuda_vers
except:
pass
if not cuda_version:
try:
try:
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
if is_windows_platform() and 'CUDA_PATH' in os.environ:
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
output = Argv(nvcc, '--version').get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'release (.{3})', output).group(1)
cuda_version = Text(int(float(match) * 10))
except:
pass
if not cuda_version:
try:
try:
output = Argv('nvidia-smi',).get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
match = match.group(1)+'.'+match.group(2)
cuda_version = Text(int(float(match) * 10))
except:
pass
if not cudnn_version:
try:
cuda_lib = which('nvcc')
if is_windows_platform:
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
else:
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
cudnn_major, cudnn_minor = None, None
try:
include_file = open(cudnn_h)
except OSError:
raise CudaNotFound('Could not read cudnn.h')
with include_file:
for line in include_file:
if 'CUDNN_MAJOR' in line:
cudnn_major = line.split()[-1]
if 'CUDNN_MINOR' in line:
cudnn_minor = line.split()[-1]
if cudnn_major and cudnn_minor:
break
cudnn_version = cudnn_major + (cudnn_minor or '0')
except:
pass
return (normalize_cuda_version(cuda_version or 0),
normalize_cuda_version(cudnn_version or 0))

View File

@@ -0,0 +1,90 @@
from typing import Text
from furl import furl
from pathlib2 import Path
from trains_agent.config import Config
from .pip_api.system import SystemPip
class RequirementsTranslator(object):
"""
Translate explicit URLs to local URLs after downloading them to cache
"""
SUPPORTED_SCHEMES = ["http", "https", "ftp"]
def __init__(self, session, interpreter=None, cache_dir=None):
self._session = session
config = session.config
self.cache_dir = cache_dir or Path(config["agent.pip_download_cache.path"]).expanduser().as_posix()
self.enabled = config["agent.pip_download_cache.enabled"]
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
self.config = Config()
self.pip = SystemPip(interpreter=interpreter, session=self._session)
self._translate_back = {}
def download(self, url):
self.pip.download_package(url, cache_dir=self.cache_dir)
@classmethod
def is_supported_link(cls, line):
# type: (Text) -> bool
"""
Return whether requirement is a link that should be downloaded to cache
"""
url = furl(line)
return (
url.scheme
and url.scheme.lower() in cls.SUPPORTED_SCHEMES
and line.lstrip().lower().startswith(url.scheme.lower())
)
def translate(self, line):
"""
If requirement is supported, download it to cache and return the download path
"""
if not (self.enabled and self.is_supported_link(line)):
return line
command = self.config.command
command.log('Downloading "{}" to pip cache'.format(line))
url = furl(line)
try:
wheel_name = url.path.segments[-1]
except IndexError:
command.error('Could not parse wheel name of "{}"'.format(line))
return line
try:
self.download(line)
downloaded = Path(self.cache_dir, wheel_name).expanduser().as_uri()
except Exception:
command.error('Could not download wheel name of "{}"'.format(line))
return line
self._translate_back[str(downloaded)] = line
return downloaded
def replace_back(self, requirements):
if not requirements:
return requirements
for k in requirements:
# k is either pip/conda
if k not in ('pip', 'conda'):
continue
original_requirements = requirements[k]
new_requirements = []
for line in original_requirements:
local_file = [d for d in self._translate_back.keys() if d in line]
if local_file:
local_file = local_file[0]
new_requirements.append(line.replace(local_file, self._translate_back[local_file]))
else:
new_requirements.append(line)
requirements[k] = new_requirements
return requirements

View File

@@ -0,0 +1,115 @@
from typing import Optional, Text
import requests
from pathlib2 import Path
import six
from trains_agent.definitions import CONFIG_DIR
from trains_agent.helper.process import Argv, DEVNULL
from .pip_api.venv import VirtualenvPip
class VenvUpdateAPI(VirtualenvPip):
URL_FILE_PATH = Path(CONFIG_DIR, "venv-update-url.txt")
SCRIPT_PATH = Path(CONFIG_DIR, "venv-update")
def __init__(self, url, *args, **kwargs):
super(VenvUpdateAPI, self).__init__(*args, **kwargs)
self.url = url
self._script_path = None
self._first_install = True
@property
def downloaded_venv_url(self):
# type: () -> Optional[Text]
try:
return self.URL_FILE_PATH.read_text()
except OSError:
return None
@downloaded_venv_url.setter
def downloaded_venv_url(self, value):
self.URL_FILE_PATH.write_text(value)
def _check_script_validity(self, path):
"""
Make sure script in ``path`` is a valid python script
:param path:
:return:
"""
result = Argv(self.bin, path, "--version").call(
stdout=DEVNULL, stderr=DEVNULL, stdin=DEVNULL
)
return result == 0
@property
def script_path(self):
# type: () -> Text
if not self._script_path:
self._script_path = self.SCRIPT_PATH
if not (
self._script_path.exists()
and self.downloaded_venv_url
and self.downloaded_venv_url == self.url
and self._check_script_validity(self._script_path)
):
with self._script_path.open("wb") as f:
for data in requests.get(self.url, stream=True):
f.write(data)
self.downloaded_venv_url = self.url
return self._script_path
def install_from_file(self, path):
first_install = (
Argv(
self.python,
six.text_type(self.script_path),
"venv=",
"-p",
self.python,
self.path,
)
+ self.create_flags()
+ ("install=", "-r", path)
+ self.install_flags()
)
later_install = first_install + (
"pip-command=",
"pip-faster",
"install",
"--upgrade", # no --prune
)
self._choose_install(first_install, later_install)
def install_packages(self, *packages):
first_install = (
Argv(
self.python,
six.text_type(self.script_path),
"venv=",
self.path,
"install=",
)
+ packages
)
later_install = first_install + (
"pip-command=",
"pip-faster",
"install",
"--upgrade", # no --prune
)
self._choose_install(first_install, later_install)
def _choose_install(self, first, rest):
if self._first_install:
command = first
self._first_install = False
else:
command = rest
command.check_call(stdin=DEVNULL)
def upgrade_pip(self):
"""
pip and venv-update versions are coupled, venv-update installs the latest compatible pip
"""
pass