clearml-agent/trains_agent/helper/package/conda_api.py

436 lines
16 KiB
Python
Raw Normal View History

2019-10-25 19:28:44 +00:00
from __future__ import unicode_literals
import json
import re
import shutil
2019-10-25 19:28:44 +00:00
import subprocess
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
import six
2019-10-25 19:28:44 +00:00
import yaml
from time import time
2019-10-25 19:28:44 +00:00
from attr import attrs, attrib, Factory
from pathlib2 import Path
from packaging import version as packaging_version
2019-10-25 19:28:44 +00:00
from requirements import parse
from requirements.requirement import Requirement
2019-10-25 19:28:44 +00:00
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
2019-10-25 19:28:44 +00:00
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.session import Session
from .base import PackageManager
from .pip_api.venv import VirtualenvPip
from .requirements import RequirementsManager, MarkerRequirement
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
def package_set(packages):
return set(map(package_normalize, packages))
def _package_diff(path, packages):
# type: (Union[Path, Text], Iterable[Text]) -> Set[Text]
return package_set(Path(path).read_text().splitlines()) - package_set(packages)
class CondaPip(VirtualenvPip):
def __init__(self, source=None, *args, **kwargs):
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
2019-10-25 19:28:44 +00:00
self.source = source
def run_with_env(self, command, output=False, **kwargs):
if not self.source:
return super(CondaPip, self).run_with_env(command, output=output, **kwargs)
command = CommandSequence(self.source, Argv("pip", *command))
return (command.get_output if output else command.check_call)(
stdin=DEVNULL, **kwargs
)
class CondaAPI(PackageManager):
"""
A programmatic interface for controlling conda
"""
MINIMUM_VERSION = packaging_version.parse("4.3.30")
2019-10-25 19:28:44 +00:00
def __init__(self, session, path, python, requirements_manager):
# type: (Session, PathLike, float, RequirementsManager) -> None
"""
:param python: base python version to use (e.g python3.6)
:param path: path of env
"""
self.session = session
self.python = python
self.source = None
self.requirements_manager = requirements_manager
self.path = path
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
self.pip = CondaPip(
session=self.session,
source=self.source,
python=self.python,
requirements_manager=self.requirements_manager,
path=self.path,
)
self.conda = (
find_executable("conda")
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
)
try:
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
2019-10-25 19:28:44 +00:00
except subprocess.CalledProcessError as ex:
raise CommandFailedError(
"Unable to determine conda version: {ex}, output={ex.output}".format(
ex=ex
)
)
self.conda_version = self.get_conda_version(output)
if packaging_version.parse(self.conda_version) < self.MINIMUM_VERSION:
2019-10-25 19:28:44 +00:00
raise CommandFailedError(
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
self.conda_version, self.MINIMUM_VERSION
)
)
@staticmethod
def get_conda_version(output):
match = re.search(r"(\d+\.){0,2}\d+", output)
if not match:
raise CommandFailedError("Unidentified conda version string:", output)
return match.group(0)
@property
def bin(self):
return self.pip.bin
def upgrade_pip(self):
return self.pip.upgrade_pip()
def create(self):
"""
Create a new environment
"""
output = Argv(
self.conda,
"create",
"--yes",
"--mkdir",
"--prefix",
self.path,
"python={}".format(self.python),
).get_output(stderr=DEVNULL)
match = re.search(
r"\W*(.*activate) ({})".format(re.escape(str(self.path))), output
)
self.source = self.pip.source = (
tuple(match.group(1).split()) + (match.group(2),)
if match
else ("activate", self.path)
)
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if conda_env.is_file() and not is_windows_platform():
2019-10-25 19:28:44 +00:00
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# install cuda toolkit
try:
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
if cuda_version > 0:
self._install('cudatoolkit={:.1f}'.format(cuda_version))
except Exception:
pass
return self
def remove(self):
"""
Delete a conda environment.
Use 'conda env remove', then 'rm_tree' to be safe.
Conda seems to load "vcruntime140.dll" from all its environment on startup.
This means environment have to be deleted using 'conda env remove'.
If necessary, conda can be fooled into deleting a partially-deleted environment by creating an empty file
in '<ENV>\conda-meta\history' (value found in 'conda.gateways.disk.test.PREFIX_MAGIC_FILE').
Otherwise, it complains that said directory is not a conda environment.
See: https://github.com/conda/conda/issues/7682
"""
try:
self._run_command(("env", "remove", "-p", self.path))
except Exception:
pass
rm_tree(self.path)
# if we failed removing the path, change it's name
if is_windows_platform() and Path(self.path).exists():
try:
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
except Exception:
pass
2019-10-25 19:28:44 +00:00
def _install_from_file(self, path):
"""
Install packages from requirement file.
"""
self._install("--file", path)
def _install(self, *args):
# type: (*PathLike) -> ()
channels_args = tuple(
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
)
self._run_command(("install", "-p", self.path) + channels_args + args)
def _get_pip_packages(self, packages):
# type: (Iterable[Text]) -> Sequence[Text]
"""
Return subset of ``packages`` which are not available on conda
"""
pips = []
while True:
with self.temp_file("conda_reqs", packages) as path:
try:
self._install_from_file(path)
except PackageNotFoundError as e:
pips.append(e.pkg)
packages = _package_diff(path, {e.pkg})
else:
break
return pips
def install_packages(self, *packages):
# type: (*Text) -> ()
return self._install(*packages)
def uninstall_packages(self, *packages):
return self._run_command(("uninstall", "-p", self.path))
def install_from_file(self, path):
"""
Try to install packages from conda. Install packages which are not available from conda with pip.
"""
try:
self._install_from_file(path)
return
except PackageNotFoundError as e:
pip_packages = [e.pkg]
except PackagesNotFoundError as e:
pip_packages = package_set(e.packages)
with self.temp_file("conda_reqs", _package_diff(path, pip_packages)) as reqs:
self.install_from_file(reqs)
with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs)
def freeze(self):
2020-01-21 14:21:18 +00:00
requirements = self.pip.freeze()
try:
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
conda_packages_txt = []
requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
for pkg in conda_packages:
# skip if this is a pypi package or it is not a python package at all
if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
continue
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
requirements['conda'] = conda_packages_txt
except:
pass
return requirements
2019-10-25 19:28:44 +00:00
def load_requirements(self, requirements):
# create new environment file
conda_env = dict()
conda_env['channels'] = self.extra_channels
reqs = []
if isinstance(requirements['pip'], six.string_types):
requirements['pip'] = requirements['pip'].split('\n')
2020-01-21 14:21:18 +00:00
if isinstance(requirements.get('conda'), six.string_types):
requirements['conda'] = requirements['conda'].split('\n')
has_torch = False
has_matplotlib = False
try:
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
except:
cuda_version = 0
2020-01-21 14:21:18 +00:00
# notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
# this should happen if experiment was executed on non-conda machine or old trains client
conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda']
conda_supported_req_names = []
for r in conda_supported_req:
marker = list(parse(r))
if marker:
m = MarkerRequirement(marker[0])
2020-01-21 14:21:18 +00:00
conda_supported_req_names.append(m.name.lower())
if m.req.name.lower() == 'matplotlib':
has_matplotlib = True
elif m.req.name.lower().startswith('torch'):
has_torch = True
if m.req.name.lower() in ('torch', 'pytorch'):
has_torch = True
m.req.name = 'pytorch'
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
has_torch = True
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
2020-01-21 14:21:18 +00:00
# conda always likes - not _
m.req.name = m.req.name.replace('_', '-')
reqs.append(m)
2020-01-21 14:21:18 +00:00
2019-10-25 19:28:44 +00:00
pip_requirements = []
2020-01-21 14:21:18 +00:00
# if we have a conda list, the rest should be installed with pip,
if requirements.get('conda', None) is not None:
for r in requirements['pip']:
marker = list(parse(r))
if marker:
m = MarkerRequirement(marker[0])
if m.name.lower() not in conda_supported_req_names:
pip_requirements.append(m)
2019-10-25 19:28:44 +00:00
# Conda requirements Hacks:
if has_matplotlib:
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
if has_torch and cuda_version == 0:
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
2019-10-25 19:28:44 +00:00
while reqs:
2020-01-21 14:21:18 +00:00
# notice, we give conda more freedom in version selection, to help it choose best combination
conda_env['dependencies'] = [r.tostr().replace('==', '~=') for r in reqs]
2019-10-25 19:28:44 +00:00
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
result = self._run_command(
("env", "update", "-p", self.path, "--file", name)
)
# check if we need to remove specific packages
bad_req = self._parse_conda_result_bad_packges(result)
if not bad_req:
break
solved = False
for bad_r in bad_req:
name = bad_r.split('[')[0].split('=')[0]
# look for name in requirements
for r in reqs:
if r.name.lower() == name.lower():
pip_requirements.append(r)
reqs.remove(r)
solved = True
break
# we couldn't remove even one package,
# nothing we can do but try pip
if not solved:
pip_requirements.extend(reqs)
break
if pip_requirements:
try:
pip_req_str = [r.tostr() for r in pip_requirements]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
self.pip.load_requirements('\n'.join(pip_req_str))
except Exception as e:
print(e)
raise e
self.requirements_manager.post_install()
return True
def _parse_conda_result_bad_packges(self, result_dict):
if not result_dict:
return None
if 'bad_deps' in result_dict and result_dict['bad_deps']:
return result_dict['bad_deps']
if result_dict.get('error'):
error_lines = result_dict['error'].split('\n')
if error_lines[0].strip().lower().startswith("unsatisfiableerror:"):
empty_lines = [i for i, l in enumerate(error_lines) if not l.strip()]
if len(empty_lines) >= 2:
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
try:
2020-01-21 14:21:18 +00:00
return yaml.load('\n'.join(deps), Loader=yaml.SafeLoader)
2019-10-25 19:28:44 +00:00
except:
return None
return None
def _run_command(self, command, raw=False, **kwargs):
# type: (Iterable[Text], bool, Any) -> Union[Dict, Text]
"""
Run a conda command, returning JSON output.
The command is prepended with 'conda' and run with JSON output flags.
:param command: command to run
:param raw: return text output and don't change command
:param kwargs: kwargs for Argv.get_output()
:return: JSON output or text output
"""
def escape_ansi(line):
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
2019-10-25 19:28:44 +00:00
command = Argv(*command) # type: Executable
if not raw:
command = (self.conda,) + command + ("--quiet", "--json")
try:
print('Executing Conda: {}'.format(command.serialize()))
result = command.get_output(stdin=DEVNULL, **kwargs)
except Exception as e:
if raw:
raise
result = e.output if hasattr(e, 'output') else ''
if raw:
return result
result = json.loads(escape_ansi(result)) if result else {}
2019-10-25 19:28:44 +00:00
if result.get('success', False):
print('Pass')
elif result.get('error'):
print('Conda error: {}'.format(result.get('error')))
return result
def get_python_command(self, extra=()):
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
# enable hashing with cmp=False because pdb fails on unhashable exceptions
exception = attrs(str=True, cmp=False)
@exception
class CondaException(Exception, NonStrictAttrs):
command = attrib()
message = attrib(default=None)
@exception
class UnknownCondaError(CondaException):
data = attrib(default=Factory(dict))
@exception
class PackagesNotFoundError(CondaException):
"""
Conda 4.5 exception - this reports all missing packages.
"""
packages = attrib(default=())
@exception
class PackageNotFoundError(CondaException):
"""
Conda 4.3 exception - this reports one missing package at a time,
as a singleton YAML list.
"""
2020-01-21 14:21:18 +00:00
pkg = attrib(default="", converter=lambda val: yaml.load(val, Loader=yaml.SafeLoader)[0].replace(" ", ""))