From 72c26499c0289344541089a4ef4ff3e0e2de876c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 15 Nov 2019 23:24:04 +0200 Subject: [PATCH] Add Windows support and Conda package manager support --- trains_agent/helper/package/conda_api.py | 62 +++++++++++++++++++-- trains_agent/helper/package/requirements.py | 16 +++++- trains_agent/helper/process.py | 25 ++++++++- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/trains_agent/helper/package/conda_api.py b/trains_agent/helper/package/conda_api.py index 6994732..e6d9b70 100644 --- a/trains_agent/helper/package/conda_api.py +++ b/trains_agent/helper/package/conda_api.py @@ -2,20 +2,24 @@ from __future__ import unicode_literals import json import re +import shutil import subprocess from distutils.spawn import find_executable from functools import partial from itertools import chain from typing import Text, Iterable, Union, Dict, Set, Sequence, Any +import six import yaml +from time import time from attr import attrs, attrib, Factory from pathlib2 import Path from semantic_version import Version from requirements import parse +from requirements.requirement import Requirement from trains_agent.errors import CommandFailedError -from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform +from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike from trains_agent.session import Session from .base import PackageManager @@ -36,7 +40,8 @@ def _package_diff(path, packages): class CondaPip(VirtualenvPip): def __init__(self, source=None, *args, **kwargs): - super(CondaPip, self).__init__(*args, **kwargs) + super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \ + if is_windows_platform() and kwargs.get('path') else None, **kwargs) self.source = source def run_with_env(self, command, output=False, **kwargs): @@ -80,7 +85,7 @@ class CondaAPI(PackageManager): or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip() ) try: - output = Argv(self.conda, "--version").get_output() + output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT) except subprocess.CalledProcessError as ex: raise CommandFailedError( "Unable to determine conda version: {ex}, output={ex.output}".format( @@ -131,7 +136,7 @@ class CondaAPI(PackageManager): else ("activate", self.path) ) conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh' - if conda_env.is_file(): + if conda_env.is_file() and not is_windows_platform(): self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source) # install cuda toolkit @@ -161,6 +166,12 @@ class CondaAPI(PackageManager): except Exception: pass rm_tree(self.path) + # if we failed removing the path, change it's name + if is_windows_platform() and Path(self.path).exists(): + try: + Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time())) + except Exception: + pass def _install_from_file(self, path): """ @@ -235,9 +246,43 @@ class CondaAPI(PackageManager): # create new environment file conda_env = dict() conda_env['channels'] = self.extra_channels - reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']] + reqs = [] + if isinstance(requirements['pip'], six.string_types): + requirements['pip'] = requirements['pip'].split('\n') + has_torch = False + has_matplotlib = False + try: + cuda_version = int(self.session.config.get('agent.cuda_version', 0)) + except: + cuda_version = 0 + + for r in requirements['pip']: + marker = list(parse(r)) + if marker: + m = MarkerRequirement(marker[0]) + if m.req.name.lower() == 'matplotlib': + has_matplotlib = True + elif m.req.name.lower().startswith('torch'): + has_torch = True + + if m.req.name.lower() in ('torch', 'pytorch'): + has_torch = True + m.req.name = 'pytorch' + + if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'): + has_torch = True + m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow' + + reqs.append(m) pip_requirements = [] + # Conda requirements Hacks: + if has_matplotlib: + reqs.append(MarkerRequirement(Requirement.parse('graphviz'))) + reqs.append(MarkerRequirement(Requirement.parse('kiwisolver'))) + if has_torch and cuda_version == 0: + reqs.append(MarkerRequirement(Requirement.parse('cpuonly'))) + while reqs: conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs] with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name: @@ -308,6 +353,10 @@ class CondaAPI(PackageManager): :param kwargs: kwargs for Argv.get_output() :return: JSON output or text output """ + def escape_ansi(line): + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + command = Argv(*command) # type: Executable if not raw: command = (self.conda,) + command + ("--quiet", "--json") @@ -320,7 +369,8 @@ class CondaAPI(PackageManager): result = e.output if hasattr(e, 'output') else '' if raw: return result - result = json.loads(result) if result else {} + + result = json.loads(escape_ansi(result)) if result else {} if result.get('success', False): print('Pass') elif result.get('error'): diff --git a/trains_agent/helper/package/requirements.py b/trains_agent/helper/package/requirements.py index 8a0c048..a3439fd 100644 --- a/trains_agent/helper/package/requirements.py +++ b/trains_agent/helper/package/requirements.py @@ -288,10 +288,24 @@ class RequirementsManager(object): if cuda_version and cudnn_version: return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version) + if not cuda_version and is_windows_platform(): + try: + cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys() + if k.startswith('CUDA_PATH_V')] + cuda_vers = max(cuda_vers) + if cuda_vers > 40: + cuda_version = cuda_vers + except: + pass + if not cuda_version: try: try: - output = Argv('nvcc', '--version').get_output() + nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc' + if is_windows_platform() and 'CUDA_PATH' in os.environ: + nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc) + + output = Argv(nvcc, '--version').get_output() except OSError: raise CudaNotFound('nvcc not found') match = re.search(r'release (.{3})', output).group(1) diff --git a/trains_agent/helper/process.py b/trains_agent/helper/process.py index d5feeeb..4e80e69 100644 --- a/trains_agent/helper/process.py +++ b/trains_agent/helper/process.py @@ -137,6 +137,8 @@ class Argv(Executable): """ Returns a string of the shell command """ + if is_windows_platform(): + return self.ARGV_SEPARATOR.join(map(double_quote, self)) return self.ARGV_SEPARATOR.join(map(quote, self)) def call_subprocess(self, func, censor_password=False, *args, **kwargs): @@ -157,6 +159,9 @@ class Argv(Executable): return "Executing: {}".format(self.argv) def __iter__(self): + if is_windows_platform(): + return (word.as_posix().replace('/', '\\') if isinstance(word, Path) else six.text_type(word) + for word in self.argv) return (six.text_type(word) for word in self.argv) def __getitem__(self, item): @@ -237,7 +242,8 @@ class CommandSequence(Executable): return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None) def normalize(command): - return list(command) if is_windows_platform() else command.serialize() + # return list(command) if is_windows_platform() else command.serialize() + return command.serialize() return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands)))) @@ -279,8 +285,6 @@ class CommandSequence(Executable): def pretty(self): serialized = self.serialize() - if is_windows_platform(): - return " ".join(serialized) return serialized @@ -374,3 +378,18 @@ def quote(s): # use single quotes, and put single quotes into double quotes # the string $'b is then quoted as '$'"'"'b' return "'" + s.replace("'", "'\"'\"'") + "'" + + +def double_quote(s): + """ + Backport of shlex.quote(): + Return a shell-escaped version of the string *s*. + """ + if not s: + return "''" + if _find_unsafe(s) is None: + return s + + # use single quotes, and put single quotes into double quotes + # the string $"b is then quoted as "$"""b" + return '"' + s.replace('"', '"\'\"\'"') + '"'