Compare commits

..

13 Commits

Author SHA1 Message Date
allegroai
8ce621cc44 version bump 2019-12-15 15:42:46 +02:00
allegroai
7c0a2c4d50 version bump 2019-12-15 00:04:27 +02:00
allegroai
5e063c9195 Add docker build command and improve k8s integration 2019-12-15 00:04:15 +02:00
allegroai
24329a21fe Fix docker CUDA support 2019-12-15 00:03:39 +02:00
allegroai
3a301b0b6c Improve docker support and add docker build command 2019-12-15 00:03:04 +02:00
allegroai
1f0bb4906b Improve configuration wizard 2019-12-15 00:02:04 +02:00
allegroai
88f1031e5d Sync with trains default configuration 2019-12-15 00:01:47 +02:00
allegroai
fc2842c9a2 Add initial Poetry support 2019-12-15 00:00:55 +02:00
allegroai
e9d3aab115 version bump 2019-11-23 23:39:02 +02:00
allegroai
0ed7b2a0c8 Fix support for shared cache folder between multiple nodes in the cluster 2019-11-23 23:38:36 +02:00
allegroai
bd73be928a Improve trains-agent config wizard 2019-11-23 23:37:41 +02:00
Allegro AI
79babdd149 Update README.md 2019-11-15 23:38:45 +02:00
Allegro AI
02a21ba826 Update README.md 2019-11-15 23:37:01 +02:00
12 changed files with 435 additions and 145 deletions

View File

@@ -1,5 +1,5 @@
# TRAINS Agent
## Deep Learning DevOps For Everyone
## Deep Learning DevOps For Everyone - Now supports all platforms (Linux, macOS, and Windows)
"All the Deep-Learning DevOps your research needs, and then some... Because ain't nobody got time for that"

View File

@@ -229,6 +229,9 @@ sdk {
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
default_output_uri: ""
# Development mode worker
worker {
# Status report period in seconds

View File

@@ -358,7 +358,7 @@ class ServiceCommandSection(BaseCommandSection):
**locals())
self.exit(message)
message = 'Could not find {} with name "{}"'.format(service.rstrip('s'), name)
message = 'Could not find {} with name/id "{}"'.format(service.rstrip('s'), name)
if not response:
raise NameResolutionError(message)

View File

@@ -1,19 +1,21 @@
from __future__ import print_function
from six.moves import input
from pyhocon import ConfigFactory
from pyhocon import ConfigFactory, ConfigMissingException
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from trains_agent.backend_api.session import Session
from trains_agent.backend_api.session.defs import ENV_HOST
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
description = """
Please create new credentials using the web app: {}/profile
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
In the profile page, press "Create new credentials", then press "Copy to clipboard".
Paste credentials here: """
Paste copied configuration here:
"""
def_host = 'http://localhost:8080'
try:
@@ -38,20 +40,39 @@ def main():
print('Leaving setup, feel free to edit the configuration file.')
return
print(host_description)
web_host = input_url('Web Application Host', '')
parsed_host = verify_url(web_host)
print(description, end='')
sentinel = ''
parse_input = '\n'.join(iter(input, sentinel))
credentials = None
api_host = None
web_server = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
# Take the credentials in raw form or from api section
credentials = get_parsed_field(parsed, ["credentials"])
api_host = get_parsed_field(parsed, ["api_server", "host"])
web_server = get_parsed_field(parsed, ["web_server"])
except Exception:
credentials = credentials or None
api_host = api_host or None
web_server = web_server or None
if parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
while not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse credentials, please try entering them manually.')
credentials = read_manual_credentials()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'][0:4] + "***"))
if api_host:
api_host = input_url('API Host', api_host)
else:
print(host_description)
api_host = input_url('API Host', '')
parsed_host = verify_url(api_host)
if parsed_host.netloc.startswith('demoapp.'):
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
@@ -73,61 +94,50 @@ def main():
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
elif parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
else:
api_host = ''
web_host = ''
files_host = ''
if not parsed_host.port:
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
print('Host port not detected, do you wish to use the default 8080 port n/[y]? ', end='')
replace_port = input().lower()
if not replace_port or replace_port == 'y' or replace_port == 'yes':
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
elif not replace_port or replace_port.lower() == 'n' or replace_port.lower() == 'no':
web_host = input_host_port("Web", parsed_host)
api_host = input_host_port("API", parsed_host)
files_host = input_host_port("Files", parsed_host)
if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
api_host = input_url('API Host', api_host)
web_host = input_url('Web Application Host', web_server if web_server else web_host)
files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
api_host, web_host, files_host))
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
web_host, api_host, files_host))
while True:
print(description.format(web_host), end='')
parse_input = input()
# check if these are valid credentials
credentials = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
credentials = parsed.get("credentials", None)
except Exception:
credentials = None
if not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse user credentials, try again one after the other.')
credentials = {}
# parse individual
print('Enter user access key: ', end='')
credentials['access_key'] = input()
print('Enter user secret: ', end='')
credentials['secret_key'] = input()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'], ))
from trains_agent.backend_api.session import Session
# noinspection PyBroadException
try:
print('Verifying credentials ...')
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
retry = 1
max_retries = 2
while retry <= max_retries: # Up to 2 tries by the user
if verify_credentials(api_host, credentials):
break
except Exception:
print('Error: could not verify credentials: host={} access={} secret={}'.format(
api_host, credentials['access_key'], credentials['secret_key']))
retry += 1
if retry < max_retries + 1:
credentials = read_manual_credentials()
else:
print('Exiting setup without creating configuration file')
return
# get GIT User/Pass for cloning
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
@@ -182,18 +192,72 @@ def main():
print('TRAINS-AGENT setup completed successfully.')
def verify_credentials(api_host, credentials):
"""check if the credentials are valid"""
# noinspection PyBroadException
try:
print('Verifying credentials ...')
if api_host:
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
return True
else:
print("Can't verify credentials")
return False
except Exception:
print('Error: could not verify credentials: key={} secret={}'.format(
credentials.get('access_key'), credentials.get('secret_key')))
return False
def get_parsed_field(parsed_config, fields):
"""
Parsed the value from web profile page, 'copy to clipboard' option
:param parsed_config: The parsed value from the web ui
:type parsed_config: Config object
:param fields: list of values to parse, will parse by the list order
:type fields: List[str]
:return: parsed value if found, None else
"""
try:
return parsed_config.get("api").get(fields[0])
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
if len(fields) == 1:
return parsed_config.get(fields[0])
elif len(fields) == 2:
return parsed_config.get(fields[1])
else:
return None
def read_manual_credentials():
print('Enter user access key: ', end='')
access_key = input()
print('Enter user secret: ', end='')
secret_key = input()
return {"access_key": access_key, "secret_key": secret_key}
def input_url(host_type, host=None):
while True:
print('{} configured to: [{}] '.format(host_type, host), end='')
parse_input = input()
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
break
if parse_input and verify_url(parse_input):
host = parse_input
parsed_host = verify_url(parse_input) if parse_input else None
if parse_input and parsed_host:
host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
break
return host
def input_host_port(host_type, parsed_host):
print('Enter port for {} host '.format(host_type), end='')
replace_port = input().lower()
return parsed_host.scheme + "://" + parsed_host.netloc + (':{}'.format(replace_port) if replace_port else '') + \
parsed_host.path
def verify_url(parse_input):
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):

View File

@@ -3,6 +3,7 @@ from __future__ import print_function, division, unicode_literals
import errno
import json
import logging
import os
import os.path
import re
import signal
@@ -16,7 +17,6 @@ from datetime import datetime
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from os import environ, getpid
from tempfile import gettempdir, mkdtemp
from time import sleep, time
from typing import Text, Optional, Any, Tuple
@@ -59,7 +59,7 @@ from trains_agent.helper.base import (
is_conda,
named_temporary_file,
ExecutionInfo,
HOCONEncoder, error)
HOCONEncoder, error, get_python_path)
from trains_agent.helper.console import ensure_text
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.package.conda_api import CondaAPI
@@ -78,7 +78,7 @@ from trains_agent.helper.process import (
Argv,
COMMAND_SUCCESS,
Executable,
get_bash_output, shutdown_docker_process)
get_bash_output, shutdown_docker_process, get_docker_id, commit_docker)
from trains_agent.helper.package.cython_req import CythonRequirement
from trains_agent.helper.repo import clone_repository_cached, RepoInfo, VCS
from trains_agent.helper.resource_monitor import ResourceMonitor
@@ -326,7 +326,7 @@ class Worker(ServiceCommandSection):
extra_url = [extra_url]
# put external pip url before default ones, so we first look for packages there
for e in reversed(extra_url):
PIP_EXTRA_INDICES.insert(0, e)
self._pip_extra_index_url.insert(0, e)
except Exception:
self.log.warning('Failed adding extra-index-url to pip environment: {}'.format(extra_url))
# update pip install command
@@ -339,7 +339,7 @@ class Worker(ServiceCommandSection):
)
self.pip_install_cmd = tuple(pip_install_cmd)
self.worker_id = self._session.config["agent.worker_id"] or "{}:{}".format(
self._session.config["agent.worker_name"], getpid()
self._session.config["agent.worker_name"], os.getpid()
)
self._last_stats = defaultdict(lambda: 0)
self._last_report_timestamp = psutil.time.time()
@@ -355,6 +355,7 @@ class Worker(ServiceCommandSection):
self._docker_image = None
self._docker_arguments = None
self._daemon_foreground = None
self._standalone_mode = None
def _get_requirements_manager(self, os_override=None, base_interpreter=None):
requirements_manager = RequirementsManager(
@@ -425,24 +426,27 @@ class Worker(ServiceCommandSection):
lines=['Running Task {} inside docker: {}\n'.format(task_id, task_docker_cmd)],
level="INFO")
task_docker_cmd = task_docker_cmd.split(' ')
full_docker_cmd = self.docker_image_func(docker_image=task_docker_cmd[0],
docker_arguments=task_docker_cmd[1:])
docker_image = task_docker_cmd[0]
docker_arguments = task_docker_cmd[1:]
else:
self.send_logs(task_id=task_id,
lines=['running Task {} inside default docker image: {} {}\n'.format(
task_id, self._docker_image, self._docker_arguments or '')],
level="INFO")
full_docker_cmd = self.docker_image_func(docker_image=self._docker_image,
docker_arguments=self._docker_arguments)
# Update docker command
try:
docker_cmd = ' '.join([self._docker_image] + self._docker_arguments)
self._session.send_api(
tasks_api.EditRequest(task_id, execution=dict(docker_cmd=docker_cmd), force=True))
except Exception:
pass
docker_image = self._docker_image
docker_arguments = self._docker_arguments
full_docker_cmd[-1] = full_docker_cmd[-1] + 'execute --disable-monitoring --id ' + task_id
# Update docker command
full_docker_cmd = self.docker_image_func(docker_image=docker_image, docker_arguments=docker_arguments)
try:
self._session.send_api(
tasks_api.EditRequest(task_id, force=True, execution=dict(
docker_cmd=' '.join([docker_image] + docker_arguments) if docker_arguments else docker_image)))
except Exception:
pass
full_docker_cmd[-1] = full_docker_cmd[-1] + 'execute --disable-monitoring {} --id {}'.format(
'--standalone-mode' if self._standalone_mode else '', task_id)
cmd = Argv(*full_docker_cmd)
else:
cmd = worker_args.get_argv_for_command("execute") + (
@@ -489,7 +493,7 @@ class Worker(ServiceCommandSection):
safe_remove_file(temp_stdout_name)
safe_remove_file(temp_stderr_name)
if self.docker_image_func:
shutdown_docker_process(docker_cmd_ending='--id {}\'\"'.format(task_id))
shutdown_docker_process(docker_cmd_contains='--id {}\'\"'.format(task_id))
def run_tasks_loop(self, queues, worker_params):
"""
@@ -600,6 +604,8 @@ class Worker(ServiceCommandSection):
# check if we have the latest version
start_check_update_daemon()
self._standalone_mode = kwargs.get('standalone_mode', False)
self.check(**kwargs)
self.log.debug("starting resource monitor thread")
print("Worker \"{}\" - ".format(self.worker_id), end='')
@@ -863,15 +869,21 @@ class Worker(ServiceCommandSection):
def build(
self,
task_id,
target_folder=None,
target=None,
python_version=None,
docker=None,
**_
):
if not task_id:
raise CommandFailedError("Worker build must have valid task id")
if not check_if_command_exists("virtualenv"):
raise CommandFailedError("Worker must have virtualenv installed")
self._session.print_configuration()
if docker is not False and docker is not None:
return self._build_docker(docker, target, task_id)
current_task = self._session.api_client.tasks.get_by_id(task_id)
execution = self.get_execution_info(current_task)
@@ -882,7 +894,7 @@ class Worker(ServiceCommandSection):
requirements = None
# TODO: make sure we pass the correct python_version
venv_folder, requirements_manager = self.install_virtualenv(venv_dir=target_folder,
venv_folder, requirements_manager = self.install_virtualenv(venv_dir=target,
requested_python_version=python_version)
if self._default_pip:
@@ -913,6 +925,72 @@ class Worker(ServiceCommandSection):
return 0
def _build_docker(self, docker, target, task_id):
self.temp_config_path = safe_mkstemp(
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
)
if not target:
ValueError("--target container name must be provided for docker build")
temp_config, docker_image_func = self.get_docker_config_cmd(docker)
self.dump_config(temp_config)
self.docker_image_func = docker_image_func
try:
response = get_task(self._session, task_id, only_fields=["execution.docker_cmd"])
task_docker_cmd = response.execution.docker_cmd
task_docker_cmd = task_docker_cmd.strip() if task_docker_cmd else None
except Exception:
task_docker_cmd = None
if task_docker_cmd:
print('Building Task {} inside docker: {}\n'.format(task_id, task_docker_cmd))
task_docker_cmd = task_docker_cmd.split(' ')
full_docker_cmd = self.docker_image_func(docker_image=task_docker_cmd[0],
docker_arguments=task_docker_cmd[1:])
else:
print('running Task {} inside default docker image: {} {}\n'.format(
task_id, self._docker_image, self._docker_arguments or ''))
full_docker_cmd = self.docker_image_func(docker_image=self._docker_image,
docker_arguments=self._docker_arguments)
end_of_build_marker = "build.done=true"
docker_cmd_suffix = ' build --id {} ; ' \
'echo "" >> /root/trains.conf ; ' \
'echo {} >> /root/trains.conf ; ' \
'bash'.format(task_id, end_of_build_marker)
full_docker_cmd[-1] = full_docker_cmd[-1] + docker_cmd_suffix
cmd = Argv(*full_docker_cmd)
# we will be checking the configuration file for changes
temp_config = Path(self.temp_config_path)
base_time_stamp = temp_config.stat().st_mtime
# start the docker
print('Starting docker build')
cmd.call_subprocess(subprocess.Popen)
# now we need to wait until the line shows on our configuration file.
while True:
while temp_config.stat().st_mtime == base_time_stamp:
sleep(5.0)
with open(temp_config.as_posix()) as f:
lines = [l.strip() for l in f.readlines()]
if 'build.done=true' in lines:
break
base_time_stamp = temp_config.stat().st_mtime
print('\nDocker build done')
# get the docker id.
docker_id = get_docker_id(docker_cmd_contains='--id {} '.format(task_id))
if not docker_id:
print("Error: cannot locate docker for storage")
return
print('Committing docker container to: {}'.format(target))
print(commit_docker(container_name=target, docker_id=docker_id))
shutdown_docker_process(docker_id=docker_id)
return
@resolve_names
def execute(
self,
@@ -921,13 +999,33 @@ class Worker(ServiceCommandSection):
optimization=0,
disable_monitoring=False,
full_monitoring=False,
require_queue=False,
log_file=None,
standalone_mode=None,
**_
):
if not task_id:
raise CommandFailedError("Worker execute must have valid task id")
if not check_if_command_exists("virtualenv"):
raise CommandFailedError("Worker must have virtualenv installed")
try:
current_task = self._session.api_client.tasks.get_by_id(task_id)
if not current_task.id:
pass
except Exception:
raise ValueError("Could not find task id={}".format(task_id))
# make sure this task is not stuck in an execution queue, it shouldn't have been, but just in case.
try:
res = self._session.api_client.tasks.dequeue(task=current_task.id)
if require_queue and res.meta.result_code != 200:
raise ValueError("Execution required enqueued task, "
"but task id={} is not queued.".format(current_task.id))
except Exception:
if require_queue:
raise
if full_monitoring:
worker_params = WorkerParams(
log_level=log_level,
@@ -942,13 +1040,8 @@ class Worker(ServiceCommandSection):
return
self._session.print_configuration()
current_task = self._session.api_client.tasks.get_by_id(task_id)
try:
if not current_task.id:
pass
except Exception:
raise ValueError("Could not find task id={}".format(task_id))
# now mark the task as started
self._session.api_client.tasks.started(
task=current_task.id,
status_reason="worker started execution",
@@ -966,12 +1059,13 @@ class Worker(ServiceCommandSection):
except AttributeError:
requirements = None
venv_folder, requirements_manager = self.install_virtualenv()
venv_folder, requirements_manager = self.install_virtualenv(standalone_mode=standalone_mode)
if self._default_pip:
self.package_api.install_packages(*self._default_pip)
if not standalone_mode:
if self._default_pip:
self.package_api.install_packages(*self._default_pip)
print("\n")
print("\n")
directory, vcs, repo_info = self.get_repo_info(
execution, current_task, venv_folder
@@ -979,12 +1073,14 @@ class Worker(ServiceCommandSection):
print("\n")
self.install_requirements(
execution,
repo_info,
requirements_manager=requirements_manager,
cached_requirements=requirements,
)
if not standalone_mode:
self.install_requirements(
execution,
repo_info,
requirements_manager=requirements_manager,
cached_requirements=requirements,
)
# do not update the task packages if we are using conda,
# it will most likely make the task environment unreproducible
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None)
@@ -1016,7 +1112,7 @@ class Worker(ServiceCommandSection):
"log_to_backend": "0",
"config_file": self._session.config_file, # The config file is the tmp file that trains_agent created
}
environ.update(
os.environ.update(
{
sdk_key: str(value)
for key, value in sdk_env.items()
@@ -1027,6 +1123,11 @@ class Worker(ServiceCommandSection):
if repo_info:
self._update_commit_id(task_id, execution, repo_info)
# Add the script CWD to the python path
python_path = get_python_path(script_dir, execution.entry_point, self.package_api)
if python_path:
os.environ['PYTHONPATH'] = python_path
print("Starting Task Execution:\n".format(task_id))
exit_code = -1
try:
@@ -1155,7 +1256,8 @@ class Worker(ServiceCommandSection):
)
except CommandFailedError:
raise
except Exception:
except Exception as ex:
print('Repository cloning failed: {}'.format(ex))
task.failed(
status_reason="failed cloning repository",
status_message=self._task_status_change_message,
@@ -1317,8 +1419,8 @@ class Worker(ServiceCommandSection):
self.package_api.load_requirements(cached_requirements)
except Exception as e:
self.log_traceback(e)
self.error("Could not install task requirements! Trying to install requirements from repository")
cached_requirements_failed = True
raise ValueError("Could not install task requirements!")
else:
self.log("task requirements installation passed")
return
@@ -1484,7 +1586,7 @@ class Worker(ServiceCommandSection):
)
)
def install_virtualenv(self, venv_dir=None, requested_python_version=None):
def install_virtualenv(self, venv_dir=None, requested_python_version=None, standalone_mode=False):
# type: (str, str) -> Tuple[Path, RequirementsManager]
"""
Install a new python virtual environment, removing the old one if exists
@@ -1501,7 +1603,7 @@ class Worker(ServiceCommandSection):
self._session.config.put("agent.default_python", executable_version)
self._session.config.put("agent.python_binary", executable_name)
first_time = (
first_time = not standalone_mode and (
is_windows_platform()
or self.is_conda
or not venv_dir.is_dir()
@@ -1531,6 +1633,10 @@ class Worker(ServiceCommandSection):
if first_time:
self.package_api.remove()
self.package_api.create()
elif standalone_mode:
# conda with standalone mode
get_conda = partial(CondaAPI, **package_manager_params)
self.package_api = get_conda()
else:
get_conda = partial(CondaAPI, **package_manager_params)
@@ -1577,7 +1683,8 @@ class Worker(ServiceCommandSection):
args.update(kwargs)
return self._get_docker_cmd(**args)
docker_image = str(self._session.config.get("agent.default_docker.image", "nvidia/cuda")) \
docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or os.environ.get("ALG_DOCKER_IMAGE") or
self._session.config.get("agent.default_docker.image", "nvidia/cuda")) \
if not docker_args else docker_args[0]
docker_arguments = docker_image.split(' ') if docker_image else []
if len(docker_arguments) > 1:
@@ -1590,8 +1697,8 @@ class Worker(ServiceCommandSection):
python_version = '3'
if not python_version.startswith('python'):
python_version = 'python'+python_version
print("Running in Docker mode (v19.03 and above) - using default docker image: {} running {}\n".format(
docker_image, python_version))
print("Running in Docker {} mode (v19.03 and above) - using default docker image: {} running {}\n".format(
'*standalone*' if self._standalone_mode else '', docker_image, python_version))
temp_config = self._session.config.copy()
mounted_cache_dir = '/root/.trains/cache'
mounted_pip_dl_dir = '/root/.trains/pip-download-cache'
@@ -1647,7 +1754,8 @@ class Worker(ServiceCommandSection):
host_ssh_cache=host_ssh_cache,
host_cache=host_cache, mounted_cache=mounted_cache_dir,
host_pip_dl=host_pip_dl, mounted_pip_dl=mounted_pip_dl_dir,
host_vcs_cache=host_vcs_cache, mounted_vcs_cache=mounted_vcs_cache)
host_vcs_cache=host_vcs_cache, mounted_vcs_cache=mounted_vcs_cache,
standalone_mode=self._standalone_mode)
return temp_config, partial(docker_cmd_functor, docker_cmd)
@staticmethod
@@ -1658,7 +1766,7 @@ class Worker(ServiceCommandSection):
host_ssh_cache,
host_cache, mounted_cache,
host_pip_dl, mounted_pip_dl,
host_vcs_cache, mounted_vcs_cache):
host_vcs_cache, mounted_vcs_cache, standalone_mode=False):
docker = 'docker'
base_cmd = [docker, 'run', '-t']
@@ -1680,23 +1788,41 @@ class Worker(ServiceCommandSection):
if host_ssh_cache:
base_cmd += ['-v', host_ssh_cache+':/root/.ssh', ]
# if we are running a RC version, install the same version in the docker
# because the default latest, will be a release version (not RC)
specify_version = ''
try:
from trains_agent.version import __version__
_version_parts = __version__.split('.')
if 'rc' in _version_parts[-1].lower() or 'rc' in _version_parts[-2].lower():
specify_version = '=={}'.format(__version__)
except:
pass
if standalone_mode:
update_scheme = ""
else:
update_scheme = \
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean ; " \
"chown -R root /root/.cache/pip ; " \
"apt-get update ; " \
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0 {python_single_digit}-pip ; " \
"{python} -m pip install -U pip ; " \
"{python} -m pip install -U trains-agent{specify_version} ; ".format(
python_single_digit=python_version.split('.')[0],
python=python_version, specify_version=specify_version)
base_cmd += [
'-v', conf_file+':/root/trains.conf',
'-v', host_apt_cache+':/var/cache/apt/archives',
'-v', host_pip_cache+':/root/.cache/pip',
'-v', host_pip_dl+':'+mounted_pip_dl,
'-v', host_cache+':'+mounted_cache,
'-v', host_vcs_cache+':'+mounted_vcs_cache,
'--rm', docker_image, 'bash', '-c',
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean ; "
"chown -R root /root/.cache/pip ; "
"apt-get update ; "
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0 {python_single_digit}-pip ; "
"{python} -m pip install -U pip ; "
"{python} -m pip install -U trains-agent ; "
"NVIDIA_VISIBLE_DEVICES=all CUDA_VISIBLE_DEVICES= {python} -u -m trains_agent ".format(
python_single_digit=python_version.split('.')[0],
python=python_version)]
'-v', conf_file+':/root/trains.conf',
'-v', host_apt_cache+':/var/cache/apt/archives',
'-v', host_pip_cache+':/root/.cache/pip',
'-v', host_pip_dl+':'+mounted_pip_dl,
'-v', host_cache+':'+mounted_cache,
'-v', host_vcs_cache+':'+mounted_vcs_cache,
'--rm', docker_image, 'bash', '-c',
update_scheme +
"NVIDIA_VISIBLE_DEVICES=all {python} -u -m trains_agent ".format(python=python_version)
]
return base_cmd

View File

@@ -176,6 +176,25 @@ def safe_remove_file(filename, error_message=None):
print(error_message)
def get_python_path(script_dir, entry_point, package_api):
try:
python_path_sep = ';' if is_windows_platform() else ':'
python_path_cmd = package_api.get_python_command(
["-c", "import sys; print('{}'.join(sys.path))".format(python_path_sep)])
org_python_path = python_path_cmd.get_output(cwd=script_dir)
# Add path of the script directory and executable directory
python_path = '{}{python_path_sep}{}{python_path_sep}'.format(
Path(script_dir).absolute().as_posix(),
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
python_path_sep=python_path_sep)
if is_windows_platform():
return python_path.replace('/', '\\') + org_python_path
return python_path + org_python_path
except Exception:
return None
class Singleton(ABCMeta):
_instances = {}

View File

@@ -96,3 +96,15 @@ class PoetryAPI(object):
def get_python_command(self, extra):
return Argv("poetry", "run", "python", *extra)
def upgrade_pip(self, *args, **kwargs):
pass
def set_selected_package_manager(self, *args, **kwargs):
pass
def out_of_scope_install_package(self, *args, **kwargs):
pass
def install_from_file(self, *args, **kwargs):
pass

View File

@@ -59,17 +59,47 @@ def kill_all_child_processes(pid=None):
parent.kill()
def shutdown_docker_process(docker_cmd_ending):
def get_docker_id(docker_cmd_contains):
try:
containers_running = get_bash_output(cmd='docker ps --no-trunc --format \"{{.ID}}: {{.Command}}\"')
for docker_line in containers_running.split('\n'):
parts = docker_line.split(':')
if parts[-1].endswith(docker_cmd_ending):
# we found our docker, stop it
get_bash_output(cmd='docker stop -t 1 {}'.format(parts[0]))
return
if docker_cmd_contains in parts[-1]:
# we found our docker, return it
return parts[0]
except Exception:
pass
return None
def shutdown_docker_process(docker_cmd_contains=None, docker_id=None):
try:
if not docker_id:
docker_id = get_docker_id(docker_cmd_contains=docker_cmd_contains)
if docker_id:
# we found our docker, stop it
get_bash_output(cmd='docker stop -t 1 {}'.format(docker_id))
except Exception:
pass
def commit_docker(container_name, docker_cmd_contains=None, docker_id=None):
try:
if not docker_id:
docker_id = get_docker_id(docker_cmd_contains=docker_cmd_contains)
if not docker_id:
print("Failed locating requested docker")
return False
if docker_id:
# we found our docker, stop it
output = get_bash_output(cmd='docker commit {} {}'.format(docker_id, container_name))
return output
except Exception:
pass
print("Failed storing requested docker")
return False
def check_if_command_exists(cmd):

View File

@@ -263,8 +263,9 @@ class VCS(object):
"""
self._set_ssh_url()
clone_command = ("clone", self.url_with_auth, self.location) + self.clone_flags
if branch:
clone_command += ("-b", branch)
# clone all branches regardless of when we want to later checkout
# if branch:
# clone_command += ("-b", branch)
if self.session.debug_mode:
self.call(*clone_command)
return
@@ -453,13 +454,13 @@ class Git(VCS):
)
def pull(self):
self.call("fetch", "origin", cwd=self.location)
self.call("fetch", "--all", cwd=self.location)
info_commands = dict(
url=Argv("git", "remote", "get-url", "origin"),
branch=Argv("git", "rev-parse", "--abbrev-ref", "HEAD"),
commit=Argv("git", "rev-parse", "HEAD"),
root=Argv("git", "rev-parse", "--show-toplevel"),
url=Argv(executable_name, "ls-remote", "--get-url", "origin"),
branch=Argv(executable_name, "rev-parse", "--abbrev-ref", "HEAD"),
commit=Argv(executable_name, "rev-parse", "HEAD"),
root=Argv(executable_name, "rev-parse", "--show-toplevel"),
)
patch_base = ("apply",)
@@ -493,10 +494,10 @@ class Hg(VCS):
)
info_commands = dict(
url=Argv("hg", "paths", "--verbose"),
branch=Argv("hg", "--debug", "id", "-b"),
commit=Argv("hg", "--debug", "id", "-i"),
root=Argv("hg", "root"),
url=Argv(executable_name, "paths", "--verbose"),
branch=Argv(executable_name, "--debug", "id", "-b"),
commit=Argv(executable_name, "--debug", "id", "-i"),
root=Argv(executable_name, "root"),
)
@@ -537,8 +538,6 @@ def clone_repository_cached(session, execution, destination):
vcs.clone() # branch=execution.branch)
vcs.pull()
vcs.checkout()
rm_tree(destination)
shutil.copytree(Text(cached_repo_path), Text(clone_folder))
if not clone_folder.is_dir():
@@ -548,6 +547,10 @@ def clone_repository_cached(session, execution, destination):
)
)
# checkout in the newly copy destination
vcs.location = Text(clone_folder)
vcs.checkout()
repo_info = vcs.get_repository_copy_info(clone_folder)
# make sure we have no user/pass in the returned repository structure

View File

@@ -49,7 +49,7 @@ DAEMON_ARGS = dict({
'--docker': {
'help': 'Run execution task inside a docker (v19.03 and above). Optional args <image> <arguments> or '
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
'set NVIDIA_VISIBLE_DEVICES to limit gpu visibility for docker',
'use --gpus/--cpu-only (or set NVIDIA_VISIBLE_DEVICES) to limit gpu visibility for docker',
'nargs': '*',
'default': False,
},
@@ -60,6 +60,11 @@ DAEMON_ARGS = dict({
'dest': 'queues',
'type': foreign_object_id('queues'),
},
'--standalone-mode': {
'help': 'Do not use any network connects, assume everything is pre-installed',
'action': 'store_true',
},
}, **WORKER_ARGS)
COMMANDS = {
@@ -83,6 +88,15 @@ COMMANDS = {
'help': 'Full environment setup log & task logging & monitoring (stdout is still visible)',
'action': 'store_true',
},
'--require-queue': {
'help': 'If the specified task is not queued (in any Queue), the execution will fail. '
'(Used for 3rd party scheduler integration, e.g. K8s, SLURM, etc.)',
'action': 'store_true',
},
'--standalone-mode': {
'help': 'Do not use any network connects, assume everything is pre-installed',
'action': 'store_true',
},
}, **WORKER_ARGS),
},
'build': {
@@ -96,8 +110,25 @@ COMMANDS = {
'dest': 'task_id',
'type': foreign_object_id('tasks'),
},
'--target-folder': {
'help': 'Where to build the task\'s virtual environment and source code',
'--target': {
'help': 'Where to build the task\'s virtual environment and source code. '
'When used with --docker, target docker image name to create',
},
'--docker': {
'help': 'Build the experiment inside a docker (v19.03 and above). Optional args <image> <arguments> or '
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
'use --gpus/--cpu-only (or set NVIDIA_VISIBLE_DEVICES) to limit gpu visibility for docker',
'nargs': '*',
'default': False,
},
'--gpus': {
'help': 'Specify active GPUs for the docker to use'
'Equivalent to setting NVIDIA_VISIBLE_DEVICES '
'Examples: --gpus 0 or --gpu 0,1,2 or --gpus all',
},
'--cpu-only': {
'help': 'Disable GPU access (cpu only) for the docker',
'action': 'store_true',
},
'--python-version': {
'help': 'Virtual environment python version to use',

View File

@@ -99,10 +99,12 @@ class Session(_Session):
if not self.config.get('api.host', None) and self.config.get('api.api_server', None):
self.config['api']['host'] = self.config.get('api.api_server')
# initialize nvidia visibility variable
# initialize nvidia visibility variables
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'):
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
# do not create CUDA_VISIBLE_DEVICES if it doesn't exist, it breaks TF/PyTotch CUDA detection
# os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
pass
elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'):
os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES')

View File

@@ -1 +1 @@
__version__ = '0.12.1'
__version__ = '0.12.2'