mirror of
https://github.com/clearml/clearml-agent
synced 2025-03-13 06:58:37 +00:00
Add Windows support and Conda package manager support
This commit is contained in:
parent
741be2ae42
commit
72c26499c0
@ -2,20 +2,24 @@ from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
|
||||
|
||||
import six
|
||||
import yaml
|
||||
from time import time
|
||||
from attr import attrs, attrib, Factory
|
||||
from pathlib2 import Path
|
||||
from semantic_version import Version
|
||||
from requirements import parse
|
||||
from requirements.requirement import Requirement
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
@ -36,7 +40,8 @@ def _package_diff(path, packages):
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, **kwargs)
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
@ -80,7 +85,7 @@ class CondaAPI(PackageManager):
|
||||
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
|
||||
)
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output()
|
||||
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
raise CommandFailedError(
|
||||
"Unable to determine conda version: {ex}, output={ex.output}".format(
|
||||
@ -131,7 +136,7 @@ class CondaAPI(PackageManager):
|
||||
else ("activate", self.path)
|
||||
)
|
||||
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
if conda_env.is_file() and not is_windows_platform():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
@ -161,6 +166,12 @@ class CondaAPI(PackageManager):
|
||||
except Exception:
|
||||
pass
|
||||
rm_tree(self.path)
|
||||
# if we failed removing the path, change it's name
|
||||
if is_windows_platform() and Path(self.path).exists():
|
||||
try:
|
||||
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _install_from_file(self, path):
|
||||
"""
|
||||
@ -235,9 +246,43 @@ class CondaAPI(PackageManager):
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']]
|
||||
reqs = []
|
||||
if isinstance(requirements['pip'], six.string_types):
|
||||
requirements['pip'] = requirements['pip'].split('\n')
|
||||
has_torch = False
|
||||
has_matplotlib = False
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except:
|
||||
cuda_version = 0
|
||||
|
||||
for r in requirements['pip']:
|
||||
marker = list(parse(r))
|
||||
if marker:
|
||||
m = MarkerRequirement(marker[0])
|
||||
if m.req.name.lower() == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
elif m.req.name.lower().startswith('torch'):
|
||||
has_torch = True
|
||||
|
||||
if m.req.name.lower() in ('torch', 'pytorch'):
|
||||
has_torch = True
|
||||
m.req.name = 'pytorch'
|
||||
|
||||
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
|
||||
has_torch = True
|
||||
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
|
||||
|
||||
reqs.append(m)
|
||||
pip_requirements = []
|
||||
|
||||
# Conda requirements Hacks:
|
||||
if has_matplotlib:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||
if has_torch and cuda_version == 0:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||
|
||||
while reqs:
|
||||
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
@ -308,6 +353,10 @@ class CondaAPI(PackageManager):
|
||||
:param kwargs: kwargs for Argv.get_output()
|
||||
:return: JSON output or text output
|
||||
"""
|
||||
def escape_ansi(line):
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
|
||||
command = Argv(*command) # type: Executable
|
||||
if not raw:
|
||||
command = (self.conda,) + command + ("--quiet", "--json")
|
||||
@ -320,7 +369,8 @@ class CondaAPI(PackageManager):
|
||||
result = e.output if hasattr(e, 'output') else ''
|
||||
if raw:
|
||||
return result
|
||||
result = json.loads(result) if result else {}
|
||||
|
||||
result = json.loads(escape_ansi(result)) if result else {}
|
||||
if result.get('success', False):
|
||||
print('Pass')
|
||||
elif result.get('error'):
|
||||
|
@ -288,10 +288,24 @@ class RequirementsManager(object):
|
||||
if cuda_version and cudnn_version:
|
||||
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
|
||||
|
||||
if not cuda_version and is_windows_platform():
|
||||
try:
|
||||
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
|
||||
if k.startswith('CUDA_PATH_V')]
|
||||
cuda_vers = max(cuda_vers)
|
||||
if cuda_vers > 40:
|
||||
cuda_version = cuda_vers
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
output = Argv('nvcc', '--version').get_output()
|
||||
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
|
||||
if is_windows_platform() and 'CUDA_PATH' in os.environ:
|
||||
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
|
||||
|
||||
output = Argv(nvcc, '--version').get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'release (.{3})', output).group(1)
|
||||
|
@ -137,6 +137,8 @@ class Argv(Executable):
|
||||
"""
|
||||
Returns a string of the shell command
|
||||
"""
|
||||
if is_windows_platform():
|
||||
return self.ARGV_SEPARATOR.join(map(double_quote, self))
|
||||
return self.ARGV_SEPARATOR.join(map(quote, self))
|
||||
|
||||
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
|
||||
@ -157,6 +159,9 @@ class Argv(Executable):
|
||||
return "Executing: {}".format(self.argv)
|
||||
|
||||
def __iter__(self):
|
||||
if is_windows_platform():
|
||||
return (word.as_posix().replace('/', '\\') if isinstance(word, Path) else six.text_type(word)
|
||||
for word in self.argv)
|
||||
return (six.text_type(word) for word in self.argv)
|
||||
|
||||
def __getitem__(self, item):
|
||||
@ -237,7 +242,8 @@ class CommandSequence(Executable):
|
||||
return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None)
|
||||
|
||||
def normalize(command):
|
||||
return list(command) if is_windows_platform() else command.serialize()
|
||||
# return list(command) if is_windows_platform() else command.serialize()
|
||||
return command.serialize()
|
||||
|
||||
return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands))))
|
||||
|
||||
@ -279,8 +285,6 @@ class CommandSequence(Executable):
|
||||
|
||||
def pretty(self):
|
||||
serialized = self.serialize()
|
||||
if is_windows_platform():
|
||||
return " ".join(serialized)
|
||||
return serialized
|
||||
|
||||
|
||||
@ -374,3 +378,18 @@ def quote(s):
|
||||
# use single quotes, and put single quotes into double quotes
|
||||
# the string $'b is then quoted as '$'"'"'b'
|
||||
return "'" + s.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
|
||||
def double_quote(s):
|
||||
"""
|
||||
Backport of shlex.quote():
|
||||
Return a shell-escaped version of the string *s*.
|
||||
"""
|
||||
if not s:
|
||||
return "''"
|
||||
if _find_unsafe(s) is None:
|
||||
return s
|
||||
|
||||
# use single quotes, and put single quotes into double quotes
|
||||
# the string $"b is then quoted as "$"""b"
|
||||
return '"' + s.replace('"', '"\'\"\'"') + '"'
|
||||
|
Loading…
Reference in New Issue
Block a user