Add Windows support and Conda package manager support

This commit is contained in:
allegroai 2019-11-15 23:24:04 +02:00
parent 741be2ae42
commit 72c26499c0
3 changed files with 93 additions and 10 deletions

View File

@ -2,20 +2,24 @@ from __future__ import unicode_literals
import json
import re
import shutil
import subprocess
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
import six
import yaml
from time import time
from attr import attrs, attrib, Factory
from pathlib2 import Path
from semantic_version import Version
from requirements import parse
from requirements.requirement import Requirement
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.session import Session
from .base import PackageManager
@ -36,7 +40,8 @@ def _package_diff(path, packages):
class CondaPip(VirtualenvPip):
def __init__(self, source=None, *args, **kwargs):
super(CondaPip, self).__init__(*args, **kwargs)
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
self.source = source
def run_with_env(self, command, output=False, **kwargs):
@ -80,7 +85,7 @@ class CondaAPI(PackageManager):
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
)
try:
output = Argv(self.conda, "--version").get_output()
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as ex:
raise CommandFailedError(
"Unable to determine conda version: {ex}, output={ex.output}".format(
@ -131,7 +136,7 @@ class CondaAPI(PackageManager):
else ("activate", self.path)
)
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if conda_env.is_file():
if conda_env.is_file() and not is_windows_platform():
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# install cuda toolkit
@ -161,6 +166,12 @@ class CondaAPI(PackageManager):
except Exception:
pass
rm_tree(self.path)
# if we failed removing the path, change it's name
if is_windows_platform() and Path(self.path).exists():
try:
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
except Exception:
pass
def _install_from_file(self, path):
"""
@ -235,9 +246,43 @@ class CondaAPI(PackageManager):
# create new environment file
conda_env = dict()
conda_env['channels'] = self.extra_channels
reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']]
reqs = []
if isinstance(requirements['pip'], six.string_types):
requirements['pip'] = requirements['pip'].split('\n')
has_torch = False
has_matplotlib = False
try:
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
except:
cuda_version = 0
for r in requirements['pip']:
marker = list(parse(r))
if marker:
m = MarkerRequirement(marker[0])
if m.req.name.lower() == 'matplotlib':
has_matplotlib = True
elif m.req.name.lower().startswith('torch'):
has_torch = True
if m.req.name.lower() in ('torch', 'pytorch'):
has_torch = True
m.req.name = 'pytorch'
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
has_torch = True
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
reqs.append(m)
pip_requirements = []
# Conda requirements Hacks:
if has_matplotlib:
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
if has_torch and cuda_version == 0:
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
while reqs:
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs]
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
@ -308,6 +353,10 @@ class CondaAPI(PackageManager):
:param kwargs: kwargs for Argv.get_output()
:return: JSON output or text output
"""
def escape_ansi(line):
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
command = Argv(*command) # type: Executable
if not raw:
command = (self.conda,) + command + ("--quiet", "--json")
@ -320,7 +369,8 @@ class CondaAPI(PackageManager):
result = e.output if hasattr(e, 'output') else ''
if raw:
return result
result = json.loads(result) if result else {}
result = json.loads(escape_ansi(result)) if result else {}
if result.get('success', False):
print('Pass')
elif result.get('error'):

View File

@ -288,10 +288,24 @@ class RequirementsManager(object):
if cuda_version and cudnn_version:
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
if not cuda_version and is_windows_platform():
try:
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
if k.startswith('CUDA_PATH_V')]
cuda_vers = max(cuda_vers)
if cuda_vers > 40:
cuda_version = cuda_vers
except:
pass
if not cuda_version:
try:
try:
output = Argv('nvcc', '--version').get_output()
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
if is_windows_platform() and 'CUDA_PATH' in os.environ:
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
output = Argv(nvcc, '--version').get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'release (.{3})', output).group(1)

View File

@ -137,6 +137,8 @@ class Argv(Executable):
"""
Returns a string of the shell command
"""
if is_windows_platform():
return self.ARGV_SEPARATOR.join(map(double_quote, self))
return self.ARGV_SEPARATOR.join(map(quote, self))
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
@ -157,6 +159,9 @@ class Argv(Executable):
return "Executing: {}".format(self.argv)
def __iter__(self):
if is_windows_platform():
return (word.as_posix().replace('/', '\\') if isinstance(word, Path) else six.text_type(word)
for word in self.argv)
return (six.text_type(word) for word in self.argv)
def __getitem__(self, item):
@ -237,7 +242,8 @@ class CommandSequence(Executable):
return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None)
def normalize(command):
return list(command) if is_windows_platform() else command.serialize()
# return list(command) if is_windows_platform() else command.serialize()
return command.serialize()
return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands))))
@ -279,8 +285,6 @@ class CommandSequence(Executable):
def pretty(self):
serialized = self.serialize()
if is_windows_platform():
return " ".join(serialized)
return serialized
@ -374,3 +378,18 @@ def quote(s):
# use single quotes, and put single quotes into double quotes
# the string $'b is then quoted as '$'"'"'b'
return "'" + s.replace("'", "'\"'\"'") + "'"
def double_quote(s):
"""
Backport of shlex.quote():
Return a shell-escaped version of the string *s*.
"""
if not s:
return "''"
if _find_unsafe(s) is None:
return s
# use single quotes, and put single quotes into double quotes
# the string $"b is then quoted as "$"""b"
return '"' + s.replace('"', '"\'\"\'"') + '"'