mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Rename trains-agent -> clearml-agent
This commit is contained in:
0
clearml_agent/helper/package/__init__.py
Normal file
0
clearml_agent/helper/package/__init__.py
Normal file
155
clearml_agent/helper/package/base.py
Normal file
155
clearml_agent/helper/package/base.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import abc
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Iterable, Union
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
|
||||
from trains_agent.helper.process import Executable, Argv, PathLike
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PackageManager(object):
|
||||
"""
|
||||
ABC for classes providing python package management interface
|
||||
"""
|
||||
|
||||
_selected_manager = None
|
||||
_cwd = None
|
||||
_pip_version = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def bin(self):
|
||||
# type: () -> PathLike
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def remove(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def install_from_file(self, path):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def freeze(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_requirements(self, requirements):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def install_packages(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
"""
|
||||
Install packages, upgrading depends on config
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _install(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
"""
|
||||
Run install command
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def uninstall_packages(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
pass
|
||||
|
||||
def upgrade_pip(self):
|
||||
result = self._install(
|
||||
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
|
||||
packages = self.run_with_env(('list',), output=True).splitlines()
|
||||
# p.split is ('pip', 'x.y.z')
|
||||
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
|
||||
if pip:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from .requirements import MarkerRequirement
|
||||
pip = pip[0][1].split('.')
|
||||
MarkerRequirement.pip_new_version = bool(int(pip[0]) >= 20)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
# type: (...) -> Executable
|
||||
return Argv(self.bin, *extra)
|
||||
|
||||
@contextmanager
|
||||
def temp_file(self, prefix, contents, suffix=".txt"):
|
||||
# type: (Union[Text, Iterable[Text]], Iterable[Text], Text) -> Text
|
||||
"""
|
||||
Write contents to a temporary file, yielding its path. Finally, delete it.
|
||||
:param prefix: file name prefix
|
||||
:param contents: text lines to write
|
||||
:param suffix: file name suffix
|
||||
"""
|
||||
f, temp_path = mkstemp(suffix=suffix, prefix=prefix)
|
||||
with f:
|
||||
f.write(
|
||||
contents
|
||||
if isinstance(contents, six.text_type)
|
||||
else join_lines(contents)
|
||||
)
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
if not self.session.debug_mode:
|
||||
safe_remove_file(temp_path)
|
||||
|
||||
def set_selected_package_manager(self):
|
||||
# set this instance as the selected package manager
|
||||
# this is helpful when we want out of context requirement installations
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
@property
|
||||
def cwd(self):
|
||||
return self._cwd
|
||||
|
||||
@cwd.setter
|
||||
def cwd(self, value):
|
||||
self._cwd = value
|
||||
|
||||
@classmethod
|
||||
def out_of_scope_install_package(cls, package_name, *args):
|
||||
if PackageManager._selected_manager is not None:
|
||||
try:
|
||||
result = PackageManager._selected_manager._install(package_name, *args)
|
||||
if result not in (0, None, True):
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def out_of_scope_freeze(cls):
|
||||
if PackageManager._selected_manager is not None:
|
||||
try:
|
||||
return PackageManager._selected_manager.freeze()
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def set_pip_version(cls, version):
|
||||
if not version:
|
||||
return
|
||||
version = version.replace(' ', '')
|
||||
if ('=' in version) or ('~' in version) or ('<' in version) or ('>' in version):
|
||||
cls._pip_version = version
|
||||
else:
|
||||
cls._pip_version = "=="+version
|
||||
|
||||
@classmethod
|
||||
def get_pip_version(cls):
|
||||
return cls._pip_version or ''
|
||||
722
clearml_agent/helper/package/conda_api.py
Normal file
722
clearml_agent/helper/package/conda_api.py
Normal file
@@ -0,0 +1,722 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
|
||||
|
||||
import six
|
||||
import yaml
|
||||
from time import time
|
||||
from attr import attrs, attrib, Factory
|
||||
from pathlib2 import Path
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.helper.package.requirements import SimpleVersion
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
from .requirements import RequirementsManager, MarkerRequirement
|
||||
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
|
||||
|
||||
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
|
||||
|
||||
|
||||
def package_set(packages):
|
||||
return set(map(package_normalize, packages))
|
||||
|
||||
|
||||
def _package_diff(path, packages):
|
||||
# type: (Union[Path, Text], Iterable[Text]) -> Set[Text]
|
||||
return package_set(Path(path).read_text().splitlines()) - package_set(packages)
|
||||
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
if not self.source:
|
||||
return super(CondaPip, self).run_with_env(command, output=output, **kwargs)
|
||||
command = CommandSequence(self.source, Argv("pip", *command))
|
||||
return (command.get_output if output else command.check_call)(
|
||||
stdin=DEVNULL, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class CondaAPI(PackageManager):
|
||||
|
||||
"""
|
||||
A programmatic interface for controlling conda
|
||||
"""
|
||||
|
||||
MINIMUM_VERSION = "4.3.30"
|
||||
|
||||
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
|
||||
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
|
||||
"""
|
||||
:param python: base python version to use (e.g python3.6)
|
||||
:param path: path of env
|
||||
"""
|
||||
self.session = session
|
||||
self.python = python
|
||||
self.source = None
|
||||
self.requirements_manager = requirements_manager
|
||||
self.path = path
|
||||
self.env_read_only = False
|
||||
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
|
||||
self.conda_env_as_base_docker = \
|
||||
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
|
||||
bool(ENV_CONDA_ENV_PACKAGE.get())
|
||||
if ENV_CONDA_ENV_PACKAGE.get():
|
||||
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
|
||||
else:
|
||||
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
try:
|
||||
self.conda = (
|
||||
find_executable("conda") or
|
||||
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
|
||||
shell=select_for_platform(windows=True, linux=False)).strip()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("ERROR: package manager \"conda\" selected, "
|
||||
"but \'conda\' executable could not be located")
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
raise CommandFailedError(
|
||||
"Unable to determine conda version: {ex}, output={ex.output}".format(
|
||||
ex=ex
|
||||
)
|
||||
)
|
||||
self.conda_version = self.get_conda_version(output)
|
||||
if SimpleVersion.compare_versions(self.conda_version, '<', self.MINIMUM_VERSION):
|
||||
raise CommandFailedError(
|
||||
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
|
||||
self.conda_version, self.MINIMUM_VERSION
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_conda_version(output):
|
||||
match = re.search(r"(\d+\.){0,2}\d+", output)
|
||||
if not match:
|
||||
raise CommandFailedError("Unidentified conda version string:", output)
|
||||
return match.group(0)
|
||||
|
||||
@property
|
||||
def bin(self):
|
||||
return self.pip.bin
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
def upgrade_pip(self):
|
||||
# do not change pip version if pre built environement is used
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping pip upgrade.')
|
||||
return ''
|
||||
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create a new environment
|
||||
"""
|
||||
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
|
||||
if Path(self.conda_pre_build_env_path).is_dir():
|
||||
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
self.path = Path(self.conda_pre_build_env_path)
|
||||
self.source = ("conda", "activate", self.path.as_posix())
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
self.env_read_only = True
|
||||
return self
|
||||
elif Path(self.conda_pre_build_env_path).is_file():
|
||||
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
tar_path = find_executable("tar")
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
output = Argv(
|
||||
tar_path,
|
||||
"-xzf",
|
||||
self.conda_pre_build_env_path,
|
||||
"-C",
|
||||
self.path,
|
||||
).get_output()
|
||||
|
||||
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
# unpack cleanup
|
||||
print("Fixing prefix in Conda environment {}".format(self.path))
|
||||
CommandSequence(('source', conda_env.as_posix()),
|
||||
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
|
||||
return self
|
||||
else:
|
||||
raise ValueError("Could not restore Conda environment, cannot find {}".format(
|
||||
self.conda_pre_build_env_path))
|
||||
|
||||
output = Argv(
|
||||
self.conda,
|
||||
"create",
|
||||
"--yes",
|
||||
"--mkdir",
|
||||
"--prefix",
|
||||
self.path,
|
||||
"python={}".format(self.python),
|
||||
).get_output(stderr=DEVNULL)
|
||||
match = re.search(
|
||||
r"\W*(.*activate) ({})".format(re.escape(str(self.path))), output
|
||||
)
|
||||
self.source = self.pip.source = (
|
||||
tuple(match.group(1).split()) + (match.group(2),)
|
||||
if match
|
||||
else ("conda", "activate", self.path.as_posix())
|
||||
)
|
||||
|
||||
conda_env = self._get_conda_sh()
|
||||
if conda_env.is_file() and not is_windows_platform():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
|
||||
if cuda_version > 0:
|
||||
self._install('cudatoolkit={:.1f}'.format(cuda_version))
|
||||
except Exception:
|
||||
pass
|
||||
return self
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Delete a conda environment.
|
||||
Use 'conda env remove', then 'rm_tree' to be safe.
|
||||
|
||||
Conda seems to load "vcruntime140.dll" from all its environment on startup.
|
||||
This means environment have to be deleted using 'conda env remove'.
|
||||
If necessary, conda can be fooled into deleting a partially-deleted environment by creating an empty file
|
||||
in '<ENV>\conda-meta\history' (value found in 'conda.gateways.disk.test.PREFIX_MAGIC_FILE').
|
||||
Otherwise, it complains that said directory is not a conda environment.
|
||||
|
||||
See: https://github.com/conda/conda/issues/7682
|
||||
"""
|
||||
try:
|
||||
self._run_command(("env", "remove", "-p", self.path))
|
||||
except Exception:
|
||||
pass
|
||||
rm_tree(self.path)
|
||||
# if we failed removing the path, change it's name
|
||||
if is_windows_platform() and Path(self.path).exists():
|
||||
try:
|
||||
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _install_from_file(self, path):
|
||||
"""
|
||||
Install packages from requirement file.
|
||||
"""
|
||||
self._install("--file", path)
|
||||
|
||||
def _install(self, *args):
|
||||
# type: (*PathLike) -> ()
|
||||
# if we are in read only mode, do not install anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
|
||||
return
|
||||
channels_args = tuple(
|
||||
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
|
||||
)
|
||||
self._run_command(("install", "-p", self.path) + channels_args + args)
|
||||
|
||||
def _get_pip_packages(self, packages):
|
||||
# type: (Iterable[Text]) -> Sequence[Text]
|
||||
"""
|
||||
Return subset of ``packages`` which are not available on conda
|
||||
"""
|
||||
pips = []
|
||||
while True:
|
||||
with self.temp_file("conda_reqs", packages) as path:
|
||||
try:
|
||||
self._install_from_file(path)
|
||||
except PackageNotFoundError as e:
|
||||
pips.append(e.pkg)
|
||||
packages = _package_diff(path, {e.pkg})
|
||||
else:
|
||||
break
|
||||
return pips
|
||||
|
||||
def install_packages(self, *packages):
|
||||
# type: (*Text) -> ()
|
||||
return self._install(*packages)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
|
||||
return ''
|
||||
return self._run_command(("uninstall", "-p", self.path))
|
||||
|
||||
def install_from_file(self, path):
|
||||
"""
|
||||
Try to install packages from conda. Install packages which are not available from conda with pip.
|
||||
"""
|
||||
try:
|
||||
self._install_from_file(path)
|
||||
return
|
||||
except PackageNotFoundError as e:
|
||||
pip_packages = [e.pkg]
|
||||
except PackagesNotFoundError as e:
|
||||
pip_packages = package_set(e.packages)
|
||||
with self.temp_file("conda_reqs", _package_diff(path, pip_packages)) as reqs:
|
||||
self.install_from_file(reqs)
|
||||
with self.temp_file("pip_reqs", pip_packages) as reqs:
|
||||
self.pip.install_from_file(reqs)
|
||||
|
||||
def freeze(self, freeze_full_environment=False):
|
||||
requirements = self.pip.freeze()
|
||||
req_lines = []
|
||||
conda_lines = []
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
pip_lines = requirements['pip']
|
||||
conda_packages_json = json.loads(
|
||||
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
for r in conda_packages_json:
|
||||
# check if this is a pypi package, if it is, leave it outside
|
||||
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||
name = (r['name'].replace('-', '_'), r['name'])
|
||||
pip_req_line = [l for l in pip_lines
|
||||
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
|
||||
if pip_req_line and \
|
||||
('@' not in pip_req_line[0] or
|
||||
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
|
||||
req_lines.append(pip_req_line[0])
|
||||
continue
|
||||
|
||||
req_lines.append(
|
||||
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
|
||||
continue
|
||||
|
||||
# check if we have it in our required packages
|
||||
name = r['name']
|
||||
# hack support pytorch/torch different naming convention
|
||||
if name == 'pytorch':
|
||||
name = 'torch'
|
||||
# skip over packages with _
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
|
||||
# make sure we see the conda packages, put them into the pip as well
|
||||
if conda_lines:
|
||||
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
|
||||
|
||||
requirements['pip'] = req_lines
|
||||
requirements['conda'] = conda_lines
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if freeze_full_environment:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(
|
||||
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
|
||||
conda_env_json.pop('name', None)
|
||||
conda_env_json.pop('prefix', None)
|
||||
conda_env_json.pop('channels', None)
|
||||
requirements['conda_env_json'] = json.dumps(conda_env_json)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return requirements
|
||||
|
||||
def _load_conda_full_env(self, conda_env_dict, requirements):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except Exception:
|
||||
cuda_version = 0
|
||||
|
||||
conda_env_dict['channels'] = self.extra_channels
|
||||
if 'dependencies' not in conda_env_dict:
|
||||
conda_env_dict['dependencies'] = []
|
||||
new_dependencies = OrderedDict()
|
||||
pip_requirements = None
|
||||
for line in conda_env_dict['dependencies']:
|
||||
if isinstance(line, dict):
|
||||
pip_requirements = line.pop('pip', None)
|
||||
continue
|
||||
name = line.strip().split('=', 1)[0].lower()
|
||||
if name == 'pip':
|
||||
continue
|
||||
elif name == 'python':
|
||||
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
|
||||
elif name == 'tensorflow-gpu' and cuda_version == 0:
|
||||
line = 'tensorflow={}'.format(line.split('=')[1])
|
||||
elif name == 'tensorflow' and cuda_version > 0:
|
||||
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
|
||||
elif name in ('cupti', 'cudnn'):
|
||||
# cudatoolkit should pull them based on the cudatoolkit version
|
||||
continue
|
||||
elif name.startswith('_'):
|
||||
continue
|
||||
new_dependencies[line.split('=', 1)[0].strip()] = line
|
||||
|
||||
# fix packages:
|
||||
conda_env_dict['dependencies'] = list(new_dependencies.values())
|
||||
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if bad_req:
|
||||
print('failed installing the following conda packages: {}'.format(bad_req))
|
||||
return False
|
||||
|
||||
if pip_requirements:
|
||||
# create a list of vcs packages that we need to replace in the pip section
|
||||
vcs_reqs = {}
|
||||
if 'pip' in requirements:
|
||||
pip_lines = requirements['pip'].splitlines() \
|
||||
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
|
||||
for line in pip_lines:
|
||||
try:
|
||||
marker = list(parse(line))
|
||||
except Exception:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
if m.vcs:
|
||||
vcs_reqs[m.name] = m
|
||||
try:
|
||||
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
|
||||
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping requirements installation.')
|
||||
return None
|
||||
|
||||
# if we have a full conda environment, use it and pass the pip to pip
|
||||
if requirements.get('conda_env_json'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(requirements.get('conda_env_json'))
|
||||
print('Conda restoring full yaml environment')
|
||||
return self._load_conda_full_env(conda_env_json, requirements)
|
||||
except Exception:
|
||||
print('Could not load fully stored conda environment, falling back to requirements')
|
||||
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
reqs = []
|
||||
if isinstance(requirements['pip'], six.string_types):
|
||||
requirements['pip'] = requirements['pip'].split('\n')
|
||||
if isinstance(requirements.get('conda'), six.string_types):
|
||||
requirements['conda'] = requirements['conda'].split('\n')
|
||||
has_torch = False
|
||||
has_matplotlib = False
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except:
|
||||
cuda_version = 0
|
||||
|
||||
# notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
|
||||
# this should happen if experiment was executed on non-conda machine or old trains client
|
||||
conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda']
|
||||
conda_supported_req_names = []
|
||||
pip_requirements = []
|
||||
for r in conda_supported_req:
|
||||
try:
|
||||
marker = list(parse(r))
|
||||
except:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# conda does not support version control links
|
||||
if m.vcs:
|
||||
pip_requirements.append(m)
|
||||
continue
|
||||
# Skip over pip
|
||||
if m.name in ('pip', 'virtualenv', ):
|
||||
continue
|
||||
# python version, only major.minor
|
||||
if m.name == 'python' and m.specs:
|
||||
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
|
||||
if '.' not in m.specs[0][1]:
|
||||
continue
|
||||
|
||||
conda_supported_req_names.append(m.name.lower())
|
||||
if m.req.name.lower() == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
elif m.req.name.lower().startswith('torch'):
|
||||
has_torch = True
|
||||
|
||||
if m.req.name.lower() in ('torch', 'pytorch'):
|
||||
has_torch = True
|
||||
m.req.name = 'pytorch'
|
||||
|
||||
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
|
||||
has_torch = True
|
||||
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
|
||||
|
||||
reqs.append(m)
|
||||
|
||||
# if we have a conda list, the rest should be installed with pip,
|
||||
if requirements.get('conda', None) is not None:
|
||||
for r in requirements['pip']:
|
||||
try:
|
||||
marker = list(parse(r))
|
||||
except:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# skip over local files (we cannot change the version to a local file)
|
||||
if m.local_file:
|
||||
continue
|
||||
m_name = m.name.lower()
|
||||
if m_name in conda_supported_req_names:
|
||||
# this package is in the conda list,
|
||||
# make sure that if we changed version and we match it in conda
|
||||
## conda_supported_req_names.remove(m_name)
|
||||
for cr in reqs:
|
||||
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
|
||||
# match versions
|
||||
cr.specs = m.specs
|
||||
# # conda always likes "-" not "_" but only on pypi packages
|
||||
# cr.name = cr.name.lower().replace('_', '-')
|
||||
break
|
||||
else:
|
||||
# not in conda, it is a pip package
|
||||
pip_requirements.append(m)
|
||||
if m_name == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
|
||||
# Conda requirements Hacks:
|
||||
if has_matplotlib:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||
|
||||
# remove specific cudatoolkit, it should have being preinstalled.
|
||||
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
|
||||
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
|
||||
|
||||
if has_torch and cuda_version == 0:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||
|
||||
# make sure we have no double entries
|
||||
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
|
||||
|
||||
# conform conda packages (version/name)
|
||||
for r in reqs:
|
||||
# change _ to - in name but not the prefix _ (as this is conda prefix)
|
||||
if not r.name.startswith('_') and not requirements.get('conda', None):
|
||||
r.name = r.name.replace('_', '-')
|
||||
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
||||
if r.specs and r.specs[0]:
|
||||
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
||||
|
||||
while reqs:
|
||||
# notice, we give conda more freedom in version selection, to help it choose best combination
|
||||
def clean_ver(ar):
|
||||
if not ar.specs:
|
||||
return ar.tostr()
|
||||
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
|
||||
return ar.tostr()
|
||||
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if not bad_req:
|
||||
break
|
||||
|
||||
solved = False
|
||||
for bad_r in bad_req:
|
||||
name = bad_r.split('[')[0].split('=')[0].split('~')[0].split('<')[0].split('>')[0]
|
||||
# look for name in requirements
|
||||
for r in reqs:
|
||||
if r.name.lower() == name.lower():
|
||||
pip_requirements.append(r)
|
||||
reqs.remove(r)
|
||||
solved = True
|
||||
break
|
||||
|
||||
# we couldn't remove even one package,
|
||||
# nothing we can do but try pip
|
||||
if not solved:
|
||||
pip_requirements.extend(reqs)
|
||||
break
|
||||
|
||||
if pip_requirements:
|
||||
try:
|
||||
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
return True
|
||||
|
||||
def _parse_conda_result_bad_packges(self, result_dict):
|
||||
if not result_dict:
|
||||
return None
|
||||
|
||||
if 'bad_deps' in result_dict and result_dict['bad_deps']:
|
||||
return result_dict['bad_deps']
|
||||
|
||||
if result_dict.get('error'):
|
||||
error_lines = result_dict['error'].split('\n')
|
||||
if error_lines[0].strip().lower().startswith("unsatisfiableerror:"):
|
||||
empty_lines = [i for i, l in enumerate(error_lines) if not l.strip()]
|
||||
if len(empty_lines) >= 2:
|
||||
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
|
||||
try:
|
||||
return yaml.load('\n'.join(deps), Loader=yaml.SafeLoader)
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _run_command(self, command, raw=False, **kwargs):
|
||||
# type: (Iterable[Text], bool, Any) -> Union[Dict, Text]
|
||||
"""
|
||||
Run a conda command, returning JSON output.
|
||||
The command is prepended with 'conda' and run with JSON output flags.
|
||||
:param command: command to run
|
||||
:param raw: return text output and don't change command
|
||||
:param kwargs: kwargs for Argv.get_output()
|
||||
:return: JSON output or text output
|
||||
"""
|
||||
def escape_ansi(line):
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
|
||||
command = Argv(*command) # type: Executable
|
||||
if not raw:
|
||||
command = (self.conda,) + command + ("--quiet", "--json")
|
||||
try:
|
||||
print('Executing Conda: {}'.format(command.serialize()))
|
||||
result = command.get_output(stdin=DEVNULL, **kwargs)
|
||||
if self.session.debug_mode:
|
||||
print(result)
|
||||
except Exception as e:
|
||||
result = e.output if hasattr(e, 'output') else ''
|
||||
if self.session.debug_mode:
|
||||
print(result)
|
||||
if raw:
|
||||
raise
|
||||
if raw:
|
||||
return result
|
||||
|
||||
result = json.loads(escape_ansi(result)) if result else {}
|
||||
if result.get('success', False):
|
||||
print('Pass')
|
||||
elif result.get('error'):
|
||||
print('Conda error: {}'.format(result.get('error')))
|
||||
return result
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||
|
||||
def _get_conda_sh(self):
|
||||
# type () -> Path
|
||||
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if base_conda_env.is_file():
|
||||
return base_conda_env
|
||||
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
|
||||
conda = find_executable("conda", path=path)
|
||||
if not conda:
|
||||
continue
|
||||
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
return conda_env
|
||||
return base_conda_env
|
||||
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
|
||||
exception = attrs(str=True, cmp=False)
|
||||
|
||||
|
||||
@exception
|
||||
class CondaException(Exception, NonStrictAttrs):
|
||||
command = attrib()
|
||||
message = attrib(default=None)
|
||||
|
||||
|
||||
@exception
|
||||
class UnknownCondaError(CondaException):
|
||||
data = attrib(default=Factory(dict))
|
||||
|
||||
|
||||
@exception
|
||||
class PackagesNotFoundError(CondaException):
|
||||
"""
|
||||
Conda 4.5 exception - this reports all missing packages.
|
||||
"""
|
||||
|
||||
packages = attrib(default=())
|
||||
|
||||
|
||||
@exception
|
||||
class PackageNotFoundError(CondaException):
|
||||
"""
|
||||
Conda 4.3 exception - this reports one missing package at a time,
|
||||
as a singleton YAML list.
|
||||
"""
|
||||
|
||||
pkg = attrib(default="", converter=lambda val: yaml.load(val, Loader=yaml.SafeLoader)[0].replace(" ", ""))
|
||||
106
clearml_agent/helper/package/external_req.py
Normal file
106
clearml_agent/helper/package/external_req.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
from ..base import safe_furl as furl
|
||||
|
||||
|
||||
class ExternalRequirements(SimpleSubstitution):
|
||||
|
||||
name = "external_link"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ExternalRequirements, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = []
|
||||
self.post_install_req_lookup = OrderedDict()
|
||||
|
||||
def match(self, req):
|
||||
# match both editable or code or unparsed
|
||||
if not (not req.name or req.req and (req.req.editable or req.req.vcs)):
|
||||
return False
|
||||
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
|
||||
return False
|
||||
if req.pip_new_version and not (req.req.editable or req.req.vcs):
|
||||
return False
|
||||
return True
|
||||
|
||||
def post_install(self, session):
|
||||
post_install_req = self.post_install_req
|
||||
self.post_install_req = []
|
||||
for req in post_install_req:
|
||||
try:
|
||||
freeze_base = PackageManager.out_of_scope_freeze() or ''
|
||||
except:
|
||||
freeze_base = ''
|
||||
|
||||
req_line = req.tostr(markers=False)
|
||||
if req_line.strip().startswith('-e ') or req_line.strip().startswith('--editable'):
|
||||
req_line = re.sub(r'^(-e|--editable=?)\s*', '', req_line, count=1)
|
||||
|
||||
if req.req.vcs and req_line.startswith('git+'):
|
||||
try:
|
||||
url_no_frag = furl(req_line)
|
||||
url_no_frag.set(fragment=None)
|
||||
# reverse replace
|
||||
fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1]
|
||||
vcs_url = req_line[4:]
|
||||
# reverse replace
|
||||
vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1]
|
||||
from ..repo import Git
|
||||
vcs = Git(session=session, url=vcs_url, location=None, revision=None)
|
||||
vcs._set_ssh_url()
|
||||
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
|
||||
if new_req_line != req_line:
|
||||
furl_line = furl(new_req_line)
|
||||
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
|
||||
req_line,
|
||||
furl_line.set(password='xxxxxx').tostr() if furl_line.password else new_req_line))
|
||||
req_line = new_req_line
|
||||
except Exception:
|
||||
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
|
||||
|
||||
# if we have older pip version we have to make sure we replace back the package name with the
|
||||
# git repository link. In new versions this is supported and we get "package @ git+https://..."
|
||||
if not req.pip_new_version:
|
||||
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
freeze_post = PackageManager.out_of_scope_freeze() or ''
|
||||
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
|
||||
if package_name and package_name[0] not in self.post_install_req_lookup:
|
||||
self.post_install_req_lookup[package_name[0]] = req.req.line
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# no need to force reinstall, pip will always rebuilt if the package comes from git
|
||||
# and make sure the required packages are installed (if they are not it will install them)
|
||||
if not PackageManager.out_of_scope_install_package(req_line):
|
||||
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req.append(req)
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
|
||||
def replace_back(self, list_of_requirements):
|
||||
if not list_of_requirements:
|
||||
return list_of_requirements
|
||||
|
||||
for k in list_of_requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = list_of_requirements[k]
|
||||
list_of_requirements[k] = [r for r in original_requirements
|
||||
if r not in self.post_install_req_lookup]
|
||||
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
|
||||
for r in self.post_install_req_lookup.keys() if r in original_requirements]
|
||||
return list_of_requirements
|
||||
0
clearml_agent/helper/package/pip_api/__init__.py
Normal file
0
clearml_agent/helper/package/pip_api/__init__.py
Normal file
94
clearml_agent/helper/package/pip_api/system.py
Normal file
94
clearml_agent/helper/package/pip_api/system.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import Text, Optional
|
||||
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES, PROGRAM_NAME
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from trains_agent.session import Session
|
||||
|
||||
|
||||
class SystemPip(PackageManager):
|
||||
|
||||
indices_args = None
|
||||
|
||||
def __init__(self, interpreter=None, session=None):
|
||||
# type: (Optional[Text], Optional[Session]) -> ()
|
||||
"""
|
||||
Program interface to the system pip.
|
||||
"""
|
||||
self._bin = interpreter or sys.executable
|
||||
self.session = session
|
||||
|
||||
@property
|
||||
def bin(self):
|
||||
return self._bin
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
def remove(self):
|
||||
pass
|
||||
|
||||
def install_from_file(self, path):
|
||||
self.run_with_env(('install', '-r', path) + self.install_flags(), cwd=self.cwd)
|
||||
|
||||
def install_packages(self, *packages):
|
||||
self._install(*(packages + self.install_flags()))
|
||||
|
||||
def _install(self, *args):
|
||||
self.run_with_env(('install',) + args, cwd=self.cwd)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
self.run_with_env(('uninstall', '-y') + packages)
|
||||
|
||||
def download_package(self, package, cache_dir):
|
||||
self.run_with_env(
|
||||
(
|
||||
'download',
|
||||
package,
|
||||
'--dest', cache_dir,
|
||||
'--no-deps',
|
||||
) + self.install_flags()
|
||||
)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
requirements = requirements.get('pip') if isinstance(requirements, dict) else requirements
|
||||
if not requirements:
|
||||
return
|
||||
with self.temp_file('cached-reqs', requirements) as path:
|
||||
self.install_from_file(path)
|
||||
|
||||
def uninstall(self, package):
|
||||
self.run_with_env(('uninstall', '-y', package))
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
pip freeze to all install packages except the running program
|
||||
:return: Dict contains pip as key and pip's packages to install
|
||||
:rtype: Dict[str: List[str]]
|
||||
"""
|
||||
packages = self.run_with_env(('freeze',), output=True).splitlines()
|
||||
packages_without_program = [package for package in packages if PROGRAM_NAME not in package]
|
||||
return {'pip': packages_without_program}
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
"""
|
||||
Run a shell command using environment from a virtualenv script
|
||||
:param command: command to run
|
||||
:type command: Iterable[Text]
|
||||
:param output: return output
|
||||
:param kwargs: kwargs for get_output/check_output command
|
||||
"""
|
||||
command = self._make_command(command)
|
||||
return (command.get_output if output else command.check_call)(stdin=DEVNULL, **kwargs)
|
||||
|
||||
def _make_command(self, command):
|
||||
return Argv(self.bin, '-m', 'pip', '--disable-pip-version-check', *command)
|
||||
|
||||
def install_flags(self):
|
||||
if self.indices_args is None:
|
||||
self.indices_args = tuple(
|
||||
chain.from_iterable(('--extra-index-url', x) for x in PIP_EXTRA_INDICES)
|
||||
)
|
||||
return self.indices_args
|
||||
77
clearml_agent/helper/package/pip_api/venv.py
Normal file
77
clearml_agent/helper/package/pip_api/venv.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Any
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session
|
||||
from ..pip_api.system import SystemPip
|
||||
from ..requirements import RequirementsManager
|
||||
|
||||
|
||||
class VirtualenvPip(SystemPip, PackageManager):
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
|
||||
"""
|
||||
Program interface to virtualenv pip.
|
||||
Must be given either path to virtualenv or source command.
|
||||
Either way, ``self.source`` is exposed.
|
||||
:param session: a Session object for communication
|
||||
:param python: interpreter path
|
||||
:param path: path of virtual environment to create/manipulate
|
||||
:param python: python version
|
||||
:param interpreter: path of python interpreter
|
||||
"""
|
||||
super(VirtualenvPip, self).__init__(
|
||||
session=session,
|
||||
interpreter=interpreter or Path(
|
||||
path, select_for_platform(linux="bin/python", windows="scripts/python.exe"))
|
||||
)
|
||||
self.path = path
|
||||
self.requirements_manager = requirements_manager
|
||||
self.python = python
|
||||
|
||||
def _make_command(self, command):
|
||||
return self.session.command(self.bin, "-m", "pip", "--disable-pip-version-check", *command)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
if isinstance(requirements, dict) and requirements.get("pip"):
|
||||
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
|
||||
super(VirtualenvPip, self).load_requirements(requirements)
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def create_flags(self):
|
||||
"""
|
||||
Configurable environment creation arguments
|
||||
"""
|
||||
return Argv.conditional_flag(
|
||||
self.session.config["agent.package_manager.system_site_packages"],
|
||||
"--system-site-packages",
|
||||
)
|
||||
|
||||
def install_flags(self):
|
||||
"""
|
||||
Configurable package installation creation arguments
|
||||
"""
|
||||
return super(VirtualenvPip, self).install_flags() + Argv.conditional_flag(
|
||||
self.session.config["agent.package_manager.force_upgrade"], "--upgrade"
|
||||
)
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create virtualenv.
|
||||
Only valid if instantiated with path.
|
||||
Use self.python as self.bin does not exist.
|
||||
"""
|
||||
self.session.command(
|
||||
self.python, "-m", "virtualenv", self.path, *self.create_flags()
|
||||
).check_call()
|
||||
return self
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Delete virtualenv.
|
||||
Only valid if instantiated with path.
|
||||
"""
|
||||
rm_tree(self.path)
|
||||
139
clearml_agent/helper/package/poetry_api.py
Normal file
139
clearml_agent/helper/package/poetry_api.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
|
||||
import attr
|
||||
import sys
|
||||
import os
|
||||
from pathlib2 import Path
|
||||
from trains_agent.helper.process import Argv, DEVNULL, check_if_command_exists
|
||||
from trains_agent.session import Session, POETRY
|
||||
|
||||
|
||||
def prop_guard(prop, log_prop=None):
|
||||
assert isinstance(prop, property)
|
||||
assert not log_prop or isinstance(log_prop, property)
|
||||
|
||||
def decorator(func):
|
||||
message = "%s:%s calling {}, {} = %s".format(
|
||||
func.__name__, prop.fget.__name__
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def new_func(self, *args, **kwargs):
|
||||
prop_value = prop.fget(self)
|
||||
if log_prop:
|
||||
log_prop.fget(self).debug(
|
||||
message,
|
||||
type(self).__name__,
|
||||
"" if prop_value else " not",
|
||||
prop_value,
|
||||
)
|
||||
if prop_value:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PoetryConfig:
|
||||
|
||||
def __init__(self, session, interpreter=None):
|
||||
# type: (Session, str) -> ()
|
||||
self.session = session
|
||||
self._log = session.get_logger(__name__)
|
||||
self._python = interpreter or sys.executable
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return self._log
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.session.config["agent.package_manager.type"] == POETRY
|
||||
|
||||
_guard_enabled = prop_guard(enabled, log)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
func = kwargs.pop("func", Argv.get_output)
|
||||
kwargs.setdefault("stdin", DEVNULL)
|
||||
kwargs['env'] = deepcopy(os.environ)
|
||||
if 'VIRTUAL_ENV' in kwargs['env'] or 'CONDA_PREFIX' in kwargs['env']:
|
||||
kwargs['env'].pop('VIRTUAL_ENV', None)
|
||||
kwargs['env'].pop('CONDA_PREFIX', None)
|
||||
kwargs['env'].pop('PYTHONPATH', None)
|
||||
if hasattr(sys, "real_prefix") and hasattr(sys, "base_prefix"):
|
||||
path = ':'+kwargs['env']['PATH']
|
||||
path = path.replace(':'+sys.base_prefix, ':'+sys.real_prefix, 1)
|
||||
kwargs['env']['PATH'] = path
|
||||
|
||||
if check_if_command_exists("poetry"):
|
||||
argv = Argv("poetry", *args)
|
||||
else:
|
||||
argv = Argv(self._python, "-m", "poetry", *args)
|
||||
self.log.debug("running: %s", argv)
|
||||
return func(argv, **kwargs)
|
||||
|
||||
def _config(self, *args, **kwargs):
|
||||
return self.run("config", *args, **kwargs)
|
||||
|
||||
@_guard_enabled
|
||||
def initialize(self, cwd=None):
|
||||
if not self._initialized:
|
||||
self._initialized = True
|
||||
try:
|
||||
self._config("--local", "virtualenvs.in-project", "true", cwd=cwd)
|
||||
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
|
||||
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
|
||||
except Exception as ex:
|
||||
print("Exception: {}\nError: Failed configuring Poetry virtualenvs.in-project".format(ex))
|
||||
raise
|
||||
|
||||
def get_api(self, path):
|
||||
# type: (Path) -> PoetryAPI
|
||||
return PoetryAPI(self, path)
|
||||
|
||||
|
||||
@attr.s
|
||||
class PoetryAPI(object):
|
||||
config = attr.ib(type=PoetryConfig)
|
||||
path = attr.ib(type=Path, converter=Path)
|
||||
|
||||
INDICATOR_FILES = "pyproject.toml", "poetry.lock"
|
||||
|
||||
def install(self):
|
||||
# type: () -> bool
|
||||
if self.enabled:
|
||||
self.config.run("install", "-n", cwd=str(self.path), func=Argv.check_call)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.config.enabled and (
|
||||
any((self.path / indicator).exists() for indicator in self.INDICATOR_FILES)
|
||||
)
|
||||
|
||||
def freeze(self):
|
||||
lines = self.config.run("show", cwd=str(self.path)).splitlines()
|
||||
lines = [[p for p in line.split(' ') if p] for line in lines]
|
||||
return {"pip": [parts[0]+'=='+parts[1]+' # '+' '.join(parts[2:]) for parts in lines]}
|
||||
|
||||
def get_python_command(self, extra):
|
||||
if check_if_command_exists("poetry"):
|
||||
return Argv("poetry", "run", "python", *extra)
|
||||
else:
|
||||
return Argv(self.config._python, "-m", "poetry", "run", "python", *extra)
|
||||
|
||||
def upgrade_pip(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_selected_package_manager(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def out_of_scope_install_package(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def install_from_file(self, *args, **kwargs):
|
||||
pass
|
||||
48
clearml_agent/helper/package/post_req.py
Normal file
48
clearml_agent/helper/package/post_req.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PostRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("horovod", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PostRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = []
|
||||
# check if we need to replace the packages:
|
||||
post_packages = self.config.get('agent.package_manager.post_packages', None)
|
||||
if post_packages:
|
||||
self.__class__.name = post_packages
|
||||
post_optional_packages = self.config.get('agent.package_manager.post_optional_packages', None)
|
||||
if post_optional_packages:
|
||||
self.__class__.optional_package_names = post_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def post_install(self, session):
|
||||
for req in self.post_install_req:
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
|
||||
self.post_install_req = []
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req.append(req)
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
75
clearml_agent/helper/package/priority_req.py
Normal file
75
clearml_agent/helper/package/priority_req.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PriorityPackageRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("cython", "numpy", "setuptools", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PriorityPackageRequirement, self).__init__(*args, **kwargs)
|
||||
# check if we need to replace the packages:
|
||||
priority_packages = self.config.get('agent.package_manager.priority_packages', None)
|
||||
if priority_packages:
|
||||
self.__class__.name = priority_packages
|
||||
priority_optional_packages = self.config.get('agent.package_manager.priority_optional_packages', None)
|
||||
if priority_optional_packages:
|
||||
self.__class__.optional_package_names = priority_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if PackageManager.out_of_scope_install_package(str(req)):
|
||||
return Text(req)
|
||||
except Exception:
|
||||
pass
|
||||
return Text('')
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
|
||||
|
||||
class PackageCollectorRequirement(SimpleSubstitution):
|
||||
"""
|
||||
This RequirementSubstitution class will allow you to have multiple instances of the same
|
||||
package, it will output the last one (by order) to be actually used.
|
||||
"""
|
||||
name = tuple()
|
||||
|
||||
def __init__(self, session, collect_package):
|
||||
super(PackageCollectorRequirement, self).__init__(session)
|
||||
self._collect_packages = collect_package or tuple()
|
||||
self._last_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match package names
|
||||
return req.name and req.name.lower() in self._collect_packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
self._last_req = req.clone()
|
||||
return ''
|
||||
|
||||
def post_scan_add_req(self):
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
last_req = self._last_req
|
||||
self._last_req = None
|
||||
return last_req
|
||||
723
clearml_agent/helper/package/pytorch.py
Normal file
723
clearml_agent/helper/package/pytorch.py
Normal file
@@ -0,0 +1,723 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
import sys
|
||||
from furl import furl
|
||||
import urllib.parse
|
||||
from operator import itemgetter
|
||||
from html.parser import HTMLParser
|
||||
from typing import Text
|
||||
|
||||
import attr
|
||||
import requests
|
||||
|
||||
import six
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
|
||||
|
||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||
|
||||
|
||||
def os_to_wheel_name(x):
|
||||
return OS_TO_WHEEL_NAME[x]
|
||||
|
||||
|
||||
def fix_version(version):
|
||||
def replace(nums, prerelease):
|
||||
if prerelease:
|
||||
return "{}-{}".format(nums, prerelease)
|
||||
return nums
|
||||
|
||||
return re.sub(
|
||||
r"(\d+(?:\.\d+){,2})(?:\.(.*))?",
|
||||
lambda match: replace(*match.groups()),
|
||||
version,
|
||||
)
|
||||
|
||||
|
||||
class LinksHTMLParser(HTMLParser):
|
||||
def __init__(self):
|
||||
super(LinksHTMLParser, self).__init__()
|
||||
self.links = []
|
||||
|
||||
def handle_data(self, data):
|
||||
if data and data.strip():
|
||||
self.links += [data]
|
||||
|
||||
|
||||
@attr.s
|
||||
class PytorchWheel(object):
|
||||
os_name = attr.ib(type=str, converter=os_to_wheel_name)
|
||||
cuda_version = attr.ib(converter=lambda x: "cu{}".format(x) if x else "cpu")
|
||||
python = attr.ib(type=str, converter=lambda x: str(x).replace(".", ""))
|
||||
torch_version = attr.ib(type=str, converter=fix_version)
|
||||
|
||||
url_template = (
|
||||
"http://download.pytorch.org/whl/"
|
||||
"{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
|
||||
)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.unicode = "u" if self.python.startswith("2") else ""
|
||||
|
||||
def make_url(self):
|
||||
# type: () -> Text
|
||||
return self.url_template.format(self)
|
||||
|
||||
|
||||
class PytorchResolutionError(FatalSpecsResolutionError):
|
||||
pass
|
||||
|
||||
|
||||
class SimplePytorchRequirement(SimpleSubstitution):
|
||||
name = "torch"
|
||||
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
|
||||
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
|
||||
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
|
||||
torch_page_lookup = {
|
||||
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
|
||||
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
|
||||
90: 'https://download.pytorch.org/whl/cu90/torch_stable.html',
|
||||
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
|
||||
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
|
||||
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
|
||||
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
|
||||
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SimplePytorchRequirement, self).__init__(*args, **kwargs)
|
||||
self._matched = False
|
||||
|
||||
def match(self, req):
|
||||
# match both any of out packages
|
||||
return req.name in self.packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Get rid of +cpu +cu?? etc.
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except:
|
||||
pass
|
||||
self._matched = True
|
||||
return Text(req)
|
||||
|
||||
def matching_done(self, reqs, package_manager):
|
||||
# type: (Sequence[MarkerRequirement], object) -> ()
|
||||
if not self._matched:
|
||||
return
|
||||
# TODO: add conda channel support
|
||||
from .pip_api.system import SystemPip
|
||||
if package_manager and isinstance(package_manager, SystemPip):
|
||||
extra_url, _ = self.get_torch_page(self.cuda_version)
|
||||
package_manager.add_extra_install_flags(('-f', extra_url))
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version, nightly=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except Exception:
|
||||
cuda = 0
|
||||
|
||||
if nightly:
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch nightly CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# first check if key is valid
|
||||
if cuda in cls.torch_page_lookup:
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
|
||||
# then try a new cuda version page
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
torch_url = cls.page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
|
||||
for k in keys:
|
||||
if k <= cuda:
|
||||
return cls.torch_page_lookup[k], k
|
||||
# return default - zero
|
||||
return cls.torch_page_lookup[0], 0
|
||||
|
||||
|
||||
class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
name = "torch"
|
||||
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os_name = kwargs.pop("os_override", None)
|
||||
super(PytorchRequirement, self).__init__(*args, **kwargs)
|
||||
self.log = self._session.get_logger(__name__)
|
||||
self.package_manager = self.config["agent.package_manager.type"].lower()
|
||||
self.os = os_name or self.get_platform()
|
||||
self.cuda = "cuda{}".format(self.cuda_version).lower()
|
||||
self.python_version_string = str(self.config["agent.default_python"])
|
||||
self.python_major_minor_str = '.'.join(self.python_version_string.split('.')[:2])
|
||||
if '.' not in self.python_major_minor_str:
|
||||
raise PytorchResolutionError(
|
||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||
"must have both major and minor parts of the version (for example: '3.7')".format(
|
||||
self.python_version_string
|
||||
)
|
||||
)
|
||||
self.python = "python{}".format(self.python_major_minor_str)
|
||||
|
||||
self.exceptions = [
|
||||
PytorchResolutionError(message)
|
||||
for message in (
|
||||
None,
|
||||
'cuda version "{}" is not supported'.format(self.cuda),
|
||||
'python version "{}" is not supported'.format(
|
||||
self.python_version_string
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
self.validate_python_version()
|
||||
except PytorchResolutionError as e:
|
||||
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
|
||||
|
||||
self._original_req = []
|
||||
|
||||
@property
|
||||
def is_conda(self):
|
||||
return self.package_manager == "conda"
|
||||
|
||||
@property
|
||||
def is_pip(self):
|
||||
return not self.is_conda
|
||||
|
||||
def validate_python_version(self):
|
||||
"""
|
||||
Make sure python version has both major and minor versions as required for choosing pytorch wheel
|
||||
"""
|
||||
if self.is_pip and not self.python_major_minor_str:
|
||||
raise PytorchResolutionError(
|
||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||
"must have both major and minor parts of the version (for example: '3.7')".format(
|
||||
self.python_version_string
|
||||
)
|
||||
)
|
||||
|
||||
def match(self, req):
|
||||
return req.name in self.packages
|
||||
|
||||
@staticmethod
|
||||
def get_platform():
|
||||
if sys.platform == "linux":
|
||||
return "linux"
|
||||
if sys.platform == "win32" or sys.platform == "cygwin":
|
||||
return "windows"
|
||||
if sys.platform == "darwin":
|
||||
return "macos"
|
||||
raise RuntimeError("unrecognized OS")
|
||||
|
||||
def _get_link_from_torch_page(self, req, torch_url):
|
||||
links_parser = LinksHTMLParser()
|
||||
links_parser.feed(requests.get(torch_url, timeout=10).text)
|
||||
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
|
||||
py_ver = self.python_major_minor_str.replace('.', '')
|
||||
url = None
|
||||
last_v = None
|
||||
closest_v = None
|
||||
# search for our package
|
||||
for l in links_parser.links:
|
||||
parts = l.split('/')[-1].split('-')
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
if parts[0] != req.name:
|
||||
continue
|
||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||
# version ignore .postX suffix (treat as regular version)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = str(parts[1].split('%')[0].split('+')[0])
|
||||
except Exception:
|
||||
continue
|
||||
if len(parts) < 3 or not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if len(parts) < 5 or platform_wheel not in parts[4]:
|
||||
continue
|
||||
# update the closest matched version (from above)
|
||||
if not closest_v:
|
||||
closest_v = v
|
||||
elif SimpleVersion.compare_versions(
|
||||
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
|
||||
SimpleVersion.compare_versions(
|
||||
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
|
||||
closest_v = v
|
||||
# check if this an actual match
|
||||
if not req.compare_version(v) or \
|
||||
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||
continue
|
||||
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
last_v = v
|
||||
# if we found an exact match, use it
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if req.specs[0][0] == '==' and \
|
||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url, last_v or closest_v
|
||||
|
||||
def get_url_for_platform(self, req):
|
||||
# check if package is already installed with system packages
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self.config.get("agent.package_manager.system_site_packages", None):
|
||||
from pip._internal.commands.show import search_packages_info
|
||||
installed_torch = list(search_packages_info([req.name]))
|
||||
# notice the comparison order, the first part will make sure we have a valid installed package
|
||||
if installed_torch and installed_torch[0]['version'] and \
|
||||
req.compare_version(installed_torch[0]['version']):
|
||||
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
||||
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
|
||||
# package already installed, do nothing
|
||||
req.specs = [('==', str(installed_torch[0]['version']))]
|
||||
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# make sure we have a specific version to retrieve
|
||||
if not req.specs:
|
||||
req.specs = [('>', '0')]
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except Exception:
|
||||
pass
|
||||
op, version = req.specs[0]
|
||||
# assert op == "=="
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||
while not url and torch_url_key > 0:
|
||||
previous_cuda_key = torch_url_key
|
||||
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
|
||||
req, previous_cuda_key, closest_matched_version))
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if url:
|
||||
break
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||
# never fallback to CPU
|
||||
if torch_url_key < 1:
|
||||
print(
|
||||
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError(
|
||||
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
|
||||
else:
|
||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
torch_version=fix_version(version),
|
||||
python=self.python_major_minor_str.replace('.', ''),
|
||||
os_name=self.os,
|
||||
cuda_version=self.cuda_version,
|
||||
).make_url()
|
||||
if url:
|
||||
# normalize url (sometimes we will get ../ which we should not...
|
||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||
# print found
|
||||
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
|
||||
|
||||
self.log.debug("checking url: %s", url)
|
||||
return url, requests.head(url, timeout=10).ok
|
||||
|
||||
@staticmethod
|
||||
def match_version(req, options):
|
||||
versioned_options = sorted(
|
||||
((fix_version(key), value) for key, value in options.items()),
|
||||
key=itemgetter(0),
|
||||
reverse=True,
|
||||
)
|
||||
req.specs = [(op, fix_version(version)) for op, version in req.specs]
|
||||
|
||||
try:
|
||||
return next(
|
||||
replacement
|
||||
for version, replacement in versioned_options
|
||||
if req.compare_version(version)
|
||||
)
|
||||
except StopIteration:
|
||||
raise PytorchResolutionError(
|
||||
'Could not find wheel for "{}", '
|
||||
"Available versions: {}".format(req, list(options))
|
||||
)
|
||||
|
||||
def replace_conda(self, req):
|
||||
spec = "".join(req.specs[0]) if req.specs else ""
|
||||
if not self.cuda_version:
|
||||
return "pytorch-cpu{spec}\ntorchvision-cpu".format(spec=spec)
|
||||
return "pytorch{spec}\ntorchvision\ncuda{self.cuda_version}".format(
|
||||
self=self, spec=spec
|
||||
)
|
||||
|
||||
def _table_lookup(self, req):
|
||||
"""
|
||||
Look for pytorch wheel matching `req` in table
|
||||
:param req: python requirement
|
||||
"""
|
||||
def check(base_, key_, exception_):
|
||||
result = base_.get(key_)
|
||||
if not result:
|
||||
if key_.startswith('cuda'):
|
||||
print('Could not locate, {}'.format(exception_))
|
||||
ver = sorted([float(a.replace('cuda', '').replace('none', '0')) for a in base_.keys()], reverse=True)[0]
|
||||
key_ = 'cuda'+str(int(ver))
|
||||
result = base_.get(key_)
|
||||
print('Reverting to \"{}\"'.format(key_))
|
||||
if not result:
|
||||
raise exception_
|
||||
return result
|
||||
raise exception_
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
if self.is_conda:
|
||||
return self.replace_conda(req)
|
||||
|
||||
base = self.MAP
|
||||
for key, exception in zip((self.os, self.cuda, self.python), self.exceptions):
|
||||
base = check(base, key, exception)
|
||||
|
||||
return self.match_version(req, base).replace(" ", "\n")
|
||||
|
||||
def replace(self, req):
|
||||
try:
|
||||
new_req = self._replace(req)
|
||||
if new_req:
|
||||
self._original_req.append((req, new_req))
|
||||
return new_req
|
||||
except Exception as e:
|
||||
message = "Exception when trying to resolve python wheel"
|
||||
self.log.debug(message, exc_info=True)
|
||||
raise PytorchResolutionError("{}: {}".format(message, e))
|
||||
|
||||
def _replace(self, req):
|
||||
self.validate_python_version()
|
||||
try:
|
||||
result, ok = self.get_url_for_platform(req)
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
# try:
|
||||
# result = self._table_lookup(req)
|
||||
# except Exception as e:
|
||||
# exc = e
|
||||
# else:
|
||||
# self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
# return result
|
||||
# self.log.debug(
|
||||
# "Could not find Pytorch wheel in table, trying manually constructing URL"
|
||||
# )
|
||||
|
||||
result = ok = None
|
||||
# try:
|
||||
# result, ok = self.get_url_for_platform(req)
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
if not ok:
|
||||
if result:
|
||||
self.log.debug("URL not found: {}".format(result))
|
||||
exc = PytorchResolutionError(
|
||||
"Could not find pytorch wheel URL for: {} with cuda {} support".format(req, self.cuda_version)
|
||||
)
|
||||
# cancel exception chaining
|
||||
six.raise_from(exc, None)
|
||||
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
|
||||
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
|
||||
"""
|
||||
:param list_of_requirements: {'pip': ['a==1.0', ]}
|
||||
:return: {'pip': ['a==1.0', ]}
|
||||
"""
|
||||
if not self._original_req:
|
||||
return list_of_requirements
|
||||
try:
|
||||
for k, lines in list_of_requirements.items():
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
for i, line in enumerate(lines):
|
||||
if not line or line.lstrip().startswith('#'):
|
||||
continue
|
||||
parts = [p for p in re.split('\s|=|\.|<|>|~|!|@|#', line) if p]
|
||||
if not parts:
|
||||
continue
|
||||
for req, new_req in self._original_req:
|
||||
if req.req.name == parts[0]:
|
||||
# support for pip >= 20.1
|
||||
if '@' in line:
|
||||
# skip if we have nothing to add
|
||||
if str(req).strip() != str(new_req).strip():
|
||||
# if this is local file and use the version detection
|
||||
if req.local_file:
|
||||
lines[i] = '{}'.format(str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(str(req), str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(line, str(new_req))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
return list_of_requirements
|
||||
|
||||
MAP = {
|
||||
"windows": {
|
||||
"cuda100": {
|
||||
"python3.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-win_amd64.whl"
|
||||
},
|
||||
"python3.6": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda92": {
|
||||
"python3.7": {
|
||||
"0.4.1",
|
||||
"http://download.pytorch.org/whl/cu92/torch-0.4.1-cp37-cp37m-win_amd64.whl",
|
||||
},
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda91": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda90": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda80": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cudanone": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
},
|
||||
"macos": {
|
||||
"cuda100": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda92": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda91": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda90": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda80": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cudanone": {
|
||||
"python3.6": {"0.4.0": "torch"},
|
||||
"python3.5": {"0.4.0": "torch"},
|
||||
"python2.7": {"0.4.0": "torch"},
|
||||
},
|
||||
},
|
||||
"linux": {
|
||||
"cuda100": {
|
||||
"python3.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python3.6": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cuda92": {
|
||||
"python3.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp37-cp37m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp35-cp35m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp27-cp27mu-manylinux1_x86_64.whl"
|
||||
},
|
||||
},
|
||||
"cuda91": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl"
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl"
|
||||
},
|
||||
},
|
||||
"cuda90": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cuda80": {
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cudanone": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
604
clearml_agent/helper/package/requirements.py
Normal file
604
clearml_agent/helper/package/requirements.py
Normal file
@@ -0,0 +1,604 @@
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy, copy
|
||||
from itertools import chain, starmap
|
||||
from operator import itemgetter
|
||||
from os import path
|
||||
from typing import Text, List, Type, Optional, Tuple, Dict
|
||||
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES
|
||||
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session, normalize_cuda_version
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from .translator import RequirementsTranslator
|
||||
|
||||
|
||||
class SpecsResolutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FatalSpecsResolutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@six.python_2_unicode_compatible
|
||||
class MarkerRequirement(object):
|
||||
|
||||
# if True pip version above 20.x and with support for "package @ scheme://link"
|
||||
# default is True
|
||||
pip_new_version = True
|
||||
|
||||
def __init__(self, req): # type: (Requirement) -> None
|
||||
self.req = req
|
||||
|
||||
@property
|
||||
def marker(self):
|
||||
match = re.search(r';\s*(.*)', self.req.line)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def tostr(self, markers=True):
|
||||
if not self.uri:
|
||||
parts = [self.name or self.line]
|
||||
|
||||
if self.extras:
|
||||
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
|
||||
|
||||
if self.specifier:
|
||||
parts.append(self.format_specs())
|
||||
elif self.vcs:
|
||||
# leave the line as is, let pip handle it
|
||||
if self.line:
|
||||
return self.line
|
||||
else:
|
||||
# let's build the line manually
|
||||
parts = [
|
||||
self.uri,
|
||||
'@{}'.format(self.revision) if self.revision else '',
|
||||
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
|
||||
]
|
||||
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
|
||||
# package @ file:///example.com/somewheel.whl
|
||||
# leave the line as is, let pip handle it
|
||||
return self.line
|
||||
else:
|
||||
parts = [self.uri]
|
||||
|
||||
if markers and self.marker:
|
||||
parts.append('; {0}'.format(self.marker))
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
def clone(self):
|
||||
return MarkerRequirement(copy(self.req))
|
||||
|
||||
__str__ = tostr
|
||||
|
||||
def __repr__(self):
|
||||
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
||||
|
||||
def format_specs(self, num_parts=None, max_num_parts=None):
|
||||
max_num_parts = max_num_parts or num_parts
|
||||
if max_num_parts is None or not self.specs:
|
||||
return ','.join(starmap(operator.add, self.specs))
|
||||
|
||||
op, version = self.specs[0]
|
||||
for v in self._sub_versions_pep440:
|
||||
version = version.replace(v, '.')
|
||||
if num_parts:
|
||||
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
||||
else:
|
||||
version = version.strip('.').split('.')[:max_num_parts]
|
||||
return op+'.'.join(version)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.req, item)
|
||||
|
||||
@property
|
||||
def specs(self): # type: () -> List[Tuple[Text, Text]]
|
||||
return self.req.specs
|
||||
|
||||
@specs.setter
|
||||
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
|
||||
self.req.specs = value
|
||||
|
||||
def fix_specs(self):
|
||||
def solve_by(func, op_is, specs):
|
||||
return func([(op, version) for op, version in specs if op == op_is])
|
||||
|
||||
def solve_equal(specs):
|
||||
if len(set(version for _, version in self.specs)) > 1:
|
||||
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
|
||||
return specs
|
||||
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
|
||||
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
|
||||
equal = solve_by(solve_equal, '==', self.specs)
|
||||
if equal:
|
||||
self.specs = equal
|
||||
else:
|
||||
self.specs = greater + smaller
|
||||
|
||||
def compare_version(self, requested_version, op=None, num_parts=3):
|
||||
"""
|
||||
compare the requested version with the one we have in the spec,
|
||||
If the requested version is 1.2.3 the self.spec should be 1.2.3*
|
||||
If the requested version is 1.2 the self.spec should be 1.2*
|
||||
etc.
|
||||
|
||||
:param str requested_version:
|
||||
:param str op: '==', '>', '>=', '<=', '<', '~='
|
||||
:param int num_parts: number of parts to compare
|
||||
:return: True if we answer the requested version
|
||||
"""
|
||||
# if we have no specific version, we cannot compare, so assume it's okay
|
||||
if not self.specs:
|
||||
return True
|
||||
|
||||
version = self.specs[0][1]
|
||||
op = (op or self.specs[0][0]).strip()
|
||||
|
||||
return SimpleVersion.compare_versions(
|
||||
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
|
||||
|
||||
|
||||
class SimpleVersion:
|
||||
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
|
||||
VERSION_PATTERN = r"""
|
||||
v?
|
||||
(?:
|
||||
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||
(?P<pre> # pre-release
|
||||
[-_\.]?
|
||||
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||
[-_\.]?
|
||||
(?P<pre_n>[0-9]+)?
|
||||
)?
|
||||
(?P<post> # post release
|
||||
(?:-(?P<post_n1>[0-9]+))
|
||||
|
|
||||
(?:
|
||||
[-_\.]?
|
||||
(?P<post_l>post|rev|r)
|
||||
[-_\.]?
|
||||
(?P<post_n2>[0-9]+)?
|
||||
)
|
||||
)?
|
||||
(?P<dev> # dev release
|
||||
[-_\.]?
|
||||
(?P<dev_l>dev)
|
||||
[-_\.]?
|
||||
(?P<dev_n>[0-9]+)?
|
||||
)?
|
||||
)
|
||||
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||
"""
|
||||
_local_version_separators = re.compile(r"[\._-]")
|
||||
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||
|
||||
@classmethod
|
||||
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
|
||||
"""
|
||||
Compare two versions based on the op operator
|
||||
returns bool(version_a op version_b)
|
||||
Notice: Ignores a/b/rc/post/dev markers on the version
|
||||
|
||||
:param str version_a:
|
||||
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
|
||||
:param str version_b:
|
||||
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
||||
(ignore a/b/rc/post/dev in the comparison)
|
||||
:param int num_parts: number of parts to compare, split by . (dot)
|
||||
:return bool: version_a op version_b
|
||||
"""
|
||||
|
||||
if not version_b:
|
||||
return True
|
||||
|
||||
if op == '~=':
|
||||
num_parts = max(num_parts, 2)
|
||||
op = '=='
|
||||
ignore_sub_versions = True
|
||||
elif op == '===':
|
||||
op = '=='
|
||||
|
||||
try:
|
||||
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
|
||||
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
|
||||
except:
|
||||
# revert to string based
|
||||
for v in cls._sub_versions_pep440:
|
||||
version_a = version_a.replace(v, '.')
|
||||
version_b = version_b.replace(v, '.')
|
||||
|
||||
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||
version_a_key = ''
|
||||
version_b_key = ''
|
||||
for i in range(num_parts):
|
||||
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
|
||||
version_a_key += pad.format(version_a[i])
|
||||
version_b_key += pad.format(version_b[i])
|
||||
|
||||
if op == '==':
|
||||
return version_a_key == version_b_key
|
||||
if op == '<=':
|
||||
return version_a_key <= version_b_key
|
||||
if op == '>=':
|
||||
return version_a_key >= version_b_key
|
||||
if op == '>':
|
||||
return version_a_key > version_b_key
|
||||
if op == '<':
|
||||
return version_a_key < version_b_key
|
||||
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
|
||||
|
||||
@staticmethod
|
||||
def _parse_letter_version(
|
||||
letter, # type: str
|
||||
number, # type: Union[str, bytes, SupportsInt]
|
||||
):
|
||||
# type: (...) -> Optional[Tuple[str, int]]
|
||||
|
||||
if letter:
|
||||
# We consider there to be an implicit 0 in a pre-release if there is
|
||||
# not a numeral associated with it.
|
||||
if number is None:
|
||||
number = 0
|
||||
|
||||
# We normalize any letters to their lower case form
|
||||
letter = letter.lower()
|
||||
|
||||
# We consider some words to be alternate spellings of other words and
|
||||
# in those cases we want to normalize the spellings to our preferred
|
||||
# spelling.
|
||||
if letter == "alpha":
|
||||
letter = "a"
|
||||
elif letter == "beta":
|
||||
letter = "b"
|
||||
elif letter in ["c", "pre", "preview"]:
|
||||
letter = "rc"
|
||||
elif letter in ["rev", "r"]:
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
if not letter and number:
|
||||
# We assume if we are given a number, but we are not given a letter
|
||||
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
|
||||
return ()
|
||||
|
||||
@staticmethod
|
||||
def _get_match_key(match, num_parts, ignore_sub_versions):
|
||||
if ignore_sub_versions:
|
||||
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||
(), (), (), (),)
|
||||
return (
|
||||
int(match.group("epoch")) if match.group("epoch") else 0,
|
||||
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
||||
SimpleVersion._parse_letter_version(
|
||||
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
||||
),
|
||||
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
||||
SimpleVersion._parse_local_version(match.group("local")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_local_version(local):
|
||||
# type: (str) -> Optional[LocalType]
|
||||
"""
|
||||
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||
"""
|
||||
if local is not None:
|
||||
return tuple(
|
||||
part.lower() if not part.isdigit() else int(part)
|
||||
for part in SimpleVersion._local_version_separators.split(local)
|
||||
)
|
||||
return ()
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class RequirementSubstitution(object):
|
||||
|
||||
_pip_extra_index_url = PIP_EXTRA_INDICES
|
||||
|
||||
def __init__(self, session):
|
||||
# type: (Session) -> ()
|
||||
self._session = session
|
||||
self.config = session.config # type: ConfigTree
|
||||
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
|
||||
self.package_manager = self.config['agent.package_manager.type']
|
||||
|
||||
@abstractmethod
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
"""
|
||||
Returns whether a requirement needs to be modified by this substitution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def replace(self, req): # type: (MarkerRequirement) -> Text
|
||||
"""
|
||||
Replace a requirement
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
return None
|
||||
|
||||
def post_install(self, session):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_pip_version(cls, package):
|
||||
output = Argv(
|
||||
'pip',
|
||||
'search',
|
||||
package,
|
||||
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
|
||||
).get_output()
|
||||
# ad-hoc pattern to duplicate the behavior of the old code
|
||||
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
|
||||
|
||||
@property
|
||||
def cuda_version(self):
|
||||
return self.config['agent.cuda_version']
|
||||
|
||||
@property
|
||||
def cudnn_version(self):
|
||||
return self.config['agent.cudnn_version']
|
||||
|
||||
|
||||
class SimpleSubstitution(RequirementSubstitution):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self):
|
||||
pass
|
||||
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
return (self.name == req.name or (
|
||||
req.uri and
|
||||
re.match(r'https?://', req.uri) and
|
||||
self.name in req.uri
|
||||
))
|
||||
|
||||
def replace(self, req): # type: (MarkerRequirement) -> Text
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
if req.uri:
|
||||
return re.sub(
|
||||
r'({})(.*?)(-cp)'.format(self.name),
|
||||
r'\1\2{}\3'.format(self.suffix),
|
||||
req.uri,
|
||||
count=1)
|
||||
|
||||
if req.specs:
|
||||
_, version_number = req.specs[0]
|
||||
# assert packaging_version.parse(version_number)
|
||||
else:
|
||||
version_number = self.get_pip_version(self.name)
|
||||
|
||||
req.specs = [('==', version_number + self.suffix)]
|
||||
return Text(req)
|
||||
|
||||
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
|
||||
"""
|
||||
:param list_of_requirements: {'pip': ['a==1.0', ]}
|
||||
:return: {'pip': ['a==1.0', ]}
|
||||
"""
|
||||
return list_of_requirements
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class CudaSensitiveSubstitution(SimpleSubstitution):
|
||||
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
return self.cuda_version and self.cudnn_version and \
|
||||
super(CudaSensitiveSubstitution, self).match(req)
|
||||
|
||||
|
||||
class CudaNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequirementsManager(object):
|
||||
|
||||
def __init__(self, session, base_interpreter=None):
|
||||
# type: (Session, PathLike) -> ()
|
||||
self._session = session
|
||||
self.config = deepcopy(session.config) # type: ConfigTree
|
||||
self.handlers = [] # type: List[RequirementSubstitution]
|
||||
agent = self.config['agent']
|
||||
self.active = not agent.get('cpu_only', False)
|
||||
self.found_cuda = False
|
||||
if self.active:
|
||||
try:
|
||||
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
|
||||
self.found_cuda = True
|
||||
except Exception:
|
||||
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
|
||||
if agent.get('cuda_version'):
|
||||
self.found_cuda = True
|
||||
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
|
||||
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
|
||||
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
|
||||
cache_dir=pip_cache_dir.as_posix())
|
||||
|
||||
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
|
||||
self.handlers.append(cls(self._session))
|
||||
|
||||
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
|
||||
match = re.search(r';\s*(.*)', Text(req))
|
||||
if match:
|
||||
req.markers = match.group(1).split(',')
|
||||
if not self.active:
|
||||
return None
|
||||
for handler in self.handlers:
|
||||
if handler.match(req):
|
||||
return handler.replace(req)
|
||||
return None
|
||||
|
||||
def replace(self, requirements): # type: (Text) -> Text
|
||||
def safe_parse(req_str):
|
||||
try:
|
||||
return next(parse(req_str))
|
||||
except Exception as ex:
|
||||
return Requirement(req_str)
|
||||
|
||||
parsed_requirements = tuple(
|
||||
map(
|
||||
MarkerRequirement,
|
||||
[safe_parse(line) for line in (requirements.splitlines()
|
||||
if isinstance(requirements, six.text_type) else requirements)]
|
||||
)
|
||||
)
|
||||
if not parsed_requirements:
|
||||
# return the original requirements just in case
|
||||
return requirements
|
||||
|
||||
def replace_one(i, req):
|
||||
# type: (int, MarkerRequirement) -> Optional[Text]
|
||||
try:
|
||||
return self._replace_one(req)
|
||||
except FatalSpecsResolutionError:
|
||||
warning('could not resolve python wheel replacement for {}'.format(req))
|
||||
raise
|
||||
except Exception:
|
||||
warning('could not resolve python wheel replacement for \"{}\", '
|
||||
'using original requirements line: {}'.format(req, i))
|
||||
return None
|
||||
|
||||
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
|
||||
conda = is_conda(self.config)
|
||||
result = map(
|
||||
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
|
||||
new_requirements,
|
||||
parsed_requirements
|
||||
)
|
||||
if not conda:
|
||||
result = map(self.translator.translate, result)
|
||||
|
||||
result = list(result)
|
||||
# add post scan add requirements call back
|
||||
for h in self.handlers:
|
||||
req = h.post_scan_add_req()
|
||||
if req:
|
||||
result.append(req.tostr())
|
||||
|
||||
return join_lines(result)
|
||||
|
||||
def post_install(self, session):
|
||||
for h in self.handlers:
|
||||
try:
|
||||
h.post_install(session)
|
||||
except Exception as ex:
|
||||
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
|
||||
raise
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if self.translator:
|
||||
requirements = self.translator.replace_back(requirements)
|
||||
|
||||
for h in self.handlers:
|
||||
try:
|
||||
requirements = h.replace_back(requirements)
|
||||
except Exception:
|
||||
pass
|
||||
return requirements
|
||||
|
||||
@staticmethod
|
||||
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
|
||||
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
|
||||
cuda_version = config['agent.cuda_version']
|
||||
cudnn_version = config['agent.cudnn_version']
|
||||
if cuda_version and cudnn_version:
|
||||
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
|
||||
|
||||
if not cuda_version and is_windows_platform():
|
||||
try:
|
||||
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
|
||||
if k.startswith('CUDA_PATH_V')]
|
||||
cuda_vers = max(cuda_vers)
|
||||
if cuda_vers > 40:
|
||||
cuda_version = cuda_vers
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
|
||||
if is_windows_platform() and 'CUDA_PATH' in os.environ:
|
||||
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
|
||||
|
||||
output = Argv(nvcc, '--version').get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'release (.{3})', output).group(1)
|
||||
cuda_version = Text(int(float(match) * 10))
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
output = Argv('nvidia-smi',).get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
|
||||
match = match.group(1)+'.'+match.group(2)
|
||||
cuda_version = Text(int(float(match) * 10))
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cudnn_version:
|
||||
try:
|
||||
cuda_lib = which('nvcc')
|
||||
if is_windows_platform:
|
||||
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
|
||||
else:
|
||||
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
|
||||
|
||||
cudnn_major, cudnn_minor = None, None
|
||||
try:
|
||||
include_file = open(cudnn_h)
|
||||
except OSError:
|
||||
raise CudaNotFound('Could not read cudnn.h')
|
||||
with include_file:
|
||||
for line in include_file:
|
||||
if 'CUDNN_MAJOR' in line:
|
||||
cudnn_major = line.split()[-1]
|
||||
if 'CUDNN_MINOR' in line:
|
||||
cudnn_minor = line.split()[-1]
|
||||
if cudnn_major and cudnn_minor:
|
||||
break
|
||||
cudnn_version = cudnn_major + (cudnn_minor or '0')
|
||||
except:
|
||||
pass
|
||||
|
||||
return (normalize_cuda_version(cuda_version or 0),
|
||||
normalize_cuda_version(cudnn_version or 0))
|
||||
|
||||
90
clearml_agent/helper/package/translator.py
Normal file
90
clearml_agent/helper/package/translator.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Text
|
||||
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.config import Config
|
||||
from .pip_api.system import SystemPip
|
||||
|
||||
|
||||
class RequirementsTranslator(object):
|
||||
|
||||
"""
|
||||
Translate explicit URLs to local URLs after downloading them to cache
|
||||
"""
|
||||
|
||||
SUPPORTED_SCHEMES = ["http", "https", "ftp"]
|
||||
|
||||
def __init__(self, session, interpreter=None, cache_dir=None):
|
||||
self._session = session
|
||||
config = session.config
|
||||
self.cache_dir = cache_dir or Path(config["agent.pip_download_cache.path"]).expanduser().as_posix()
|
||||
self.enabled = config["agent.pip_download_cache.enabled"]
|
||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.config = Config()
|
||||
self.pip = SystemPip(interpreter=interpreter, session=self._session)
|
||||
self._translate_back = {}
|
||||
|
||||
def download(self, url):
|
||||
self.pip.download_package(url, cache_dir=self.cache_dir)
|
||||
|
||||
@classmethod
|
||||
def is_supported_link(cls, line):
|
||||
# type: (Text) -> bool
|
||||
"""
|
||||
Return whether requirement is a link that should be downloaded to cache
|
||||
"""
|
||||
url = furl(line)
|
||||
return (
|
||||
url.scheme
|
||||
and url.scheme.lower() in cls.SUPPORTED_SCHEMES
|
||||
and line.lstrip().lower().startswith(url.scheme.lower())
|
||||
)
|
||||
|
||||
def translate(self, line):
|
||||
"""
|
||||
If requirement is supported, download it to cache and return the download path
|
||||
"""
|
||||
if not (self.enabled and self.is_supported_link(line)):
|
||||
return line
|
||||
command = self.config.command
|
||||
command.log('Downloading "{}" to pip cache'.format(line))
|
||||
url = furl(line)
|
||||
try:
|
||||
wheel_name = url.path.segments[-1]
|
||||
except IndexError:
|
||||
command.error('Could not parse wheel name of "{}"'.format(line))
|
||||
return line
|
||||
try:
|
||||
self.download(line)
|
||||
downloaded = Path(self.cache_dir, wheel_name).expanduser().as_uri()
|
||||
except Exception:
|
||||
command.error('Could not download wheel name of "{}"'.format(line))
|
||||
return line
|
||||
|
||||
self._translate_back[str(downloaded)] = line
|
||||
|
||||
return downloaded
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if not requirements:
|
||||
return requirements
|
||||
|
||||
for k in requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = requirements[k]
|
||||
new_requirements = []
|
||||
for line in original_requirements:
|
||||
local_file = [d for d in self._translate_back.keys() if d in line]
|
||||
if local_file:
|
||||
local_file = local_file[0]
|
||||
new_requirements.append(line.replace(local_file, self._translate_back[local_file]))
|
||||
else:
|
||||
new_requirements.append(line)
|
||||
|
||||
requirements[k] = new_requirements
|
||||
|
||||
return requirements
|
||||
115
clearml_agent/helper/package/venv_update_api.py
Normal file
115
clearml_agent/helper/package/venv_update_api.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Optional, Text
|
||||
|
||||
import requests
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import CONFIG_DIR
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
|
||||
|
||||
class VenvUpdateAPI(VirtualenvPip):
|
||||
URL_FILE_PATH = Path(CONFIG_DIR, "venv-update-url.txt")
|
||||
SCRIPT_PATH = Path(CONFIG_DIR, "venv-update")
|
||||
|
||||
def __init__(self, url, *args, **kwargs):
|
||||
super(VenvUpdateAPI, self).__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
self._script_path = None
|
||||
self._first_install = True
|
||||
|
||||
@property
|
||||
def downloaded_venv_url(self):
|
||||
# type: () -> Optional[Text]
|
||||
try:
|
||||
return self.URL_FILE_PATH.read_text()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
@downloaded_venv_url.setter
|
||||
def downloaded_venv_url(self, value):
|
||||
self.URL_FILE_PATH.write_text(value)
|
||||
|
||||
def _check_script_validity(self, path):
|
||||
"""
|
||||
Make sure script in ``path`` is a valid python script
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
result = Argv(self.bin, path, "--version").call(
|
||||
stdout=DEVNULL, stderr=DEVNULL, stdin=DEVNULL
|
||||
)
|
||||
return result == 0
|
||||
|
||||
@property
|
||||
def script_path(self):
|
||||
# type: () -> Text
|
||||
if not self._script_path:
|
||||
self._script_path = self.SCRIPT_PATH
|
||||
if not (
|
||||
self._script_path.exists()
|
||||
and self.downloaded_venv_url
|
||||
and self.downloaded_venv_url == self.url
|
||||
and self._check_script_validity(self._script_path)
|
||||
):
|
||||
with self._script_path.open("wb") as f:
|
||||
for data in requests.get(self.url, stream=True):
|
||||
f.write(data)
|
||||
self.downloaded_venv_url = self.url
|
||||
return self._script_path
|
||||
|
||||
def install_from_file(self, path):
|
||||
first_install = (
|
||||
Argv(
|
||||
self.python,
|
||||
six.text_type(self.script_path),
|
||||
"venv=",
|
||||
"-p",
|
||||
self.python,
|
||||
self.path,
|
||||
)
|
||||
+ self.create_flags()
|
||||
+ ("install=", "-r", path)
|
||||
+ self.install_flags()
|
||||
)
|
||||
later_install = first_install + (
|
||||
"pip-command=",
|
||||
"pip-faster",
|
||||
"install",
|
||||
"--upgrade", # no --prune
|
||||
)
|
||||
self._choose_install(first_install, later_install)
|
||||
|
||||
def install_packages(self, *packages):
|
||||
first_install = (
|
||||
Argv(
|
||||
self.python,
|
||||
six.text_type(self.script_path),
|
||||
"venv=",
|
||||
self.path,
|
||||
"install=",
|
||||
)
|
||||
+ packages
|
||||
)
|
||||
later_install = first_install + (
|
||||
"pip-command=",
|
||||
"pip-faster",
|
||||
"install",
|
||||
"--upgrade", # no --prune
|
||||
)
|
||||
self._choose_install(first_install, later_install)
|
||||
|
||||
def _choose_install(self, first, rest):
|
||||
if self._first_install:
|
||||
command = first
|
||||
self._first_install = False
|
||||
else:
|
||||
command = rest
|
||||
command.check_call(stdin=DEVNULL)
|
||||
|
||||
def upgrade_pip(self):
|
||||
"""
|
||||
pip and venv-update versions are coupled, venv-update installs the latest compatible pip
|
||||
"""
|
||||
pass
|
||||
Reference in New Issue
Block a user