Initial release

This commit is contained in:
allegroai
2019-10-25 22:28:44 +03:00
parent 4d808bedc5
commit 224a709f40
101 changed files with 26263 additions and 0 deletions

13
.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
# Mac
.DS_Store
# IntelliJ
.idea
# Python
*.pyc
__pycache__
build/
dist/
*.egg-info

235
docs/trains.conf Normal file
View File

@@ -0,0 +1,235 @@
# TRAINS-AGENT configuration file
api {
api_server: https://demoapi.trainsai.io
web_server: https://demoapp.trainsai.io
files_server: https://demofiles.trainsai.io
# Credentials are generated in the webapp, https://demoapp.trainsai.io/profile
# Overridden with os environment: TRAINS_API_ACCESS_KEY / TRAINS_API_SECRET_KEY
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
}
agent {
# Set GIT user/pass credentials
# leave blank for GIT SSH credentials
git_user=""
git_pass=""
# unique name of this worker, if None, created based on hostname:process_id
# Overridden with os environment: TRAINS_WORKER_NAME
# worker_id: "trains-agent-machine1:gpu0"
worker_id: ""
# worker name, replaces the hostname when creating a unique name for this worker
# Overridden with os environment: TRAINS_WORKER_ID
# worker_name: "trains-agent-machine1"
worker_name: ""
# select python package manager:
# currently supported pip and conda
# poetry is used if pip selected and repository contains poetry.lock file
package_manager: {
# supported options: pip, conda
type: pip,
# virtual environment inheres packages from system
system_site_packages: false,
# install with --upgrade
force_upgrade: false,
# additional artifact repositories to use when installing python packages
# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]
extra_index_url: []
# additional conda channels to use when installing with conda package manager
conda_channels: ["pytorch", "conda-forge", ]
},
# target folder for virtual environments builds, created when executing experiment
venvs_dir = ~/.trains/venvs-builds
# cached git clone folder
vcs_cache: {
enabled: true,
path: ~/.trains/vcs-cache
},
# use venv-update in order to accelerate python virtual environment building
# Still in beta, turned off by default
venv_update: {
enabled: false,
},
# cached folder for specific python package download (mostly pytorch versions)
pip_download_cache {
enabled: true,
path: ~/.trains/pip-download-cache
},
translate_ssh: true,
# reload configuration file every daemon execution
reload_config: false,
# pip cache folder used mapped into docker, for python package caching
docker_pip_cache = ~/.trains/pip-cache
# apt cache folder used mapped into docker, for ubuntu package caching
docker_apt_cache = ~/.trains/apt-cache
default_docker: {
# default docker image to use when running in docker mode
image: "nvidia/cuda"
# optional arguments to pass to docker image
# arguments: ["--ipc=host"]
}
}
sdk {
# TRAINS - default SDK configuration
storage {
cache {
# Defaults to system temp folder / cache
default_base_dir: "~/.trains/cache"
}
direct_access: [
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
# or cached, and any download request will return a direct reference.
# Objects are specified in glob format, available for url and content_type.
{ url: "file://*" } # file-urls are always directly referenced
]
}
metrics {
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
# X files are stored in the upload destination for each metric/variant combination.
file_history_size: 100
# Settings for generated debug images
images {
format: JPEG
quality: 87
subsampling: 0
}
}
network {
metrics {
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
# a specific iteration
file_upload_threads: 4
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
# being sent for upload
file_upload_starvation_warning_sec: 120
}
iteration {
# Max number of retries when getting frames if the server returned an error (http code 500)
max_retries_on_server_error: 5
# Backoff factory for consecutive retry attempts.
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
retry_backoff_factor_sec: 10
}
}
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
# {
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
# host: "my-minio-host:9000"
# key: "12345678"
# secret: "12345678"
# multipart: false
# secure: false
# }
]
}
boto3 {
pool_connections: 512
max_multipart_concurrency: 16
}
}
google.storage {
# # Default project and credentials file
# # Will be used when no bucket configuration is found
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# # Specific credentials per bucket and sub directory
# credentials = [
# {
# bucket: "my-bucket"
# subdir: "path/in/bucket" # Not required
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# },
# ]
}
azure.storage {
# containers: [
# {
# account_name: "trains"
# account_key: "secret"
# # container_name:
# }
# ]
}
log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
task_log_buffer_capacity: 66
# disable urllib info and lower levels
disable_urllib3_info: True
}
development {
# Development-mode options
# dev task reuse window
task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously
vcs_repo_detect_async: True
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
# Development mode worker
worker {
# Status report period in seconds
report_period_sec: 2
# ping to the server - check connectivity
ping_period_sec: 30
# Log all stdout & stderr
log_stdout: True
}
}
}

6
main.py Normal file
View File

@@ -0,0 +1,6 @@
import sys
from trains_agent import __main__
if __name__ == "__main__":
sys.exit(__main__.main())

23
requirements.txt Normal file
View File

@@ -0,0 +1,23 @@
attrs>=18.0
enum34>=0.9 ; python_version < '3.6'
furl>=2.0.0
future>=0.16.0
humanfriendly>=2.1
jsonmodels>=2.2
jsonschema>=2.6.0
pathlib2>=2.3.0
psutil>=3.4.2
pyhocon>=0.3.38
pyparsing>=2.0.3
python-dateutil>=2.4.2
pyjwt>=1.6.4
PyYAML>=3.12
requests-file>=1.4.2
requests>=2.20.0
requirements_parser>=0.2.0
semantic_version>=2.6.0
six>=1.11.0
tqdm>=4.19.5
typing>=3.6.4
urllib3>=1.21.1
virtualenv>=16

76
setup.py Normal file
View File

@@ -0,0 +1,76 @@
"""
TRAINS - Artificial Intelligence Version Control
TRAINS-AGENT DevOps for machine/deep learning
https://github.com/allegroai/trains-agent
"""
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
from six import exec_
from pathlib2 import Path
here = Path(__file__).resolve().parent
# Get the long description from the README file
long_description = (here / 'README.md').read_text()
def read_version_string():
result = {}
exec_((here / 'trains_agent/version.py').read_text(), result)
return result['__version__']
version = read_version_string()
requirements = (here / 'requirements.txt').read_text().splitlines()
setup(
name='trains_agent',
version=version,
description='Trains-Agent DevOps for deep learning (DevOps for TRAINS)',
long_description=long_description,
# The project's main homepage.
url='https://github.com/allegroai/trains-agent',
author='Allegroai',
author_email='trains@allegro.ai',
license='Apache License 2.0',
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: System Administrators',
'Intended Audience :: Science/Research',
'Operating System :: POSIX :: Linux',
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Image Recognition',
'Topic :: System :: Logging',
'Topic :: System :: Monitoring',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'License :: OSI Approved :: Apache Software License',
],
keywords='trains devops machine deep learning agent automation hpc cluster',
packages=find_packages(exclude=['contrib', 'docs', 'data', 'examples', 'tests*']),
install_requires=requirements,
extras_require={
},
package_data={
'trains_agent': ['backend_api/config/default/*.conf']
},
include_package_data=True,
# To provide executable scripts, use entry points in preference to the
# "scripts" keyword. Entry points provide cross-platform support and allow
# pip to create the appropriate form of executable for the target platform.
entry_points={
'console_scripts': ['trains-agent=trains_agent.__main__:main'],
},
)

0
tests/__init__.py Normal file
View File

51
tests/conftest.py Normal file
View File

@@ -0,0 +1,51 @@
from argparse import Namespace
from contextlib import contextmanager
import pytest
import yaml
from pathlib2 import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@pytest.fixture(scope='function')
def run_trains_agent(script_runner):
""" Execute trains_agent agent app in subprocess and return stdout as a string.
Args:
script_runner (object): a pytest plugin for testing python scripts
installed via console_scripts entry point of setup.py.
It can run the scripts under test in a separate process or using the interpreter that's running
the test suite. The former mode ensures that the script will run in an environment that
is identical to normal execution whereas the latter one allows much quicker test runs during development
while simulating the real runs as muh as possible.
For more details: https://pypi.python.org/pypi/pytest-console-scripts
Returns:
string: The return value. stdout output
"""
def _method(*args):
trains_agent_file = str(PROJECT_ROOT / "trains_agent.sh")
ret = script_runner.run(trains_agent_file, *args)
return ret
return _method
@pytest.fixture(scope='function')
def trains_agentyaml(tmpdir):
@contextmanager
def _method(template_file):
file = tmpdir.join("trains_agent.yaml")
with (PROJECT_ROOT / "tests/templates" / template_file).open() as f:
code = yaml.load(f)
yield Namespace(code=code, file=file.strpath)
file.write(yaml.dump(code))
return _method
# class Test(object):
# def yaml_file(self, tmpdir, template_file):
# file = tmpdir.join("trains_agent.yaml")
# with open(template_file) as f:
# test_object = yaml.load(f)
# self.let(test_object)
# file.write(yaml.dump(test_object))
# return file.strpath

View File

View File

@@ -0,0 +1,35 @@
import pytest
from trains_agent.helper.repo import VCS
@pytest.mark.parametrize(
["url", "expected"],
(
("a", None),
("foo://a/b", None),
("foo://a/b/", None),
("https://a/b/", None),
("https://example.com/a/b", None),
("https://example.com/a/b/", None),
("ftp://example.com/a/b", None),
("ftp://example.com/a/b/", None),
("github.com:foo/bar.git", "https://github.com/foo/bar.git"),
("git@github.com:foo/bar.git", "https://github.com/foo/bar.git"),
("bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("hg@bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://git@github.com/foo/bar.git", "https://github.com/foo/bar.git"),
("ssh://user@github.com/foo/bar.git", "https://user@github.com/foo/bar.git"),
("ssh://git:password@github.com/foo/bar.git", "https://git:password@github.com/foo/bar.git"),
("ssh://user:password@github.com/foo/bar.git", "https://user:password@github.com/foo/bar.git"),
("ssh://hg@bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
("ssh://user@bitbucket.org/foo/bar.git", "https://user@bitbucket.org/foo/bar.git"),
("ssh://hg:password@bitbucket.org/foo/bar.git", "https://hg:password@bitbucket.org/foo/bar.git"),
("ssh://user:password@bitbucket.org/foo/bar.git", "https://user:password@bitbucket.org/foo/bar.git"),
),
)
def test(url, expected):
result = VCS.resolve_ssh_url(url)
expected = expected or url
assert result == expected

View File

@@ -0,0 +1,76 @@
import pytest
from furl import furl
from trains_agent.helper.package.translator import RequirementsTranslator
@pytest.mark.parametrize(
"line",
(
furl()
.set(
scheme=scheme,
host=host,
path=path,
query=query,
fragment=fragment,
port=port,
username=username,
password=password,
)
.url
for scheme in ("http", "https", "ftp")
for host in ("a", "example.com")
for path in (None, "/", "a", "a/", "a/b", "a/b/", "a b", "a b ")
for query in (None, "foo", "foo=3", "foo=3&bar")
for fragment in (None, "foo")
for port in (None, 1337)
for username in (None, "", "user")
for password in (None, "", "password")
),
)
def test_supported(line):
assert "://" in line
assert RequirementsTranslator.is_supported_link(line)
@pytest.mark.parametrize(
"line",
[
"pytorch",
"foo",
"foo1",
"bar",
"bar1",
"foo-bar",
"foo-bar1",
"foo-bar-1",
"foo_bar",
"foo_bar1",
"foo_bar_1",
" https://a",
" https://a/b",
" http://a",
" http://a/b",
" ftp://a/b",
"file://a/b",
"ssh://a/b",
"foo://a/b",
"git//a/b",
"git+https://a/b",
"https+git://a/b",
"git+http://a/b",
"http+git://a/b",
"",
" ",
"-e ",
"-e x",
"-e http://a",
"-e http://a/b",
"-e https://a",
"-e https://a/b",
"-e file://a/b",
],
)
def test_not_supported(line):
assert not RequirementsTranslator.is_supported_link(line)

View File

@@ -0,0 +1,33 @@
import attr
import pytest
import requests
from furl import furl
import six
from trains_agent.helper.package.pytorch import PytorchRequirement
@attr.s
class PytorchURLWheel(object):
os = attr.ib()
cuda = attr.ib()
python = attr.ib()
pytorch = attr.ib()
url = attr.ib()
wheels = [
PytorchURLWheel(os=os, cuda=cuda, python=python, pytorch=pytorch_version, url=url)
for os, os_d in PytorchRequirement.MAP.items()
for cuda, cuda_d in os_d.items()
if isinstance(cuda_d, dict)
for python, python_d in cuda_d.items()
if isinstance(python_d, dict)
for pytorch_version, url in python_d.items()
if isinstance(url, six.string_types) and furl(url).scheme
]
@pytest.mark.parametrize('wheel', wheels, ids=[','.join(map(str, attr.astuple(wheel))) for wheel in wheels])
def test_map(wheel):
assert requests.head(wheel.url).ok

View File

@@ -0,0 +1,75 @@
import pytest
from trains_agent.helper.repo import Git
NO_CHANGE = object()
def param(url, expected, user=False, password=False):
"""
Helper function for creating parametrization arguments.
:param url: input url
:param expected: expected output URL or NO_CHANGE if the same as input URL
:param user: Add `agent.git_user=USER` to config
:param password: Add `agent.git_password=PASSWORD` to config
"""
expected_repr = "NO_CHANGE" if expected is NO_CHANGE else None
user = "USER" if user else None
password = "PASSWORD" if password else None
return pytest.param(
url,
expected,
user,
password,
id="-".join(filter(None, (url, user, password, expected_repr))),
)
@pytest.mark.parametrize(
"url,expected,user,password",
[
param("https://bitbucket.org/company/repo", NO_CHANGE),
param("https://bitbucket.org/company/repo", NO_CHANGE, user=True),
param("https://bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"https://bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
),
param("https://user@bitbucket.org/company/repo", NO_CHANGE),
param("https://user@bitbucket.org/company/repo", NO_CHANGE, user=True),
param("https://user@bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"https://user@bitbucket.org/company/repo",
"https://USER:PASSWORD@bitbucket.org/company/repo",
user=True,
password=True,
),
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE),
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE, user=True),
param(
"https://user:password@bitbucket.org/company/repo", NO_CHANGE, password=True
),
param(
"https://user:password@bitbucket.org/company/repo",
NO_CHANGE,
user=True,
password=True,
),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True),
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, password=True),
param(
"ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
),
param("git@bitbucket.org:company/repo.git", NO_CHANGE),
param("git@bitbucket.org:company/repo.git", NO_CHANGE, user=True),
param("git@bitbucket.org:company/repo.git", NO_CHANGE, password=True),
param(
"git@bitbucket.org:company/repo.git", NO_CHANGE, user=True, password=True
),
],
)
def test(url, user, password, expected):
config = {"agent": {"git_user": user, "git_pass": password}}
result = Git.add_auth(config, url)
expected = result if expected is NO_CHANGE else expected
assert result == expected

View File

@@ -0,0 +1,254 @@
"""
Test handling of jupyter notebook tasks.
Logging is enabled in `trains_agent/tests/pytest.ini`. Search for `pytest live logging` for more info.
"""
import logging
import re
import select
import subprocess
import time
from contextlib import contextmanager
from typing import Iterator, ContextManager, Sequence, IO, Text
from uuid import uuid4
from trains_agent.backend_api.services.tasks import Script
from trains_agent.backend_api.session.client import APIClient
from pathlib2 import Path
from pytest import fixture
from trains_agent.helper.process import Argv
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
log = logging.getLogger(__name__)
DEFAULT_TASK_ARGS = {"type": "testing", "name": "test", "input": {"view": {}}}
HERE = Path(__file__).resolve().parent
SHORT_TIMEOUT = 30
@fixture(scope="session")
def client():
return APIClient(api_version="2.2")
@contextmanager
def create_task(client, **params):
"""
Create task in backend
"""
log.info("creating new task")
task = client.tasks.create(**params)
try:
yield task
finally:
log.info("deleting task, id=%s", task.id)
task.delete(force=True)
def select_read(file_obj, timeout):
return select.select([file_obj], [], [], timeout)[0]
def run_task(task):
return Argv("trains_agent", "--debug", "worker", "execute", "--id", task.id)
@contextmanager
def iterate_output(timeout, command):
# type: (int, Argv) -> ContextManager[Iterator[Text]]
"""
Run `command` in a subprocess and return a contextmanager of an iterator
over its output's lines. If `timeout` seconds have passed, iterator ends and
the process is killed.
:param timeout: maximum amount of time to wait for command to end
:param command: command to run
"""
log.info("running: %s", command)
process = command.call_subprocess(
subprocess.Popen, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
try:
yield _iterate_output(timeout, process)
finally:
status = process.poll()
if status is not None:
log.info("command %s terminated, status: %s", command, status)
else:
log.info("killing command %s", command)
process.kill()
def _iterate_output(timeout, process):
# type: (int, subprocess.Popen) -> Iterator[Text]
"""
Return an iterator over process's output lines.
over its output's lines.
If `timeout` seconds have passed, iterator ends and the process is killed.
:param timeout: maximum amount of time to wait for command to end
:param process: process to iterate over its output lines
"""
start = time.time()
exit_loop = [] # type: Sequence[IO]
def loop_helper(file_obj):
# type: (IO) -> Sequence[IO]
diff = timeout - (time.time() - start)
if diff <= 0:
return exit_loop
return select_read(file_obj, timeout=diff)
buffer = ""
for output, in iter(lambda: loop_helper(process.stdout), exit_loop):
try:
added = output.read(1024).decode("utf8")
except EOFError:
if buffer:
yield buffer
return
buffer += added
lines = buffer.split("\n", 1)
while len(lines) > 1:
line, buffer = lines
log.debug("--- %s", line)
yield line
lines = buffer.split("\n", 1)
def search_lines(lines, search_for, error):
# type: (Iterator[Text], Text, Text) -> None
"""
Fail test if `search_for` string appears nowhere in `lines`.
Consumes `lines` up to point where `search_for` is found.
:param lines: lines to search in
:param search_for: string to search lines for
:param error: error to show if not found
"""
for line in lines:
if search_for in line:
break
else:
assert False, error
def search_lines_pattern(lines, pattern, error):
# type: (Iterator[Text], Text, Text) -> None
"""
Like `search_lines` but searches for a pattern.
:param lines: lines to search in
:param pattern: pattern to search lines for
:param error: error to show if not found
"""
for line in lines:
if re.search(pattern, line):
break
else:
assert False, error
def test_entry_point_warning(client):
"""
non-empty script.entry_point should output a warning
"""
with create_task(
client,
script=Script(diff="print('hello')", entry_point="foo.py", repository=""),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
for line in output:
if "found non-empty script.entry_point" in line:
break
else:
assert False, "did not find warning in output"
def test_run_no_dirs(client):
"""
The arbitrary `code` directory should be selected when there is no `script.repository`
"""
uuid = uuid4().hex
script = "print('{}')".format(uuid)
with create_task(
client,
script=Script(diff=script, entry_point="", repository="", working_dir=""),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
search_lines(
output,
"found literal script",
"task was not recognized as a literal script",
)
search_lines_pattern(
output,
r"selected execution directory:.*code",
r"did not selected empty `code` dir as execution dir",
)
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
def test_run_working_dir(client):
"""
Literal script tasks should respect `working_dir`
"""
uuid = uuid4().hex
script = "print('{}')".format(uuid)
with create_task(
client,
script=Script(
diff=script,
entry_point="",
repository="git@bitbucket.org:seematics/roee_test_git.git",
working_dir="space dir",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(120, run_task(task)) as output:
search_lines(
output,
"found literal script",
"task was not recognized as a literal script",
)
search_lines_pattern(
output,
r"selected execution directory:.*space dir",
r"did not selected working_dir as set in execution_info",
)
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
def test_regular_task(client):
"""
Test a plain old task
"""
with create_task(
client,
script=Script(
entry_point="noop.py",
repository="git@bitbucket.org:seematics/roee_test_git.git",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
message = "Done"
search_lines(
output, message, "did not reach dummy output message {}".format(message)
)
def test_regular_task_nested(client):
"""
`entry_point` should be relative to `working_dir` if present
"""
with create_task(
client,
script=Script(
entry_point="noop_nested.py",
working_dir="no_reqs",
repository="git@bitbucket.org:seematics/roee_test_git.git",
),
**DEFAULT_TASK_ARGS
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
message = "Done"
search_lines(
output, message, "did not reach dummy output message {}".format(message)
)

11
tests/pytest.ini Normal file
View File

@@ -0,0 +1,11 @@
[pytest]
python_files=*.py
script_launch_mode = subprocess
env =
PYTHONIOENCODING=utf-8
log_cli = 1
log_cli_level = DEBUG
log_cli_format = [%(name)s] [%(asctime)s] [%(levelname)8s] %(message)s
log_print = False
log_cli_date_format=%Y-%m-%d %H:%M:%S
addopts = -p no:warnings

2
tests/requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
pytest
-r ../requirements.txt

View File

View File

@@ -0,0 +1,13 @@
import os
def pytest_addoption(parser):
parser.addoption('--cpu', action='store_true')
def pytest_configure(config):
if not config.option.cpu:
return
os.environ['PATH'] = ':'.join(p for p in os.environ['PATH'].split(':') if 'cuda' not in p)
os.environ['CUDA_VERSION'] = ''
os.environ['CUDNN_VERSION'] = ''

View File

@@ -0,0 +1,350 @@
from __future__ import unicode_literals
import re
import subprocess
from itertools import chain
from os import path
from os.path import sep
from sys import platform as sys_platform
import pytest
import requirements
from trains_agent.commands.worker import Worker
from trains_agent.helper.package.pytorch import PytorchRequirement
from trains_agent.helper.package.requirements import RequirementsManager, \
RequirementSubstitution, MarkerRequirement
from trains_agent.helper.process import get_bash_output
from trains_agent.session import Session
_cuda_based_packages_hack = ('seematics.caffe', 'lightnet')
def old_get_suffix(session):
cuda_version = session.config['agent.cuda_version']
cudnn_version = session.config['agent.cudnn_version']
if cuda_version and cudnn_version:
nvcc_ver = cuda_version.strip()
cudnn_ver = cudnn_version.strip()
else:
if sys_platform == 'win32':
nvcc_ver = subprocess.check_output('nvcc --version'.split()).decode('utf-8'). \
replace('\r', '').split('\n')
else:
nvcc_ver = subprocess.check_output(
'nvcc --version'.split()).decode('utf-8').split('\n')
nvcc_ver = [l for l in nvcc_ver if 'release ' in l]
nvcc_ver = nvcc_ver[0][nvcc_ver[0].find(
'release ') + len('release '):][:3]
nvcc_ver = str(int(float(nvcc_ver) * 10))
if sys_platform == 'win32':
cuda_lib = subprocess.check_output('where nvcc'.split()).decode('utf-8'). \
replace('\r', '').split('\n')[0]
cudnn_h = sep.join(cuda_lib.split(
sep)[:-2] + ['include', 'cudnn.h'])
else:
cuda_lib = subprocess.check_output(
'which nvcc'.split()).decode('utf-8').split('\n')[0]
cudnn_h = path.join(
sep, *(cuda_lib.split(sep)[:-2] + ['include', 'cudnn.h']))
cudnn_mj, cudnn_mi = None, None
for l in open(cudnn_h, 'r'):
if 'CUDNN_MAJOR' in l:
cudnn_mj = l.split()[-1]
if 'CUDNN_MINOR' in l:
cudnn_mi = l.split()[-1]
if cudnn_mj and cudnn_mi:
break
cudnn_ver = cudnn_mj + ('0' if not cudnn_mi else cudnn_mi)
# build cuda + cudnn version suffix
# make sure these are integers, someone else will catch the exception
pkg_suffix_ver = '.post' + \
str(int(nvcc_ver)) + '.dev' + str(int(cudnn_ver))
return pkg_suffix_ver, nvcc_ver, cudnn_ver
def old_replace(session, line):
try:
cuda_ver_suffix, cuda_ver, cuda_cudnn_ver = old_get_suffix(session)
except Exception:
return line
if line.lstrip().startswith('#'):
return line
for package_name in _cuda_based_packages_hack:
if package_name not in line:
continue
try:
line_lstrip = line.lstrip()
if line_lstrip.startswith('http://') or line_lstrip.startswith('https://'):
pos = line.find(package_name) + len(package_name)
# patch line with specific version
line = line[:pos] + \
line[pos:].replace('-cp', cuda_ver_suffix + '-cp', 1)
else:
# this is a pypi package
tokens = line.replace('=', ' ').replace('<', ' ').replace('>', ' ').replace(';', ' '). \
replace('!', ' ').split()
if package_name != tokens[0]:
# how did we get here, probably a mistake
found_cuda_based_package = False
continue
version_number = None
if len(tokens) > 1:
# get the package version info
test_version_number = tokens[1]
# check if we have a valid version, i.e. does not contain post/dev
version_number = '.'.join([v for v in test_version_number.split('.')
if v and '0' <= v[0] <= '9'])
if version_number != test_version_number:
raise ValueError()
# we have no version, but we have to have one
if not version_number:
# get the latest version from the extra index list
pip_search_cmd = ['pip', 'search']
if Worker._pip_extra_index_url:
pip_search_cmd.extend(
chain.from_iterable(('-i', x) for x in Worker._pip_extra_index_url))
pip_search_cmd += [package_name]
pip_search_output = get_bash_output(
' '.join(pip_search_cmd), strip=True)
version_number = pip_search_output.split(package_name)[1]
version_number = version_number.replace(
'(', ' ').replace(')', ' ').split()[0]
version_number = '.'.join([v for v in version_number.split('.')
if v and '0' <= v[0] <= '9'])
if not version_number:
# somewhere along the way we failed
raise ValueError()
package_name_version = package_name + '==' + version_number + cuda_ver_suffix
if version_number in line:
# make sure we have the specific version not >=
tokens = line.split(';')
line = ';'.join([package_name_version] + tokens[1:])
else:
# add version to the package_name
line = line.replace(package_name, package_name_version, 1)
# print('pip install %s using CUDA v%s CuDNN v%s' % (package_name, cuda_ver, cuda_cudnn_ver))
except ValueError:
pass
# print('Warning! could not find installed CUDA/CuDNN version for %s, '
# 'using original requirements line: %s' % (package_name, line))
# add the current line into the cuda requirements list
return line
win_condition = 'sys_platform != "win_32"'
versions = ('1', '1.4', '1.4.9', '1.5.3.dev0',
'1.5.3.dev3', '43.1.2.dev0.post1')
def normalize(result):
return result and re.sub(' ?; ?', ';', result)
def parse_one(requirement):
return MarkerRequirement(next(requirements.parse(requirement)))
def compare(manager, current_session, pytest_config, requirement, expected):
try:
res1 = normalize(manager._replace_one(parse_one(requirement)))
except ValueError:
res1 = None
res2 = old_replace(current_session, requirement)
if res2 == requirement:
res2 = None
res2 = normalize(res2)
expected = normalize(expected)
if pytest_config.option.cpu:
assert res1 is None
else:
assert res1 == expected
if requirement not in FAILURES:
assert res1 == res2
return res1
ARG_VERSIONED = pytest.mark.parametrize('arg', (
'nothing{op}{version}{extra}',
'something_else{op}{version}{extra}',
'something-else{op}{version}{extra}',
'something.else{op}{version}{extra}',
'seematics.caffe{op}{version}{extra}',
'lightnet{op}{version}{extra}',
))
OP = pytest.mark.parametrize('op', ('==', '<=', '>='))
VERSION = pytest.mark.parametrize('version', versions)
EXTRA = pytest.mark.parametrize('extra', ('', ' ; ' + win_condition))
ARG_PLAIN = pytest.mark.parametrize('arg', (
'nothing',
'something_else',
'something-else',
'something.else',
'seematics.caffe',
'lightnet',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.config-1.0.2-py2.py3-none-any.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.api-1.0.2-py2.py3-none-any.whl',
'https://s3.amazonaws.com/seematics-pip/public/seematics.sdk-1.1.0-py2.py3-none-any.whl',
))
@ARG_VERSIONED
@OP
@VERSION
@EXTRA
def test_with_version(manager, current_session, pytestconfig, arg, op, version, extra):
requirement = arg.format(**locals())
# expected = EXPECTED.get((arg, op, version))
expected = get_expected(current_session, (arg, op, version))
expected = expected and expected + extra
compare(manager, current_session, pytestconfig, requirement, expected)
@ARG_PLAIN
@EXTRA
def test_plain(arg, current_session, pytestconfig, extra, manager):
# expected = EXPECTED.get(arg)
expected = get_expected(current_session, arg)
expected = expected and expected + extra
compare(manager, current_session, pytestconfig, arg + extra, expected)
def get_expected(session, key):
result = EXPECTED.get(key)
cuda_version, cudnn_version = session.config['agent.cuda_version'], session.config['agent.cudnn_version']
suffix = '.post{cuda_version}.dev{cudnn_version}'.format(**locals()) \
if (cuda_version and cudnn_version) \
else ''
return result and result.format(suffix=suffix)
@ARG_VERSIONED
@OP
@VERSION
@EXTRA
def test_str_versioned(arg, op, version, extra):
requirement = arg.format(**locals())
assert normalize(str(parse_one(requirement))) == normalize(requirement)
@ARG_PLAIN
@EXTRA
def test_str_plain(arg, extra, manager):
requirement = arg.format(**locals())
assert normalize(str(parse_one(requirement))) == normalize(requirement)
@pytest.fixture(scope='session')
def current_session(pytestconfig):
session = Session()
if not pytestconfig.option.cpu:
return session
session.config['agent.cuda_version'] = None
session.config['agent.cudnn_version'] = None
return session
@pytest.fixture(scope='session')
def manager(current_session):
manager = RequirementsManager(current_session.config)
for requirement in (PytorchRequirement, ):
manager.register(requirement)
return manager
SPECS = (
# plain
'',
# greater than
'>=0', '>0.1', '>=0.1', '>0.1.2', '>=0.1.2', '>0.1.2.post30', '>=0.1.3.post30', '>0.post30.dev40',
'>=0.post30.dev40', '>0.post30-dev40', '>=0.post30-dev40', '>0.0', '>=0.0', '>0.3', '>=0.3', '>=0.4',
# smaller than
'<4', '<4.0', '<4.1.0', '<4.1.0.post80.dev60', '<4.1.0-post80.dev60', '<3', '<2.0', '<1.0.3', '<=4', '<=4.0',
'<=4.1.0', '<=4.1.0.post80.dev60', '<=4.1.0-post80.dev60', '<=3', '<=2.0', '<=1.0.3',
# equals
'==0.4.0',
# equals and
'==0.4.0,>=0', '==0.4.0,>=0.1', '==0.4.0,>0.1', '==0.4.0,>=0.1.2', '==0.4.0,>0.1.2', '==0.4.0,<4', '==0.4.0,<=4',
'==0.4.0,<4.0', '==0.4.0,<=4.0', '==0.4.0,<4.0.2',
# smaller and greater
'>=0,<4', '>=0,<4.0', '>0.1,<1', '>0.1,<1.1.2', '>0.1,<1.1.2', '>=0,<=4', '>=0,<4.0',
'>=0.1,<1', '>0.1,<=1.1.2', '>=0.1,<1.1.2',
)
@pytest.mark.parametrize('package_manager', ('pip', 'conda'))
@pytest.mark.parametrize('os', ('linux', 'windows', 'macos'))
@pytest.mark.parametrize('cuda', ('80', '90', '91'))
@pytest.mark.parametrize('python', ('2.7', '3.5', '3.6'))
@pytest.mark.parametrize('spec', SPECS)
@pytest.mark.parametrize('condition', ('', ' ; ' + win_condition))
def test_pytorch_success(manager, package_manager, os, cuda, python, spec, condition):
pytorch_handler = manager.handlers[-1]
pytorch_handler.package_manager = package_manager
pytorch_handler.os = os
pytorch_handler.cuda = cuda_ver = 'cuda{}'.format(cuda)
pytorch_handler.python = python_ver = 'python{}'.format(python)
req = 'torch{}{}'.format(spec, condition)
expected = pytorch_handler.MAP[package_manager][os][cuda_ver][python_ver]
if isinstance(expected, Exception):
with pytest.raises(type(expected)):
manager._replace_one(parse_one(req))
else:
expected = expected['0.4.0']
result = manager._replace_one(parse_one(req))
assert result == expected
get_pip_version = RequirementSubstitution.get_pip_version
EXPECTED = {
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl':
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp35'
'-cp35m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl':
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp27'
'-cp27m-win_amd64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl':
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp35-cp35m'
'-linux_x86_64.whl',
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl':
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp27-cp27mu'
'-linux_x86_64.whl',
'seematics.caffe': 'seematics.caffe=={}{{suffix}}'.format(get_pip_version('seematics.caffe')),
'lightnet': 'lightnet=={}{{suffix}}'.format(get_pip_version('lightnet')),
('seematics.caffe{op}{version}{extra}', '<=', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '<=', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '<=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '==', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1'): 'seematics.caffe==1{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1.4'): 'seematics.caffe==1.4{suffix}',
('seematics.caffe{op}{version}{extra}', '>=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '<=', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '<=', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '<=', '1.4.9'): 'lightnet==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '==', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '==', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '==', '1.4.9'): 'lightnet==1.4.9{suffix}',
('lightnet{op}{version}{extra}', '>=', '1'): 'lightnet==1{suffix}',
('lightnet{op}{version}{extra}', '>=', '1.4'): 'lightnet==1.4{suffix}',
('lightnet{op}{version}{extra}', '>=', '1.4.9'): 'lightnet==1.4.9{suffix}',
}
FAILURES = {
'seematics.caffe ; {}'.format(win_condition),
'lightnet ; {}'.format(win_condition)
}

View File

View File

@@ -0,0 +1,7 @@
def main():
assert 1 / 2 == 0
print('success')
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
def main():
if not (1 / 2 == 0.5):
raise ValueError('failure')
print('success')
if __name__ == "__main__":
main()

1
trains_agent/__init__.py Normal file
View File

@@ -0,0 +1 @@

99
trains_agent/__main__.py Normal file
View File

@@ -0,0 +1,99 @@
from __future__ import print_function, unicode_literals, absolute_import
import argparse
import sys
import warnings
from trains_agent.backend_api.session.datamodel import UnusedKwargsWarning
import trains_agent
from trains_agent.config import get_config
from trains_agent.definitions import FileBuffering, CONFIG_FILE
from trains_agent.helper.base import reverse_home_folder_expansion, chain_map, named_temporary_file
from trains_agent.helper.process import ExitStatus
from . import interface, session, definitions, commands
from .errors import ConfigFileNotFound, Sigterm, APIError
from .helper.trace import PackageTrace
from .interface import get_parser
def run_command(parser, args, command_name):
debug = args.debug
if len(command_name.split('.')) < 2:
command_class = commands.Worker
elif hasattr(args, 'func') and getattr(args, 'func'):
command_class = getattr(commands, command_name.capitalize())
command_name = args.func
else:
command_class, command_name = command_name.split('.')
command_class = getattr(commands, command_class.capitalize())
args_dict = dict(vars(args))
parser.remove_top_level_results(args_dict)
warnings.simplefilter('ignore', UnusedKwargsWarning)
try:
command = command_class(**vars(args))
get_config()['command'] = command
debug = command._session.debug_mode
func = getattr(command, command_name)
return func(**args_dict)
except ConfigFileNotFound:
message = 'Cannot find configuration file in "{}".\n' \
'To create a configuration file, run:\n' \
'$ trains_agent init'.format(reverse_home_folder_expansion(CONFIG_FILE))
command_class.exit(message)
except APIError as api_error:
if not debug:
command_class.error(api_error)
return ExitStatus.failure
traceback = api_error.format_traceback()
if traceback:
print(traceback)
print('Own traceback:')
raise
except Exception as e:
if debug:
raise
command_class.error(e)
return ExitStatus.failure
except (KeyboardInterrupt, Sigterm):
return ExitStatus.interrupted
def main():
parser = get_parser()
args = parser.parse_args()
try:
command_name = args.command
if not command_name:
return parser.print_help()
except AttributeError:
parser.error(argparse._('too few arguments'))
if not args.trace:
return run_command(parser, args, command_name)
with named_temporary_file(
mode='w',
buffering=FileBuffering.LINE_BUFFERING,
prefix='.trains_agent_trace_',
suffix='.txt',
delete=False,
) as output:
print(
'Saving trace for command '
'"{definitions.PROGRAM_NAME} {command_name} {args.func}" to "{output.name}"'.format(
**chain_map(locals(), globals())))
tracer = PackageTrace(
package=trains_agent,
out_file=output,
ignore_submodules=(__name__, interface, definitions, session))
return tracer.runfunc(run_command, parser, args, command_name)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,2 @@
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
from .config import load as load_config

View File

@@ -0,0 +1,16 @@
from ...backend_config import Config
from pathlib2 import Path
def load(*additional_module_paths):
# type: (str) -> Config
"""
Load configuration with the API defaults, using the additional module path provided
:param additional_module_paths: Additional config paths for modules who'se default
configurations should be loaded as well
:return: Config object
"""
config = Config(verbose=False)
this_module_path = Path(__file__).parent
config.load_relative_to(this_module_path, *additional_module_paths)
return config

View File

@@ -0,0 +1,79 @@
{
# unique name of this worker, if None, created based on hostname:process_id
# Override with os environment: TRAINS_WORKER_ID
# worker_id: "trains-agent-machine1:gpu0"
worker_id: ""
# worker name, replaces the hostname when creating a unique name for this worker
# Override with os environment: TRAINS_WORKER_NAME
# worker_name: "trains-agent-machine1"
worker_name: ""
# Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials.
# git_user: ""
# git_pass: ""
# select python package manager:
# currently supported pip and conda
# poetry is used if pip selected and repository contains poetry.lock file
package_manager: {
# supported options: pip, conda
type: pip,
# virtual environment inheres packages from system
system_site_packages: false,
# install with --upgrade
force_upgrade: false,
# additional artifact repositories to use when installing python packages
# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]
extra_index_url: []
# additional conda channels to use when installing with conda package manager
conda_channels: ["defaults", "conda-forge", "pytorch", ]
},
# target folder for virtual environments builds, created when executing experiment
venvs_dir = ~/.trains/venvs-builds
# cached git clone folder
vcs_cache: {
enabled: true,
path: ~/.trains/vcs-cache
},
# use venv-update in order to accelerate python virtual environment building
# Still in beta, turned off by default
venv_update: {
enabled: false,
},
# cached folder for specific python package download (used for pytorch package caching)
pip_download_cache {
enabled: true,
path: ~/.trains/pip-download-cache
},
translate_ssh: true,
# reload configuration file every daemon execution
reload_config: false,
# pip cache folder used mapped into docker, for python package caching
docker_pip_cache = ~/.trains/pip-cache
# apt cache folder used mapped into docker, for ubuntu package caching
docker_apt_cache = ~/.trains/apt-cache
default_docker: {
# default docker image to use when running in docker mode
image: "nvidia/cuda"
# optional arguments to pass to docker image
# arguments: ["--ipc=host", ]
}
# cuda versions used for solving pytorch wheel packages
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
# cuda_version: 10.1
# cudnn_version: 7.6
}

View File

@@ -0,0 +1,37 @@
{
version: 1.5
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
# default version assigned to requests with no specific version. this is not expected to change
# as it keeps us backwards compatible.
default_version: 1.5
http {
max_req_size = 15728640 # request size limit (smaller than that configured in api server)
retries {
# retry values (int, 0 means fail on first retry)
total: 240 # Total number of retries to allow. Takes precedence over other counts.
connect: 240 # How many times to retry on connection-related errors (never reached server)
read: 240 # How many times to retry on read errors (waiting for server)
redirect: 240 # How many redirects to perform (HTTP response with a status code 301, 302, 303, 307 or 308)
status: 240 # How many times to retry on bad status codes
# backoff parameters
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
backoff_factor: 1.0
backoff_max: 120.0
}
wait_on_maintenance_forever: true
pool_maxsize: 512
pool_connections: 512
}
auth {
# When creating a request, if token will expire in less than this value, try to refresh the token
token_expiration_threshold_sec = 360
}
}

View File

@@ -0,0 +1,8 @@
{
version: 1
loggers {
urllib3 {
level: ERROR
}
}
}

View File

@@ -0,0 +1,150 @@
{
# TRAINS - default SDK configuration
storage {
cache {
# Defaults to system temp folder / cache
default_base_dir: "~/.trains/cache"
size {
# max_used_bytes = -1
min_free_bytes = 10GB
# cleanup_margin_percent = 5%
}
}
direct_access: [
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
# or cached, and any download request will return a direct reference.
# Objects are specified in glob format, available for url and content_type.
{ url: "file://*" } # file-urls are always directly referenced
]
}
metrics {
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
# X files are stored in the upload destination for each metric/variant combination.
file_history_size: 100
# Settings for generated debug images
images {
format: JPEG
quality: 87
subsampling: 0
}
}
network {
metrics {
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
# a specific iteration
file_upload_threads: 4
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
# being sent for upload
file_upload_starvation_warning_sec: 120
}
iteration {
# Max number of retries when getting frames if the server returned an error (http code 500)
max_retries_on_server_error: 5
# Backoff factory for consecutive retry attempts.
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
retry_backoff_factor_sec: 10
}
}
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
# {
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
# host: "my-minio-host:9000"
# key: "12345678"
# secret: "12345678"
# multipart: false
# secure: false
# }
]
}
boto3 {
pool_connections: 512
max_multipart_concurrency: 16
}
}
google.storage {
# # Default project and credentials file
# # Will be used when no bucket configuration is found
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# # Specific credentials per bucket and sub directory
# credentials = [
# {
# bucket: "my-bucket"
# subdir: "path/in/bucket" # Not required
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# },
# ]
}
azure.storage {
# containers: [
# {
# account_name: "trains"
# account_key: "secret"
# # container_name:
# }
# ]
}
log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
task_log_buffer_capacity: 66
# disable urllib info and lower levels
disable_urllib3_info: True
}
development {
# Development-mode options
# dev task reuse window
task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously
vcs_repo_detect_async: True
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
# Development mode worker
worker {
# Status report period in seconds
report_period_sec: 2
# ping to the server - check connectivity
ping_period_sec: 30
# Log all stdout & stderr
log_stdout: True
}
}
}

View File

@@ -0,0 +1,38 @@
import re
from functools import partial
import attr
from attr.converters import optional as optional_converter
from attr.validators import instance_of, optional, and_
from six import string_types
# noinspection PyTypeChecker
sequence = instance_of((list, tuple))
def sequence_of(types):
def validator(_, attrib, value):
assert all(isinstance(x, types) for x in value), attrib.name
return and_(sequence, validator)
@attr.s
class Action(object):
name = attr.ib()
version = attr.ib()
service = attr.ib()
definitions_keys = attr.ib(validator=sequence)
authorize = attr.ib(validator=instance_of(bool), default=True)
log_data = attr.ib(validator=instance_of(bool), default=True)
log_result_data = attr.ib(validator=instance_of(bool), default=True)
internal = attr.ib(default=False)
allow_roles = attr.ib(default=None, validator=optional(sequence_of(string_types)))
request = attr.ib(validator=optional(instance_of(dict)), default=None)
batch_request = attr.ib(validator=optional(instance_of(dict)), default=None)
response = attr.ib(validator=optional(instance_of(dict)), default=None)
method = attr.ib(default=None)
description = attr.ib(
default=None,
validator=optional(instance_of(string_types)),
)

View File

@@ -0,0 +1,201 @@
import itertools
import re
import attr
import six
import pyhocon
from .action import Action
class Service(object):
""" Service schema handler """
__jsonschema_ref_ex = re.compile("^#/definitions/(.*)$")
@property
def default(self):
return self._default
@property
def actions(self):
return self._actions
@property
def definitions(self):
""" Raw service definitions (each might be dependant on some of its siblings) """
return self._definitions
@property
def definitions_refs(self):
return self._definitions_refs
@property
def name(self):
return self._name
@property
def doc(self):
return self._doc
def __init__(self, name, service_config):
self._name = name
self._default = None
self._actions = []
self._definitions = None
self._definitions_refs = None
self._doc = None
self.parse(service_config)
@classmethod
def get_ref_name(cls, ref_string):
m = cls.__jsonschema_ref_ex.match(ref_string)
if m:
return m.group(1)
def parse(self, service_config):
self._default = service_config.get(
"_default", pyhocon.ConfigTree()
).as_plain_ordered_dict()
self._doc = '{} service'.format(self.name)
description = service_config.get('_description', '')
if description:
self._doc += '\n\n{}'.format(description)
self._definitions = service_config.get(
"_definitions", pyhocon.ConfigTree()
).as_plain_ordered_dict()
self._definitions_refs = {
k: self._get_schema_references(v) for k, v in self._definitions.items()
}
all_refs = set(itertools.chain(*self.definitions_refs.values()))
if not all_refs.issubset(self.definitions):
raise ValueError(
"Unresolved references (%s) in %s/definitions"
% (", ".join(all_refs.difference(self.definitions)), self.name)
)
actions = {
k: v.as_plain_ordered_dict()
for k, v in service_config.items()
if not k.startswith("_")
}
self._actions = {
action_name: action
for action_name, action in (
(action_name, self._parse_action_versions(action_name, action_versions))
for action_name, action_versions in actions.items()
)
if action
}
def _parse_action_versions(self, action_name, action_versions):
def parse_version(action_version):
try:
return float(action_version)
except (ValueError, TypeError) as ex:
raise ValueError(
"Failed parsing version number {} ({}) in {}/{}".format(
action_version, ex.args[0], self.name, action_name
)
)
def add_internal(cfg):
if "internal" in action_versions:
cfg.setdefault("internal", action_versions["internal"])
return cfg
return {
parsed_version: action
for parsed_version, action in (
(parsed_version, self._parse_action(action_name, parsed_version, add_internal(cfg)))
for parsed_version, cfg in (
(parse_version(version), cfg)
for version, cfg in action_versions.items()
if version not in ["internal", "allow_roles", "authorize"]
)
)
if action
}
def _get_schema_references(self, s):
refs = set()
if isinstance(s, dict):
for k, v in s.items():
if isinstance(v, six.string_types):
m = self.__jsonschema_ref_ex.match(v)
if m:
refs.add(m.group(1))
continue
elif k in ("oneOf", "anyOf") and isinstance(v, list):
refs.update(*map(self._get_schema_references, v))
refs.update(self._get_schema_references(v))
return refs
def _expand_schema_references_with_definitions(self, schema, refs=None):
definitions = schema.get("definitions", {})
refs = refs if refs is not None else self._get_schema_references(schema)
required_refs = set(refs).difference(definitions)
if not required_refs:
return required_refs
if not required_refs.issubset(self.definitions):
raise ValueError(
"Unresolved references (%s)"
% ", ".join(required_refs.difference(self.definitions))
)
# update required refs with all sub requirements
last_required_refs = None
while last_required_refs != required_refs:
last_required_refs = required_refs.copy()
additional_refs = set(
itertools.chain(
*(self.definitions_refs.get(ref, []) for ref in required_refs)
)
)
required_refs.update(additional_refs)
return required_refs
def _resolve_schema_references(self, schema, refs=None):
definitions = schema.get("definitions", {})
definitions.update({k: v for k, v in self.definitions.items() if k in refs})
schema["definitions"] = definitions
def _parse_action(self, action_name, action_version, action_config):
data = self.default.copy()
data.update(action_config)
if not action_config.get("generate", True):
return None
definitions_keys = set()
for schema_key in ("request", "response"):
if schema_key in action_config:
try:
schema = action_config[schema_key]
refs = self._expand_schema_references_with_definitions(schema)
self._resolve_schema_references(schema, refs=refs)
definitions_keys.update(refs)
except ValueError as ex:
name = "%s.%s/%.1f/%s" % (
self.name,
action_name,
action_version,
schema_key,
)
raise ValueError("%s in %s" % (str(ex), name))
return Action(
name=action_name,
version=action_version,
definitions_keys=list(definitions_keys),
service=self.name,
**(
{
key: value
for key, value in data.items()
if key in attr.fields_dict(Action)
}
)
)

View File

@@ -0,0 +1,13 @@
from .v2_4 import auth
from .v2_4 import debug
from .v2_4 import queues
from .v2_4 import tasks
from .v2_4 import workers
__all__ = [
'auth',
'debug',
'queues',
'tasks',
'workers',
]

View File

@@ -0,0 +1,623 @@
"""
auth service
This service provides authentication management and authorization
validation for the entire system.
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class Credentials(NonStrictDataModel):
"""
:param access_key: Credentials access key
:type access_key: str
:param secret_key: Credentials secret key
:type secret_key: str
"""
_schema = {
'properties': {
'access_key': {
'description': 'Credentials access key',
'type': ['string', 'null'],
},
'secret_key': {
'description': 'Credentials secret key',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, access_key=None, secret_key=None, **kwargs):
super(Credentials, self).__init__(**kwargs)
self.access_key = access_key
self.secret_key = secret_key
@schema_property('access_key')
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
@schema_property('secret_key')
def secret_key(self):
return self._property_secret_key
@secret_key.setter
def secret_key(self, value):
if value is None:
self._property_secret_key = None
return
self.assert_isinstance(value, "secret_key", six.string_types)
self._property_secret_key = value
class CredentialKey(NonStrictDataModel):
"""
:param access_key:
:type access_key: str
:param last_used:
:type last_used: datetime.datetime
:param last_used_from:
:type last_used_from: str
"""
_schema = {
'properties': {
'access_key': {'description': '', 'type': ['string', 'null']},
'last_used': {
'description': '',
'format': 'date-time',
'type': ['string', 'null'],
},
'last_used_from': {'description': '', 'type': ['string', 'null']},
},
'type': 'object',
}
def __init__(
self, access_key=None, last_used=None, last_used_from=None, **kwargs):
super(CredentialKey, self).__init__(**kwargs)
self.access_key = access_key
self.last_used = last_used
self.last_used_from = last_used_from
@schema_property('access_key')
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
@schema_property('last_used')
def last_used(self):
return self._property_last_used
@last_used.setter
def last_used(self, value):
if value is None:
self._property_last_used = None
return
self.assert_isinstance(value, "last_used", six.string_types + (datetime,))
if not isinstance(value, datetime):
value = parse_datetime(value)
self._property_last_used = value
@schema_property('last_used_from')
def last_used_from(self):
return self._property_last_used_from
@last_used_from.setter
def last_used_from(self, value):
if value is None:
self._property_last_used_from = None
return
self.assert_isinstance(value, "last_used_from", six.string_types)
self._property_last_used_from = value
class CreateCredentialsRequest(Request):
"""
Creates a new set of credentials for the authenticated user.
New key/secret is returned.
Note: Secret will never be returned in any other API call.
If a secret is lost or compromised, the key should be revoked
and a new set of credentials can be created.
"""
_service = "auth"
_action = "create_credentials"
_version = "2.1"
_schema = {
'additionalProperties': False,
'definitions': {},
'properties': {},
'type': 'object',
}
class CreateCredentialsResponse(Response):
"""
Response of auth.create_credentials endpoint.
:param credentials: Created credentials
:type credentials: Credentials
"""
_service = "auth"
_action = "create_credentials"
_version = "2.1"
_schema = {
'definitions': {
'credentials': {
'properties': {
'access_key': {
'description': 'Credentials access key',
'type': ['string', 'null'],
},
'secret_key': {
'description': 'Credentials secret key',
'type': ['string', 'null'],
},
},
'type': 'object',
},
},
'properties': {
'credentials': {
'description': 'Created credentials',
'oneOf': [{'$ref': '#/definitions/credentials'}, {'type': 'null'}],
},
},
'type': 'object',
}
def __init__(
self, credentials=None, **kwargs):
super(CreateCredentialsResponse, self).__init__(**kwargs)
self.credentials = credentials
@schema_property('credentials')
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
if isinstance(value, dict):
value = Credentials.from_dict(value)
else:
self.assert_isinstance(value, "credentials", Credentials)
self._property_credentials = value
class EditUserRequest(Request):
"""
Edit a users' auth data properties
:param user: User ID
:type user: str
:param role: The new user's role within the company
:type role: str
"""
_service = "auth"
_action = "edit_user"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'role': {
'description': "The new user's role within the company",
'enum': ['admin', 'superuser', 'user', 'annotator'],
'type': ['string', 'null'],
},
'user': {'description': 'User ID', 'type': ['string', 'null']},
},
'type': 'object',
}
def __init__(
self, user=None, role=None, **kwargs):
super(EditUserRequest, self).__init__(**kwargs)
self.user = user
self.role = role
@schema_property('user')
def user(self):
return self._property_user
@user.setter
def user(self, value):
if value is None:
self._property_user = None
return
self.assert_isinstance(value, "user", six.string_types)
self._property_user = value
@schema_property('role')
def role(self):
return self._property_role
@role.setter
def role(self, value):
if value is None:
self._property_role = None
return
self.assert_isinstance(value, "role", six.string_types)
self._property_role = value
class EditUserResponse(Response):
"""
Response of auth.edit_user endpoint.
:param updated: Number of users updated (0 or 1)
:type updated: float
:param fields: Updated fields names and values
:type fields: dict
"""
_service = "auth"
_action = "edit_user"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'fields': {
'additionalProperties': True,
'description': 'Updated fields names and values',
'type': ['object', 'null'],
},
'updated': {
'description': 'Number of users updated (0 or 1)',
'enum': [0, 1],
'type': ['number', 'null'],
},
},
'type': 'object',
}
def __init__(
self, updated=None, fields=None, **kwargs):
super(EditUserResponse, self).__init__(**kwargs)
self.updated = updated
self.fields = fields
@schema_property('updated')
def updated(self):
return self._property_updated
@updated.setter
def updated(self, value):
if value is None:
self._property_updated = None
return
self.assert_isinstance(value, "updated", six.integer_types + (float,))
self._property_updated = value
@schema_property('fields')
def fields(self):
return self._property_fields
@fields.setter
def fields(self, value):
if value is None:
self._property_fields = None
return
self.assert_isinstance(value, "fields", (dict,))
self._property_fields = value
class GetCredentialsRequest(Request):
"""
Returns all existing credential keys for the authenticated user.
Note: Only credential keys are returned.
"""
_service = "auth"
_action = "get_credentials"
_version = "2.1"
_schema = {
'additionalProperties': False,
'definitions': {},
'properties': {},
'type': 'object',
}
class GetCredentialsResponse(Response):
"""
Response of auth.get_credentials endpoint.
:param credentials: List of credentials, each with an empty secret field.
:type credentials: Sequence[CredentialKey]
"""
_service = "auth"
_action = "get_credentials"
_version = "2.1"
_schema = {
'definitions': {
'credential_key': {
'properties': {
'access_key': {'description': '', 'type': ['string', 'null']},
'last_used': {
'description': '',
'format': 'date-time',
'type': ['string', 'null'],
},
'last_used_from': {
'description': '',
'type': ['string', 'null'],
},
},
'type': 'object',
},
},
'properties': {
'credentials': {
'description': 'List of credentials, each with an empty secret field.',
'items': {'$ref': '#/definitions/credential_key'},
'type': ['array', 'null'],
},
},
'type': 'object',
}
def __init__(
self, credentials=None, **kwargs):
super(GetCredentialsResponse, self).__init__(**kwargs)
self.credentials = credentials
@schema_property('credentials')
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
self.assert_isinstance(value, "credentials", (list, tuple))
if any(isinstance(v, dict) for v in value):
value = [CredentialKey.from_dict(v) if isinstance(v, dict) else v for v in value]
else:
self.assert_isinstance(value, "credentials", CredentialKey, is_array=True)
self._property_credentials = value
class LoginRequest(Request):
"""
Get a token based on supplied credentials (key/secret).
Intended for use by users with key/secret credentials that wish to obtain a token
for use with other services. Token will be limited by the same permissions that
exist for the credentials used in this call.
:param expiration_sec: Requested token expiration time in seconds. Not
guaranteed, might be overridden by the service
:type expiration_sec: int
"""
_service = "auth"
_action = "login"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'expiration_sec': {
'description': 'Requested token expiration time in seconds. \n Not guaranteed, might be overridden by the service',
'type': ['integer', 'null'],
},
},
'type': 'object',
}
def __init__(
self, expiration_sec=None, **kwargs):
super(LoginRequest, self).__init__(**kwargs)
self.expiration_sec = expiration_sec
@schema_property('expiration_sec')
def expiration_sec(self):
return self._property_expiration_sec
@expiration_sec.setter
def expiration_sec(self, value):
if value is None:
self._property_expiration_sec = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "expiration_sec", six.integer_types)
self._property_expiration_sec = value
class LoginResponse(Response):
"""
Response of auth.login endpoint.
:param token: Token string
:type token: str
"""
_service = "auth"
_action = "login"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'token': {'description': 'Token string', 'type': ['string', 'null']},
},
'type': 'object',
}
def __init__(
self, token=None, **kwargs):
super(LoginResponse, self).__init__(**kwargs)
self.token = token
@schema_property('token')
def token(self):
return self._property_token
@token.setter
def token(self, value):
if value is None:
self._property_token = None
return
self.assert_isinstance(value, "token", six.string_types)
self._property_token = value
class LogoutRequest(Request):
"""
Removes the authentication cookie from the current session
"""
_service = "auth"
_action = "logout"
_version = "2.2"
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
class LogoutResponse(Response):
"""
Response of auth.logout endpoint.
"""
_service = "auth"
_action = "logout"
_version = "2.2"
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
class RevokeCredentialsRequest(Request):
"""
Revokes (and deletes) a set (key, secret) of credentials for
the authenticated user.
:param access_key: Credentials key
:type access_key: str
"""
_service = "auth"
_action = "revoke_credentials"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'access_key': {
'description': 'Credentials key',
'type': ['string', 'null'],
},
},
'required': ['key_id'],
'type': 'object',
}
def __init__(
self, access_key=None, **kwargs):
super(RevokeCredentialsRequest, self).__init__(**kwargs)
self.access_key = access_key
@schema_property('access_key')
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
class RevokeCredentialsResponse(Response):
"""
Response of auth.revoke_credentials endpoint.
:param revoked: Number of credentials revoked
:type revoked: int
"""
_service = "auth"
_action = "revoke_credentials"
_version = "2.1"
_schema = {
'definitions': {},
'properties': {
'revoked': {
'description': 'Number of credentials revoked',
'enum': [0, 1],
'type': ['integer', 'null'],
},
},
'type': 'object',
}
def __init__(
self, revoked=None, **kwargs):
super(RevokeCredentialsResponse, self).__init__(**kwargs)
self.revoked = revoked
@schema_property('revoked')
def revoked(self):
return self._property_revoked
@revoked.setter
def revoked(self, value):
if value is None:
self._property_revoked = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "revoked", six.integer_types)
self._property_revoked = value
response_mapping = {
LoginRequest: LoginResponse,
LogoutRequest: LogoutResponse,
CreateCredentialsRequest: CreateCredentialsResponse,
GetCredentialsRequest: GetCredentialsResponse,
RevokeCredentialsRequest: RevokeCredentialsResponse,
EditUserRequest: EditUserResponse,
}

View File

@@ -0,0 +1,194 @@
"""
debug service
Debugging utilities
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class ApiexRequest(Request):
"""
"""
_service = "debug"
_action = "apiex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
class ApiexResponse(Response):
"""
Response of debug.apiex endpoint.
"""
_service = "debug"
_action = "apiex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class EchoRequest(Request):
"""
Return request data
"""
_service = "debug"
_action = "echo"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class EchoResponse(Response):
"""
Response of debug.echo endpoint.
"""
_service = "debug"
_action = "echo"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class ExRequest(Request):
"""
"""
_service = "debug"
_action = "ex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
class ExResponse(Response):
"""
Response of debug.ex endpoint.
"""
_service = "debug"
_action = "ex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingRequest(Request):
"""
Return a message. Does not require authorization.
"""
_service = "debug"
_action = "ping"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingResponse(Response):
"""
Response of debug.ping endpoint.
:param msg: A friendly message
:type msg: str
"""
_service = "debug"
_action = "ping"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'msg': {
'description': 'A friendly message',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, msg=None, **kwargs):
super(PingResponse, self).__init__(**kwargs)
self.msg = msg
@schema_property('msg')
def msg(self):
return self._property_msg
@msg.setter
def msg(self, value):
if value is None:
self._property_msg = None
return
self.assert_isinstance(value, "msg", six.string_types)
self._property_msg = value
class PingAuthRequest(Request):
"""
Return a message. Requires authorization.
"""
_service = "debug"
_action = "ping_auth"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingAuthResponse(Response):
"""
Response of debug.ping_auth endpoint.
:param msg: A friendly message
:type msg: str
"""
_service = "debug"
_action = "ping_auth"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'msg': {
'description': 'A friendly message',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, msg=None, **kwargs):
super(PingAuthResponse, self).__init__(**kwargs)
self.msg = msg
@schema_property('msg')
def msg(self):
return self._property_msg
@msg.setter
def msg(self, value):
if value is None:
self._property_msg = None
return
self.assert_isinstance(value, "msg", six.string_types)
self._property_msg = value
response_mapping = {
EchoRequest: EchoResponse,
PingRequest: PingResponse,
PingAuthRequest: PingAuthResponse,
ApiexRequest: ApiexResponse,
ExRequest: ExResponse,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from .session import Session
from .datamodel import DataModel, NonStrictDataModel, schema_property, StringEnum
from .request import Request, BatchRequest, CompoundRequest
from .response import Response
from .token_manager import TokenManager
from .errors import TimeoutExpiredError, ResultNotReadyError
from .callresult import CallResult

View File

@@ -0,0 +1,8 @@
from .datamodel import DataModel
class ApiModel(DataModel):
""" API-related data model """
_service = None
_action = None
_version = None

View File

@@ -0,0 +1,131 @@
import sys
import time
from ...backend_api.utils import get_response_cls
from .response import ResponseMeta, Response
from .errors import ResultNotReadyError, TimeoutExpiredError
class CallResult(object):
@property
def meta(self):
return self.__meta
@property
def response(self):
return self.__response
@property
def response_data(self):
return self.__response_data
@property
def async_accepted(self):
return self.meta.result_code == 202
@property
def request_cls(self):
return self.__request_cls
def __init__(self, meta, response=None, response_data=None, request_cls=None, session=None):
assert isinstance(meta, ResponseMeta)
if response and not isinstance(response, Response):
raise ValueError('response should be an instance of %s' % Response.__name__)
elif response_data and not isinstance(response_data, dict):
raise TypeError('data should be an instance of {}'.format(dict.__name__))
self.__meta = meta
self.__response = response
self.__request_cls = request_cls
self.__session = session
self.__async_result = None
if response_data is not None:
self.__response_data = response_data
elif response is not None:
try:
self.__response_data = response.to_dict()
except AttributeError:
raise TypeError('response should be an instance of {}'.format(Response.__name__))
else:
self.__response_data = None
@classmethod
def from_result(cls, res, request_cls=None, logger=None, service=None, action=None, session=None):
""" From requests result """
response_cls = get_response_cls(request_cls)
try:
data = res.json()
except ValueError:
service = service or (request_cls._service if request_cls else 'unknown')
action = action or (request_cls._action if request_cls else 'unknown')
return cls(request_cls=request_cls, meta=ResponseMeta.from_raw_data(
status_code=res.status_code, text=res.text, endpoint='%(service)s.%(action)s' % locals()))
if 'meta' not in data:
raise ValueError('Missing meta section in response payload')
try:
meta = ResponseMeta(**data['meta'])
# TODO: validate meta?
# meta.validate()
except Exception as ex:
raise ValueError('Failed parsing meta section in response payload (data=%s, error=%s)' % (data, ex))
response = None
response_data = None
try:
response_data = data.get('data', {})
if response_cls:
response = response_cls(**response_data)
# TODO: validate response?
# response.validate()
except Exception as e:
if logger:
logger.warning('Failed parsing response: %s' % str(e))
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
def ok(self):
return self.meta.result_code == 200
def ready(self):
if not self.async_accepted:
return True
session = self.__session
res = session.send_request(service='async', action='result', json=dict(id=self.meta.id), async_enable=False)
if res.status_code != session._async_status_code:
self.__async_result = CallResult.from_result(res=res, request_cls=self.request_cls, logger=session._logger)
return True
def result(self):
if not self.async_accepted:
return self
if self.__async_result is None:
raise ResultNotReadyError(self._format_msg('Timeout expired'), call_id=self.meta.id)
return self.__async_result
def wait(self, timeout=None, poll_interval=5, verbose=False):
if not self.async_accepted:
return self
session = self.__session
poll_interval = max(1, poll_interval)
remaining = max(0, timeout) if timeout else sys.maxsize
while remaining > 0:
if not self.ready():
# Still pending, log and continue
if verbose and session._logger:
progress = ('waiting forever'
if timeout is False
else '%.1f/%.1f seconds remaining' % (remaining, float(timeout or 0)))
session._logger.info('Waiting for asynchronous call %s (%s)'
% (self.request_cls.__name__, progress))
time.sleep(poll_interval)
remaining -= poll_interval
continue
# We've got something (good or bad, we don't know), create a call result and return
return self.result()
# Timeout expired, return the asynchronous call's result (we've got nothing better to report)
raise TimeoutExpiredError(self._format_msg('Timeout expired'), call_id=self.meta.id)
def _format_msg(self, msg):
return msg + ' for call %s (%s)' % (self.request_cls.__name__, self.meta.id)

View File

@@ -0,0 +1 @@
from .client import APIClient, StrictSession, APIError

View File

@@ -0,0 +1,530 @@
from __future__ import unicode_literals
import abc
import os
from argparse import Namespace
from collections import OrderedDict
from enum import Enum
from functools import reduce, wraps, WRAPPER_ASSIGNMENTS
from importlib import import_module
from itertools import chain
from operator import itemgetter
from types import ModuleType
from typing import Dict, Text, Tuple, Type, Any, Sequence
import six
from ... import services as api_services
from ....backend_api.session import CallResult
from ....backend_api.session import Session, Request as APIRequest
from ....backend_api.session.response import ResponseMeta
from ....backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR
SERVICE_TO_ENTITY_CLASS_NAMES = {"storage": "StorageItem"}
def entity_class_name(service):
# type: (ModuleType) -> Text
service_name = api_entity_name(service)
return SERVICE_TO_ENTITY_CLASS_NAMES.get(service_name.lower(), service_name)
def api_entity_name(service):
return module_name(service).rstrip("s")
@six.python_2_unicode_compatible
class APIError(Exception):
"""
Class for representing an API error.
self.data - ``dict`` of all returned JSON data
self.code - HTTP response code
self.subcode - server response subcode
self.codes - (self.code, self.subcode) tuple
self.message - result message sent from server
"""
def __init__(self, response, extra_info=None):
"""
Create a new APIError from a server response
"""
super(APIError, self).__init__()
self._response = response # type: CallResult
self.extra_info = extra_info
self.data = response.response_data # type: Dict
self.meta = response.meta # type: ResponseMeta
self.code = response.meta.result_code # type: int
self.subcode = response.meta.result_subcode # type: int
self.message = response.meta.result_msg # type: Text
self.codes = (self.code, self.subcode) # type: Tuple[int, int]
def get_traceback(self):
"""
Return server traceback for error, or None if doesn't exist.
"""
try:
return self.meta.error_stack
except AttributeError:
return None
def __str__(self):
message = "{}: ".format(type(self).__name__)
if self.extra_info:
message += "{}: ".format(self.extra_info)
if not self.meta:
message += "no meta available"
return message
if not self.code:
message += "no error code available"
return message
message += "code {0.code}".format(self)
if self.subcode:
message += "/{.subcode}".format(self)
if self.message:
message += ": {.message}".format(self)
return message
class StrictSession(Session):
"""
Session that raises exceptions on errors, and be configured with explicit ``config_file`` path.
"""
def __init__(self, config_file=None, initialize_logging=False, *args, **kwargs):
"""
:param config_file: configuration file to use, else use the default
:type config_file: Path | Text
"""
def init():
super(StrictSession, self).__init__(
initialize_logging=initialize_logging, *args, **kwargs
)
if not config_file:
init()
return
original = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None)
try:
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = str(config_file)
init()
finally:
if original is None:
os.environ.pop(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None)
else:
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = original
def send(self, request, *args, **kwargs):
result = super(StrictSession, self).send(request, *args, **kwargs)
if not result.ok():
raise APIError(result)
if not result.response:
raise APIError(result, extra_info="Invalid response")
return result
class Response(object):
"""
Proxy object for API result data.
Exposes "meta" of the original result.
"""
def __init__(self, result, dest=None):
"""
:param result: result of endpoint call
:type result: CallResult
:param dest: if all of a response's data is contained in one field, use that field
:type dest: Text
"""
self._result = result
response = getattr(result, "response", result)
if dest:
response = getattr(response, dest)
self.response = response
def __getattr__(self, attr):
return getattr(self.response, attr)
@property
def meta(self):
return self._result.meta
def __repr__(self):
return repr(self.response)
def __dir__(self):
fields = [
name
for name in dir(self.response)
if isinstance(getattr(type(self.response), name, None), property)
]
return list(set(chain(super(Response, self).__dir__(), fields)) - {"response"})
@six.python_2_unicode_compatible
class TableResponse(Response):
"""
Representation of result containing an array of entities
"""
def __init__(
self,
service, # type: Service
entity, # type: Type[entity]
fields=None, # type: Sequence[Text]
*args,
**kwargs
):
"""
:param service: service of entity
:param entity: class representing entity
:param fields: entity attributes requested by client
"""
super(TableResponse, self).__init__(*args, **kwargs)
self.service = service
self.entity = entity
self.fields = fields or ("id", "name")
self.response = [entity(service, item) for item in self]
def __repr__(self, fields=None):
return self._format_table(fields=fields)
__str__ = __repr__
def _format_table(self, fields=None):
"""
Display <fields> attributes of each element in a table
:param fields:
"""
def getter(obj, attr):
result = reduce(
lambda x, name: x if x is None else getattr(x, name, None),
attr.split("."),
obj,
)
return "" if result is None else result
fields = fields or self.fields
from trains_agent.helper.base import create_table
return create_table(
(tuple(getter(item, attr) for attr in fields) for item in self),
titles=fields, headers=True,
)
def display(self, fields=None):
print(self._format_table(fields=fields))
def where(self, predicate=None, **kwargs):
"""
Filter items.
<predicate> is a callable from a single item to a boolean. Items for which <predicate> is True will be returned.
Keyword arguments are interpreted as attribute equivalence, meaning:
>>> datasets.where(name='foo')
will return only datasets whose name is "foo".
Giving more than one condition (predicate and keyword arguments) establishes an "and" relation.
"""
def compare_enum(x, y):
return x == y or isinstance(x, Enum) and x.value == y
return TableResponse(
self.service,
self.entity,
self.fields,
[
item
for item in self
if (not predicate or predicate(item))
and all(
compare_enum(getattr(item, key), value)
for key, value in kwargs.items()
)
],
)
def __getitem__(self, item):
return self.response[item]
def __iter__(self):
return iter(self.response)
def __len__(self):
return len(self.response)
@six.add_metaclass(abc.ABCMeta)
class Entity(object):
"""
Represent a server object.
Enables calls like:
>>> entity = client.service.get_by_id(entity_id)
>>> entity.action(**kwargs)
instead of:
>>> client.service.action(id=entity_id, **kwargs)
"""
@abc.abstractproperty
def entity_name(self): # type: () -> Text
"""
Singular name of entity
"""
pass
@abc.abstractproperty
def get_by_id_request(self): # type: () -> Type[APIRequest]
"""
get_by_id request class
"""
pass
def __init__(self, service, data):
self._service = service
self.data = getattr(data, self.entity_name, data)
self.__doc__ = self.data.__doc__
def fetch(self):
"""
Update the entity data from the server.
"""
result = self._service.session.send(self.get_by_id_request(self.data.id))
self.data = getattr(result.response, self.entity_name)
def _get_default_kwargs(self):
return {self.entity_name: self.data.id}
def __getattr__(self, attr):
"""
Inject the entity's ID to the method call.
All missing properties are assumed to be functions.
"""
try:
return getattr(self.data, attr)
except AttributeError:
pass
func = getattr(self._service, attr)
@wrap_request_class(func)
def new_func(*args, **kwargs):
kwargs = dict(self._get_default_kwargs(), **kwargs)
return func(*args, **kwargs)
return new_func
def __dir__(self):
"""
Add ``self._service``'s methods to ``dir()`` result.
"""
try:
dir_ = super(Entity, self).__dir__
except AttributeError:
base = self.__dict__
else:
base = dir_()
return list(set(base).union(dir(self._service), dir(self.data)))
def __repr__(self):
"""
Display entity type, ID, and - if available - name.
"""
parts = (type(self).__name__, ": ", "id={}".format(self.data.id))
try:
parts += (", ", 'name="{}"'.format(self.data.name))
except AttributeError:
pass
return "<{}>".format("".join(parts))
def wrap_request_class(cls):
return wraps(cls, assigned=WRAPPER_ASSIGNMENTS + ("from_dict",))
def make_action(service, request_cls):
action = request_cls._action
try:
get_by_id_request = service.GetByIdRequest
except AttributeError:
get_by_id_request = None
wrap = wrap_request_class(request_cls)
if action not in ["get_all", "get_all_ex", "get_by_id", "create"]:
@wrap
def new_func(self, *args, **kwargs):
return Response(self.session.send(request_cls(*args, **kwargs)))
new_func.__name__ = new_func.__qualname__ = action
return new_func
entity_name = api_entity_name(service)
class_name = entity_class_name(service).capitalize()
properties = {
"__module__": __name__,
"entity_name": entity_name.lower(),
"get_by_id_request": get_by_id_request,
}
entity = type(str(class_name), (Entity,), properties)
if action == "get_by_id":
@wrap
def get(self, *args, **kwargs):
return entity(
self, self.session.send(request_cls(*args, **kwargs)).response
)
elif action == "create":
@wrap
def get(self, *args, **kwargs):
return entity(
self,
Namespace(
id=self.session.send(request_cls(*args, **kwargs)).response.id
),
)
elif action in ["get_all", "get_all_ex"]:
dest = service.response_mapping[request_cls]._get_data_props().popitem()[0]
@wrap
def get(self, *args, **kwargs):
return TableResponse(
service=self,
entity=entity,
result=self.session.send(request_cls(*args, **kwargs)),
dest=dest,
fields=kwargs.pop("only_fields", None),
)
else:
assert False
get.__name__ = get.__qualname__ = action
return get
@six.add_metaclass(abc.ABCMeta)
class Service(object):
"""
Superclass for action-grouping classes.
"""
name = abc.abstractproperty()
__doc__ = abc.abstractproperty()
def __init__(self, session):
self.session = session
def get_requests(service):
return OrderedDict(
(key, value)
for key, value in sorted(vars(service).items(), key=itemgetter(0))
if isinstance(value, type) and issubclass(value, APIRequest) and value._action
)
def make_service_class(module):
# type: (...) -> Type[Service]
"""
Create a service class from service module.
"""
properties = OrderedDict(
[
("__module__", __name__),
("__doc__", module.__doc__),
("name", module_name(module)),
]
)
properties.update(
(f.__name__, f)
for f in (
make_action(module, value) for key, value in get_requests(module).items()
)
)
# noinspection PyTypeChecker
return type(str(module_name(module)), (Service,), properties)
def module_name(module):
try:
module = module.__name__
except AttributeError:
pass
base_name = module.split(".")[-1]
return "".join(s.capitalize() for s in base_name.split("_"))
class Version(Entity):
entity_name = "version"
get_by_id_request = None
def fetch(self):
try:
published = self.data.status == "published"
except AttributeError:
published = False
self.data = self._service.get_versions(
dataset=self.dataset, only_published=published, versions=[self.id]
)[0].data
def _get_default_kwargs(self):
return dict(
super(Version, self)._get_default_kwargs(), **{"dataset": self.data.dataset}
)
class APIClient(object):
auth = None # type: Any
debug = None # type: Any
queues = None # type: Any
tasks = None # type: Any
workers = None # type: Any
def __init__(self, session=None, api_version=None):
self.session = session or StrictSession()
def import_(*args, **kwargs):
try:
return import_module(*args, **kwargs)
except ImportError:
return None
if api_version:
api_version = "v{}".format(str(api_version).replace(".", "_"))
services = OrderedDict(
(name, mod)
for name, mod in (
(
name,
import_(".".join((api_services.__name__, api_version, name))),
)
for name in api_services.__all__
)
if mod
)
else:
services = OrderedDict(
(name, getattr(api_services, name)) for name in api_services.__all__
)
self.__dict__.update(
dict(
{
name: make_service_class(module)(self.session)
for name, module in services.items()
},
)
)

View File

@@ -0,0 +1,148 @@
import keyword
import enum
import json
import warnings
from datetime import datetime
import jsonschema
from enum import Enum
import six
def format_date(obj):
if isinstance(obj, datetime):
return str(obj)
class SchemaProperty(property):
def __init__(self, name=None, *args, **kwargs):
super(SchemaProperty, self).__init__(*args, **kwargs)
self.name = name
def setter(self, fset):
return type(self)(self.name, self.fget, fset, self.fdel, self.__doc__)
def schema_property(name):
def init(*args, **kwargs):
return SchemaProperty(name, *args, **kwargs)
return init
class DataModel(object):
""" Data Model"""
_schema = None
_data_props_list = None
@classmethod
def _get_data_props(cls):
props = cls._data_props_list
if props is None:
props = {}
for c in cls.__mro__:
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
if isinstance(v, property)})
cls._data_props_list = props
return props.copy()
@classmethod
def _to_base_type(cls, value):
if isinstance(value, DataModel):
return value.to_dict()
elif isinstance(value, enum.Enum):
return value.value
elif isinstance(value, list):
return [cls._to_base_type(model) for model in value]
return value
def to_dict(self, only=None, except_=None):
prop_values = {v: getattr(self, k) for k, v in self._get_data_props().items()}
return {
k: self._to_base_type(v)
for k, v in prop_values.items()
if v is not None and (not only or k in only) and (not except_ or k not in except_)
}
def validate(self, schema=None):
jsonschema.validate(
self.to_dict(),
schema or self._schema,
types=dict(array=(list, tuple), integer=six.integer_types),
)
def __repr__(self):
return '<{}.{}: {}>'.format(
self.__module__.split('.')[-1],
type(self).__name__,
json.dumps(
self.to_dict(),
indent=4,
default=format_date,
)
)
@staticmethod
def assert_isinstance(value, field_name, expected, is_array=False):
if not is_array:
if not isinstance(value, expected):
raise TypeError("Expected %s of type %s, got %s" % (field_name, expected, type(value).__name__))
return
if not all(isinstance(x, expected) for x in value):
raise TypeError(
"Expected %s of type list[%s], got %s" % (
field_name,
expected,
", ".join(set(type(x).__name__ for x in value)),
)
)
@staticmethod
def normalize_key(prop_key):
if keyword.iskeyword(prop_key):
prop_key += '_'
return prop_key.replace('.', '__')
@classmethod
def from_dict(cls, dct, strict=False):
"""
Create an instance from a dictionary while ignoring unnecessary keys
"""
allowed_keys = cls._get_data_props().values()
invalid_keys = set(dct).difference(allowed_keys)
if strict and invalid_keys:
raise ValueError("Invalid keys %s" % tuple(invalid_keys))
return cls(**{cls.normalize_key(key): value for key, value in dct.items() if key not in invalid_keys})
class UnusedKwargsWarning(UserWarning):
pass
class NonStrictDataModelMixin(object):
"""
NonStrictDataModelMixin
:summary: supplies an __init__ method that warns about unused keywords
"""
def __init__(self, **kwargs):
# unexpected = [key for key in kwargs if not key.startswith('_')]
# if unexpected:
# message = '{}: unused keyword argument(s) {}' \
# .format(type(self).__name__, unexpected)
# warnings.warn(message, UnusedKwargsWarning)
# ignore extra data warnings
pass
class NonStrictDataModel(DataModel, NonStrictDataModelMixin):
pass
class StringEnum(Enum):
def __str__(self):
return self.value

View File

@@ -0,0 +1,10 @@
from ...backend_config import EnvEntry
ENV_HOST = EnvEntry("TRAINS_API_HOST", "TRAINS_API_HOST")
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "TRAINS_WEB_HOST")
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "TRAINS_FILES_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)

View File

@@ -0,0 +1,17 @@
class SessionError(Exception):
pass
class AsyncError(SessionError):
def __init__(self, msg, *args, **kwargs):
super(AsyncError, self).__init__(msg, *args)
for k, v in kwargs.items():
setattr(self, k, v)
class TimeoutExpiredError(SessionError):
pass
class ResultNotReadyError(SessionError):
pass

View File

@@ -0,0 +1,76 @@
import abc
import jsonschema
import six
from .apimodel import ApiModel
from .datamodel import DataModel
class Request(ApiModel):
_method = 'get'
def __init__(self, **kwargs):
if kwargs:
raise ValueError('Unsupported keyword arguments: %s' % ', '.join(kwargs.keys()))
@six.add_metaclass(abc.ABCMeta)
class BatchRequest(Request):
_batched_request_cls = abc.abstractproperty()
_schema_errors = (jsonschema.SchemaError, jsonschema.ValidationError, jsonschema.FormatError,
jsonschema.RefResolutionError)
def __init__(self, requests, validate_requests=False, allow_raw_requests=True, **kwargs):
super(BatchRequest, self).__init__(**kwargs)
self._validate_requests = validate_requests
self._allow_raw_requests = allow_raw_requests
self._property_requests = None
self.requests = requests
@property
def requests(self):
return self._property_requests
@requests.setter
def requests(self, value):
assert issubclass(self._batched_request_cls, Request)
assert isinstance(value, (list, tuple))
if not self._allow_raw_requests:
if any(isinstance(x, dict) for x in value):
value = [self._batched_request_cls(**x) if isinstance(x, dict) else x for x in value]
assert all(isinstance(x, self._batched_request_cls) for x in value)
self._property_requests = value
def validate(self):
if not self._validate_requests or self._allow_raw_requests:
return
for i, req in enumerate(self.requests):
try:
req.validate()
except (jsonschema.SchemaError, jsonschema.ValidationError,
jsonschema.FormatError, jsonschema.RefResolutionError) as e:
raise Exception('Validation error in batch item #%d: %s' % (i, str(e)))
def get_json(self):
return [r if isinstance(r, dict) else r.to_dict() for r in self.requests]
class CompoundRequest(Request):
_item_prop_name = 'item'
def _get_item(self):
item = getattr(self, self._item_prop_name, None)
if item is None:
raise ValueError('Item property is empty or missing')
assert isinstance(item, DataModel)
return item
def to_dict(self):
return self._get_item().to_dict()
def validate(self):
return self._get_item().validate(self._schema)

View File

@@ -0,0 +1,57 @@
import requests
import six
import jsonmodels.models
import jsonmodels.fields
import jsonmodels.errors
from .apimodel import ApiModel
from .datamodel import NonStrictDataModelMixin
class FloatOrStringField(jsonmodels.fields.BaseField):
"""String field."""
types = (float, six.string_types,)
class Response(ApiModel, NonStrictDataModelMixin):
pass
class _ResponseEndpoint(jsonmodels.models.Base):
name = jsonmodels.fields.StringField()
requested_version = FloatOrStringField()
actual_version = FloatOrStringField()
class ResponseMeta(jsonmodels.models.Base):
@property
def is_valid(self):
return self._is_valid
@classmethod
def from_raw_data(cls, status_code, text, endpoint=None):
return cls(is_valid=False, result_code=status_code, result_subcode=0, result_msg=text,
endpoint=_ResponseEndpoint(name=(endpoint or 'unknown')))
def __init__(self, is_valid=True, **kwargs):
super(ResponseMeta, self).__init__(**kwargs)
self._is_valid = is_valid
id = jsonmodels.fields.StringField(required=True)
trx = jsonmodels.fields.StringField(required=True)
endpoint = jsonmodels.fields.EmbeddedField([_ResponseEndpoint], required=True)
result_code = jsonmodels.fields.IntField(required=True)
result_subcode = jsonmodels.fields.IntField()
result_msg = jsonmodels.fields.StringField(required=True)
error_stack = jsonmodels.fields.StringField()
def __str__(self):
if self.result_code == requests.codes.ok:
return "<%d: %s/v%s>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version)
elif self._is_valid:
return "<%d/%d: %s/v%s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name,
self.endpoint.actual_version, self.result_msg)
return "<%d/%d: %s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, self.result_msg)

View File

@@ -0,0 +1,551 @@
import json as json_lib
import sys
import types
from socket import gethostname
from six.moves.urllib.parse import urlparse, urlunparse
import jwt
import requests
import six
from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth
from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ...version import __version__
class LoginError(Exception):
pass
class MaxRequestSizeError(Exception):
pass
class Session(TokenManager):
""" TRAINS API Session class. """
_AUTHORIZATION_HEADER = "Authorization"
_WORKER_HEADER = "X-Trains-Worker"
_ASYNC_HEADER = "X-Trains-Async"
_CLIENT_HEADER = "X-Trains-Agent"
_async_status_code = 202
_session_requests = 0
_session_initial_timeout = (3.0, 10.)
_session_timeout = (10.0, 300.)
_write_session_data_size = 15000
_write_session_timeout = (300.0, 300.)
api_version = '2.1'
default_host = "https://demoapi.trainsai.io"
default_web = "https://demoapp.trainsai.io"
default_files = "https://demofiles.trainsai.io"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
# TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [
requests.codes.bad_gateway,
requests.codes.service_unavailable,
requests.codes.bandwidth_limit_exceeded,
requests.codes.too_many_requests,
]
@property
def access_key(self):
return self.__access_key
@property
def secret_key(self):
return self.__secret_key
@property
def host(self):
return self.__host
@property
def worker(self):
return self.__worker
def __init__(
self,
worker=None,
api_key=None,
secret_key=None,
host=None,
logger=None,
verbose=None,
initialize_logging=True,
client=None,
config=None,
**kwargs
):
if config is not None:
self.config = config
else:
self.config = load()
if initialize_logging:
self.config.initialize_logging()
token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60
)
super(Session, self).__init__(
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
)
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
)
if not self.access_key:
raise ValueError(
"Missing access_key. Please set in configuration file or pass in session init."
)
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
)
if not self.secret_key:
raise ValueError(
"Missing secret_key. Please set in configuration file or pass in session init."
)
host = host or self.get_api_server_host(config=self.config)
if not host:
raise ValueError("host is required in init or config")
self.__host = host.strip("/")
http_retries_config = self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
self.__http_session = get_http_session_with_retry(**http_retries_config)
self.__worker = worker or gethostname()
self.__max_req_size = self.config.get("api.http.max_req_size", None)
if not self.__max_req_size:
raise ValueError("missing max request size")
self.client = client or "api-{}".format(__version__)
self.refresh_token()
# update api version from server response
try:
token_dict = jwt.decode(self.token, verify=False)
api_version = token_dict.get('api_version')
if not api_version:
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
Session.api_version = str(api_version)
except (jwt.DecodeError, ValueError):
pass
# now setup the session reporting, so one consecutive retries will show warning
# we do that here, so if we have problems authenticating, we see them immediately
# notice: this is across the board warning omission
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
def _send_request(
self,
service,
action,
version=None,
method="get",
headers=None,
auth=None,
data=None,
json=None,
refresh_token_if_unauthorized=True,
):
""" Internal implementation for making a raw API request.
- Constructs the api endpoint name
- Injects the worker id into the headers
- Allows custom authorization using a requests auth object
- Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this
case (only once). This is done since permissions are embedded in the token, and addresses a case where
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
generate a token with the updated permissions.
"""
host = self.host
headers = headers.copy() if headers else {}
headers[self._WORKER_HEADER] = self.worker
headers[self._CLIENT_HEADER] = self.client
token_refreshed_on_error = False
url = (
"{host}/v{version}/{service}.{action}"
if version
else "{host}/{service}.{action}"
).format(**locals())
while True:
if data and len(data) > self._write_session_data_size:
timeout = self._write_session_timeout
elif self._session_requests < 1:
timeout = self._session_initial_timeout
else:
timeout = self._session_timeout
res = self.__http_session.request(
method, url, headers=headers, auth=auth, data=data, json=json, timeout=timeout)
if (
refresh_token_if_unauthorized
and res.status_code == requests.codes.unauthorized
and not token_refreshed_on_error
):
# it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since
# the last time we got the token, and try again
self.refresh_token()
token_refreshed_on_error = True
# try again
continue
if (
res.status_code == requests.codes.service_unavailable
and self.config.get("api.http.wait_on_maintenance_forever", True)
):
self._logger.warning(
"Service unavailable: {} is undergoing maintenance, retrying...".format(
host
)
)
continue
break
self._session_requests += 1
return res
def add_auth_headers(self, headers):
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
return headers
def send_request(
self,
service,
action,
version=None,
method="get",
headers=None,
data=None,
json=None,
async_enable=False,
):
"""
Send a raw API request.
:param service: service name
:param action: action name
:param version: version number (default is the preconfigured api version)
:param method: method type (default is 'get')
:param headers: request headers (authorization and content type headers will be automatically added)
:param json: json to send in the request body (jsonable object or builtin types construct. if used,
content type will be application/json)
:param data: Dictionary, bytes, or file-like object to send in the request body
:param async_enable: whether request is asynchronous
:return: requests Response instance
"""
headers = self.add_auth_headers(
headers.copy() if headers else {}
)
if async_enable:
headers[self._ASYNC_HEADER] = "1"
return self._send_request(
service=service,
action=action,
version=version,
method=method,
headers=headers,
data=data,
json=json,
)
def send_request_batch(
self,
service,
action,
version=None,
headers=None,
data=None,
json=None,
method="get",
):
"""
Send a raw batch API request. Batch requests always use application/json-lines content type.
:param service: service name
:param action: action name
:param version: version number (default is the preconfigured api version)
:param headers: request headers (authorization and content type headers will be automatically added)
:param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will
be sent as a multi-line payload in the request body.
:param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the
request body.
:param method: HTTP method
:return: requests Response instance
"""
if not all(
isinstance(x, (list, tuple, type(None), types.GeneratorType))
for x in (data, json)
):
raise ValueError("Expecting list, tuple or generator in 'data' or 'json'")
if not data and not json:
raise ValueError(
"Missing data (data or json), batch requests are meaningless without it."
)
headers = headers.copy() if headers else {}
headers["Content-Type"] = "application/json-lines"
if data:
req_data = "\n".join(data)
else:
req_data = "\n".join(json_lib.dumps(x) for x in json)
cur = 0
results = []
while True:
size = self.__max_req_size
slice = req_data[cur: cur + size]
if not slice:
break
if len(slice) < size:
# this is the remainder, no need to search for newline
pass
elif slice[-1] != "\n":
# search for the last newline in order to send a coherent request
size = slice.rfind("\n") + 1
# readjust the slice
slice = req_data[cur: cur + size]
if not slice:
raise MaxRequestSizeError('Error: {}.{} request exceeds limit {} > {} bytes'.format(
service, action, len(req_data), self.__max_req_size))
res = self.send_request(
method=method,
service=service,
action=action,
data=slice,
headers=headers,
version=version,
)
results.append(res)
if res.status_code != requests.codes.ok:
break
cur += size
return results
def validate_request(self, req_obj):
""" Validate an API request against the current version and the request's schema """
try:
# make sure we're using a compatible version for this request
# validate the request (checks required fields and specific field version restrictions)
validate = req_obj.validate
except AttributeError:
raise TypeError(
'"req_obj" parameter must be an backend_api.session.Request object'
)
validate()
def send_async(self, req_obj):
"""
Asynchronously sends an API request using a request object.
:param req_obj: The request object
:type req_obj: Request
:return: CallResult object containing the raw response, response metadata and parsed response object.
"""
return self.send(req_obj=req_obj, async_enable=True)
def send(self, req_obj, async_enable=False, headers=None):
"""
Sends an API request using a request object.
:param req_obj: The request object
:type req_obj: Request
:param async_enable: Request this method be executed in an asynchronous manner
:param headers: Additional headers to send with request
:return: CallResult object containing the raw response, response metadata and parsed response object.
"""
self.validate_request(req_obj)
if isinstance(req_obj, BatchRequest):
# TODO: support async for batch requests as well
if async_enable:
raise NotImplementedError(
"Async behavior is currently not implemented for batch requests"
)
json_data = req_obj.get_json()
res = self.send_request_batch(
service=req_obj._service,
action=req_obj._action,
version=req_obj._version,
json=json_data,
method=req_obj._method,
headers=headers,
)
# TODO: handle multiple results in this case
try:
res = next(r for r in res if r.status_code != 200)
except StopIteration:
# all are 200
res = res[0]
else:
res = self.send_request(
service=req_obj._service,
action=req_obj._action,
version=req_obj._version,
json=req_obj.to_dict(),
method=req_obj._method,
async_enable=async_enable,
headers=headers,
)
call_result = CallResult.from_result(
res=res,
request_cls=req_obj.__class__,
logger=self._logger,
service=req_obj._service,
action=req_obj._action,
session=self,
)
return call_result
@classmethod
def get_api_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return ENV_HOST.get(default=(config.get("api.api_server", None) or
config.get("api.host", None) or cls.default_host))
@classmethod
def get_app_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
if web_host:
return web_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_web
# compose ourselves
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080', 1)
raise ValueError('Could not detect TRAINS web application server')
@classmethod
def get_files_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
if files_host:
return files_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_files
# compose ourselves
app_host = cls.get_app_server_host(config)
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
return urlunparse(parsed)
@classmethod
def check_min_api_version(cls, min_api_version):
"""
Return True if Session.api_version is greater or equal >= to min_api_version
"""
def version_tuple(v):
v = tuple(map(int, (v.split("."))))
return v + (0,) * max(0, 3 - len(v))
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.
"""
verbose = self._verbose and self._logger
if verbose:
self._logger.info(
"Refreshing token from {} (access_key={}, exp={})".format(
self.host, self.access_key, exp
)
)
auth = HTTPBasicAuth(self.access_key, self.secret_key)
res = None
try:
data = {"expiration_sec": exp} if exp else {}
res = self._send_request(
service="auth",
action="login",
auth=auth,
json=data,
refresh_token_if_unauthorized=False,
)
try:
resp = res.json()
except ValueError:
resp = {}
if res.status_code != 200:
msg = resp.get("meta", {}).get("result_msg", res.reason)
raise LoginError(
"Failed getting token (error {} from {}): {}".format(
res.status_code, self.host, msg
)
)
if verbose:
self._logger.info("Received new token")
return resp["data"]["token"]
except LoginError:
six.reraise(*sys.exc_info())
except KeyError as ex:
# check if this is a misconfigured api server (getting 200 without the data section)
if res and res.status_code == 200:
raise ValueError('It seems *api_server* is misconfigured. '
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
else:
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
"exception: {}".format(res, ex))
except Exception as ex:
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
def __str__(self):
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5)
)

View File

@@ -0,0 +1,95 @@
import sys
from abc import ABCMeta, abstractmethod
from time import time
import jwt
import six
@six.add_metaclass(ABCMeta)
class TokenManager(object):
@property
def token_expiration_threshold_sec(self):
return self.__token_expiration_threshold_sec
@token_expiration_threshold_sec.setter
def token_expiration_threshold_sec(self, value):
self.__token_expiration_threshold_sec = value
@property
def req_token_expiration_sec(self):
""" Token expiration sec requested when refreshing token """
return self.__req_token_expiration_sec
@req_token_expiration_sec.setter
def req_token_expiration_sec(self, value):
assert isinstance(value, (type(None), int))
self.__req_token_expiration_sec = value
@property
def token_expiration_sec(self):
return self.__token_expiration_sec
@property
def token(self):
return self._get_token()
@property
def raw_token(self):
return self.__token
def __init__(
self,
token=None,
req_token_expiration_sec=None,
token_history=None,
token_expiration_threshold_sec=60,
**kwargs
):
super(TokenManager, self).__init__()
assert isinstance(token_history, (type(None), dict))
self.token_expiration_threshold_sec = token_expiration_threshold_sec
self.req_token_expiration_sec = req_token_expiration_sec
self._set_token(token)
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
if token:
try:
exp = exp or self._get_token_exp(token)
if at_least_sec:
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
else:
at_least_sec = self.token_expiration_threshold_sec
return max(0, (exp - time() - at_least_sec))
except Exception:
pass
return 0
@classmethod
def _get_token_exp(cls, token):
""" Get token expiration time. If not present, assume forever """
return jwt.decode(token, verify=False).get('exp', sys.maxsize)
def _set_token(self, token):
if token:
self.__token = token
self.__token_expiration_sec = self._get_token_exp(token)
else:
self.__token = None
self.__token_expiration_sec = 0
def get_token_valid_period_sec(self):
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
def _get_token(self):
if self.get_token_valid_period_sec() <= 0:
self.refresh_token()
return self.__token
@abstractmethod
def _do_refresh_token(self, old_token, exp=None):
pass
def refresh_token(self):
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))

View File

@@ -0,0 +1,132 @@
import logging
import ssl
import sys
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from urllib3 import PoolManager
import six
from .session.defs import ENV_HOST_VERIFY_CERT
if six.PY3:
from functools import lru_cache
elif six.PY2:
# python 2 support
from backports.functools_lru_cache import lru_cache
__disable_certificate_verification_warning = 0
class _RetryFilter(logging.Filter):
last_instance = None
def __init__(self, total, warning_after=5):
super(_RetryFilter, self).__init__()
self.total = total
self.display_warning_after = warning_after
_RetryFilter.last_instance = self
def filter(self, record):
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
left = (record.args[0].total, record.args[0].connect, record.args[0].read,
record.args[0].redirect, record.args[0].status)
left = [l for l in left if isinstance(l, int)]
if left:
retry_left = max(left) - min(left)
return retry_left >= self.display_warning_after
return True
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
for l in ('urllib3.connectionpool', 'requests.packages.urllib3.connectionpool'):
urllib3_log = logging.getLogger(l)
if urllib3_log:
urllib3_log.removeFilter(_RetryFilter.last_instance)
urllib3_log.addFilter(_RetryFilter(total_retries, display_warning_after))
class TLSv1HTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
self.poolmanager = PoolManager(num_pools=connections,
maxsize=maxsize,
block=block,
ssl_version=ssl.PROTOCOL_TLSv1_2)
def get_http_session_with_retry(
total=0,
connect=None,
read=None,
redirect=None,
status=None,
status_forcelist=None,
backoff_factor=0,
backoff_max=None,
pool_connections=None,
pool_maxsize=None,
config=None):
if not config:
config = {}
global __disable_certificate_verification_warning
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
raise ValueError('Bad configuration. All retry count values must be null or int')
if status_forcelist and not all(isinstance(x, int) for x in status_forcelist):
raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints')
pool_maxsize = (
pool_maxsize
if pool_maxsize is not None
else config.get('api.http.pool_maxsize', 512)
)
pool_connections = (
pool_connections
if pool_connections is not None
else config.get('api.http.pool_connections', 512)
)
session = requests.Session()
if backoff_max is not None:
Retry.BACKOFF_MAX = backoff_max
retry = Retry(
total=total, connect=connect, read=read, redirect=redirect, status=status,
status_forcelist=status_forcelist, backoff_factor=backoff_factor)
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter)
session.mount('https://', adapter)
# update verify host certificate
session.verify = ENV_HOST_VERIFY_CERT.get(default=config.get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning
__disable_certificate_verification_warning += 1
logging.getLogger('TRAINS').warning(
msg='InsecureRequestWarning: Certificate verification is disabled! Adding '
'certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings')
# make sure we only do not see the warning
import urllib3
# noinspection PyBroadException
try:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except Exception:
pass
return session
def get_response_cls(request_cls):
""" Extract a request's response class using the mapping found in the module defining the request's service """
for req_cls in request_cls.mro():
module = sys.modules[req_cls.__module__]
if hasattr(module, 'action_mapping'):
return module.action_mapping[(request_cls._action, request_cls._version)][1]
elif hasattr(module, 'response_mapping'):
return module.response_mapping[req_cls]
raise TypeError('no response class!')

View File

@@ -0,0 +1,4 @@
from .defs import Environment
from .config import Config, ConfigEntry
from .errors import ConfigurationError
from .environment import EnvEntry

View File

@@ -0,0 +1,340 @@
from __future__ import print_function
import functools
import json
import os
import sys
import warnings
from fnmatch import fnmatch
from os.path import expanduser
from typing import Any
import pyhocon
import six
from pathlib2 import Path
from pyhocon import ConfigTree
from pyparsing import (
ParseFatalException,
ParseException,
RecursiveGrammarException,
ParseSyntaxException,
)
from .defs import (
Environment,
DEFAULT_CONFIG_FOLDER,
LOCAL_CONFIG_PATHS,
ENV_CONFIG_PATHS,
LOCAL_CONFIG_FILES,
LOCAL_CONFIG_FILE_OVERRIDE_VAR,
ENV_CONFIG_PATH_OVERRIDE_VAR,
)
from .defs import is_config_file
from .entry import Entry, NotSet
from .errors import ConfigurationError
from .log import initialize as initialize_log, logger
from .utils import get_options
try:
from typing import Text
except ImportError:
# windows conda-less hack
Text = Any
log = logger(__file__)
class ConfigEntry(Entry):
logger = None
def __init__(self, config, *keys, **kwargs):
# type: (Config, Text, Any) -> None
super(ConfigEntry, self).__init__(*keys, **kwargs)
self.config = config
def _get(self, key):
# type: (Text) -> Any
return self.config.get(key, NotSet)
def error(self, message):
# type: (Text) -> None
log.error(message.capitalize())
class Config(object):
"""
Represents a server configuration.
If watch=True, will watch configuration folders for changes and reload itself.
NOTE: will not watch folders that were created after initialization.
"""
# used in place of None in Config.get as default value because None is a valid value
_MISSING = object()
def __init__(
self,
config_folder=None,
env=None,
verbose=True,
relative_to=None,
app=None,
is_server=False,
**_
):
self._app = app
self._verbose = verbose
self._folder_name = config_folder or DEFAULT_CONFIG_FOLDER
self._roots = []
self._config = ConfigTree()
self._env = env or os.environ.get("TRAINS_ENV", Environment.default)
self.config_paths = set()
self.is_server = is_server
if self._verbose:
print("Config env:%s" % str(self._env))
if not self._env:
raise ValueError(
"Missing environment in either init of environment variable"
)
if self._env not in get_options(Environment):
raise ValueError("Invalid environment %s" % env)
if relative_to is not None:
self.load_relative_to(relative_to)
@property
def root(self):
return self.roots[0] if self.roots else None
@property
def roots(self):
return self._roots
@roots.setter
def roots(self, value):
self._roots = value
@property
def env(self):
return self._env
def logger(self, path=None):
return logger(path)
def load_relative_to(self, *module_paths):
def normalize(p):
return Path(os.path.abspath(str(p))).with_name(self._folder_name)
self.roots = list(map(normalize, module_paths))
self.reload()
def _reload(self):
env = self._env
config = self._config.copy()
if self.is_server:
env_config_paths = ENV_CONFIG_PATHS
else:
env_config_paths = []
env_config_path_override = os.environ.get(ENV_CONFIG_PATH_OVERRIDE_VAR)
if env_config_path_override:
env_config_paths = [expanduser(env_config_path_override)]
# merge configuration from root and other environment config paths
if self.roots or env_config_paths:
config = functools.reduce(
lambda cfg, path: ConfigTree.merge_configs(
cfg,
self._read_recursive_for_env(path, env, verbose=self._verbose),
copy_trees=True,
),
self.roots + env_config_paths,
config,
)
# merge configuration from local configuration paths
if LOCAL_CONFIG_PATHS:
config = functools.reduce(
lambda cfg, path: ConfigTree.merge_configs(
cfg, self._read_recursive(path, verbose=self._verbose), copy_trees=True
),
LOCAL_CONFIG_PATHS,
config,
)
local_config_files = LOCAL_CONFIG_FILES
local_config_override = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR)
if local_config_override:
local_config_files = [expanduser(local_config_override)]
# merge configuration from local configuration files
if local_config_files:
config = functools.reduce(
lambda cfg, file_path: ConfigTree.merge_configs(
cfg,
self._read_single_file(file_path, verbose=self._verbose),
copy_trees=True,
),
local_config_files,
config,
)
config["env"] = env
return config
def replace(self, config):
self._config = config
def reload(self):
self.replace(self._reload())
def initialize_logging(self):
logging_config = self._config.get("logging", None)
if not logging_config:
return False
# handle incomplete file handlers
deleted = []
handlers = logging_config.get("handlers", {})
for name, handler in list(handlers.items()):
cls = handler.get("class", None)
is_file = cls and "FileHandler" in cls
if cls is None or (is_file and "filename" not in handler):
deleted.append(name)
del handlers[name]
elif is_file:
file = Path(handler.get("filename"))
if not file.is_file():
file.parent.mkdir(parents=True, exist_ok=True)
file.touch()
# remove dependency in deleted handlers
root_logger = logging_config.get("root", None)
loggers = list(logging_config.get("loggers", {}).values()) + (
[root_logger] if root_logger else []
)
for logger in loggers:
handlers = logger.get("handlers", None)
if not handlers:
continue
logger["handlers"] = [h for h in handlers if h not in deleted]
extra = None
if self._app:
extra = {"app": self._app}
initialize_log(logging_config, extra=extra)
return True
def __getitem__(self, key):
try:
return self._config[key]
except:
return None
def __getattr__(self, key):
c = self.__getattribute__('_config')
if key.split('.')[0] in c:
try:
return c[key]
except Exception:
return None
return getattr(c, key)
def get(self, key, default=_MISSING):
value = self._config.get(key, default)
if value is self._MISSING and not default:
raise KeyError(
"Unable to find value for key '{}' and default value was not provided.".format(
key
)
)
return value
def to_dict(self):
return self._config.as_plain_ordered_dict()
def as_json(self):
return json.dumps(self.to_dict(), indent=2)
def _read_recursive_for_env(self, root_path_str, env, verbose=True):
root_path = Path(root_path_str)
if root_path.exists():
default_config = self._read_recursive(
root_path / Environment.default, verbose=verbose
)
if (root_path / env) != (root_path / Environment.default):
env_config = self._read_recursive(
root_path / env, verbose=verbose
) # None is ok, will return empty config
config = ConfigTree.merge_configs(default_config, env_config, True)
else:
config = default_config
else:
config = ConfigTree()
return config
def _read_recursive(self, conf_root, verbose=True):
conf = ConfigTree()
if not conf_root:
return conf
conf_root = Path(conf_root)
if not conf_root.exists():
if verbose:
print("No config in %s" % str(conf_root))
return conf
if verbose:
print("Loading config from %s" % str(conf_root))
for root, dirs, files in os.walk(str(conf_root)):
rel_dir = str(Path(root).relative_to(conf_root))
if rel_dir == ".":
rel_dir = ""
prefix = rel_dir.replace("/", ".")
for filename in files:
if not is_config_file(filename):
continue
if prefix != "":
key = prefix + "." + Path(filename).stem
else:
key = Path(filename).stem
file_path = str(Path(root) / filename)
conf.put(key, self._read_single_file(file_path, verbose=verbose))
return conf
@staticmethod
def _read_single_file(file_path, verbose=True):
if not file_path or not Path(file_path).is_file():
return ConfigTree()
if verbose:
print("Loading config from file %s" % file_path)
try:
return pyhocon.ConfigFactory.parse_file(file_path)
except ParseSyntaxException as ex:
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
file_path, ex
)
six.reraise(
ConfigurationError,
ConfigurationError(msg, file_path=file_path),
sys.exc_info()[2],
)
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
msg = "Failed parsing {0} ({1.__class__.__name__}): {1}".format(
file_path, ex
)
six.reraise(ConfigurationError, ConfigurationError(msg), sys.exc_info()[2])
except Exception as ex:
print("Failed loading %s: %s" % (file_path, ex))
raise

View File

@@ -0,0 +1,53 @@
import base64
from distutils.util import strtobool
from typing import Union, Optional, Any, TypeVar, Callable, Tuple
import six
try:
from typing import Text
except ImportError:
# windows conda-less hack
Text = Any
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
def base64_to_text(value):
# type: (Any) -> Text
return base64.b64decode(value).decode("utf-8")
def text_to_bool(value):
# type: (Text) -> bool
return bool(strtobool(value))
def any_to_bool(value):
# type: (Optional[Union[int, float, Text]]) -> bool
if isinstance(value, six.text_type):
return text_to_bool(value)
return bool(value)
def or_(*converters, **kwargs):
# type: (ConverterType, Tuple[Exception, ...]) -> ConverterType
"""
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
for which a set of exceptions is ignored (and the original value is returned)
:param converters: A converter callable
:param exceptions: A tuple of exception types to ignore
"""
# noinspection PyUnresolvedReferences
exceptions = kwargs.get("exceptions", (ValueError, TypeError))
def wrapper(value):
for converter in converters:
try:
return converter(value)
except exceptions:
pass
return value
return wrapper

View File

@@ -0,0 +1,53 @@
from os.path import expanduser
from pathlib2 import Path
ENV_VAR = 'TRAINS_ENV'
""" Name of system environment variable that can be used to specify the config environment name """
DEFAULT_CONFIG_FOLDER = 'config'
""" Default config folder to search for when loading relative to a given path """
ENV_CONFIG_PATHS = [
]
""" Environment-related config paths """
LOCAL_CONFIG_PATHS = [
# '/etc/opt/trains', # used by servers for docker-generated configuration
# expanduser('~/.trains/config'),
]
""" Local config paths, not related to environment """
LOCAL_CONFIG_FILES = [
expanduser('~/trains.conf'), # used for workstation configuration (end-users, workers)
]
""" Local config files (not paths) """
LOCAL_CONFIG_FILE_OVERRIDE_VAR = 'TRAINS_CONFIG_FILE'
""" Local config file override environment variable. If this is set, no other local config files will be used. """
ENV_CONFIG_PATH_OVERRIDE_VAR = 'TRAINS_CONFIG_PATH'
"""
Environment-related config path override environment variable. If this is set, no other env config path will be used.
"""
class Environment(object):
""" Supported environment names """
default = 'default'
demo = 'demo'
local = 'local'
CONFIG_FILE_EXTENSION = '.conf'
def is_config_file(path):
return Path(path).suffix == CONFIG_FILE_EXTENSION

View File

@@ -0,0 +1,103 @@
import abc
from typing import Optional, Any, Tuple, Callable, Dict
import six
from .converters import any_to_bool
try:
from typing import Text
except ImportError:
# windows conda-less hack
Text = Any
NotSet = object()
Converter = Callable[[Any], Any]
@six.add_metaclass(abc.ABCMeta)
class Entry(object):
"""
Configuration entry definition
"""
@classmethod
def default_conversions(cls):
# type: () -> Dict[Any, Converter]
return {
bool: any_to_bool,
six.text_type: lambda s: six.text_type(s).strip(),
}
def __init__(self, key, *more_keys, **kwargs):
# type: (Text, Text, Any) -> None
"""
:param key: Entry's key (at least one).
:param more_keys: More alternate keys for this entry.
:param type: Value type. If provided, will be used choosing a default conversion or
(if none exists) for casting the environment value.
:param converter: Value converter. If provided, will be used to convert the environment value.
:param default: Default value. If provided, will be used as the default value on calls to get() and get_pair()
in case no value is found for any key and no specific default value was provided in the call.
Default value is None.
:param help: Help text describing this entry
"""
self.keys = (key,) + more_keys
self.type = kwargs.pop("type", six.text_type)
self.converter = kwargs.pop("converter", None)
self.default = kwargs.pop("default", None)
self.help = kwargs.pop("help", None)
def __str__(self):
return str(self.key)
@property
def key(self):
return self.keys[0]
def convert(self, value, converter=None):
# type: (Any, Converter) -> Optional[Any]
converter = converter or self.converter
if not converter:
converter = self.default_conversions().get(self.type, self.type)
return converter(value)
def get_pair(self, default=NotSet, converter=None):
# type: (Any, Converter) -> Optional[Tuple[Text, Any]]
for key in self.keys:
value = self._get(key)
if value is NotSet:
continue
try:
value = self.convert(value, converter)
except Exception as ex:
self.error("invalid value {key}={value}: {ex}".format(**locals()))
break
return key, value
result = self.default if default is NotSet else default
return self.key, result
def get(self, default=NotSet, converter=None):
# type: (Any, Converter) -> Optional[Any]
return self.get_pair(default=default, converter=converter)[1]
def set(self, value):
# type: (Any, Any) -> (Text, Any)
key, _ = self.get_pair(default=None, converter=None)
self._set(key, str(value))
def _set(self, key, value):
# type: (Text, Text) -> None
pass
@abc.abstractmethod
def _get(self, key):
# type: (Text) -> Any
pass
@abc.abstractmethod
def error(self, message):
# type: (Text) -> None
pass

View File

@@ -0,0 +1,25 @@
from os import getenv, environ
from .converters import text_to_bool
from .entry import Entry, NotSet
class EnvEntry(Entry):
@classmethod
def default_conversions(cls):
conversions = super(EnvEntry, cls).default_conversions().copy()
conversions[bool] = text_to_bool
return conversions
def _get(self, key):
value = getenv(key, "").strip()
return value or NotSet
def _set(self, key, value):
environ[key] = value
def __str__(self):
return "env:{}".format(super(EnvEntry, self).__str__())
def error(self, message):
print("Environment configuration: {}".format(message))

View File

@@ -0,0 +1,5 @@
class ConfigurationError(Exception):
def __init__(self, msg, file_path=None, *args):
super(ConfigurationError, self).__init__(msg, *args)
self.file_path = file_path

View File

@@ -0,0 +1,30 @@
import logging.config
from pathlib2 import Path
def logger(path=None):
name = "trains"
if path:
p = Path(path)
module = (p.parent if p.stem.startswith('_') else p).stem
name = "trains.%s" % module
return logging.getLogger(name)
def initialize(logging_config=None, extra=None):
if extra is not None:
from logging import Logger
class _Logger(Logger):
__extra = extra.copy()
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
extra = extra or {}
extra.update(self.__extra)
super(_Logger, self)._log(level, msg, args, exc_info=exc_info, extra=extra, **kwargs)
Logger.manager.loggerClass = _Logger
if logging_config is not None:
logging.config.dictConfig(dict(logging_config))

View File

@@ -0,0 +1,9 @@
def get_items(cls):
""" get key/value items from an enum-like class (members represent enumeration key/value) """
return {k: v for k, v in vars(cls).items() if not k.startswith('_')}
def get_options(cls):
""" get options from an enum-like class (members represent enumeration key/value) """
return get_items(cls).values()

View File

@@ -0,0 +1,3 @@
from __future__ import print_function
from .worker import Worker

View File

@@ -0,0 +1,401 @@
from __future__ import unicode_literals, print_function
import copy
import re
import sys
from abc import abstractmethod
from functools import wraps
from operator import attrgetter
from traceback import print_exc
from typing import Text
from trains_agent.helper.console import ListFormatter, print_text
from trains_agent.helper.dicts import filter_keys
try:
from typing import NoReturn
except ImportError:
from typing_extensions import NoReturn
import six
from trains_agent.backend_api import services
from trains_agent.errors import APIError, CommandFailedError
from trains_agent.helper.base import Singleton, return_list, print_parameters, dump_yaml, load_yaml, error, warning
from trains_agent.interface.base import ObjectID
from trains_agent.session import Session
class NameResolutionError(CommandFailedError):
def __init__(self, message, suggestions=''):
super(NameResolutionError, self).__init__(message)
self.message = message
self.suggestions = suggestions
def __str__(self):
return self.message + self.suggestions
def resolve_names(func):
def safe_resolve(command, arg):
try:
result = command._resolve_name(arg.name, arg.service)
return result, None
except NameResolutionError:
return arg.name, sys.exc_info()
def _resolve_single(command, arg):
if isinstance(arg, ObjectID):
return command._resolve_name(arg.name, arg.service)
elif isinstance(arg, (list, tuple)) and all(isinstance(x, ObjectID) for x in arg):
result = [safe_resolve(command, x) for x in arg]
if len(result) == 1:
name, ex = result[0]
if ex:
six.reraise(*ex)
return [name]
for _, ex in result:
if ex:
command.warning(ex[1].message)
return [name for (name, _) in result]
return arg
@wraps(func)
def newfunc(self, *args, **kwargs):
args = [_resolve_single(self, arg) for arg in args]
kwargs = {key: _resolve_single(self, value) for key, value in kwargs.items()}
return func(self, *args, **kwargs)
return newfunc
class BaseCommandSection(object):
"""
Base class for command sections which do not interact with the allegro API.
Has basic utilities for user interaction.
"""
warning = staticmethod(warning)
error = staticmethod(error)
@staticmethod
def log(message, *args):
print("trains-agent: {}".format(message % args))
@classmethod
def exit(cls, message, code=1): # type: (Text, int) -> NoReturn
cls.error(message)
sys.exit(code)
@six.add_metaclass(Singleton)
class ServiceCommandSection(BaseCommandSection):
"""
Base class for command sections which interact with the allegro API.
Contains API functionality which is common across services.
"""
_worker_name = None
MAX_SUGGESTIONS = 10
def __init__(self, *args, **kwargs):
super(ServiceCommandSection, self).__init__()
self._session = self._get_session(*args, **kwargs)
self._list_formatter = ListFormatter(self.service)
@staticmethod
def _get_session(*args, **kwargs):
return Session(*args, **kwargs)
@property
@abstractmethod
def service(self):
""" The name of the REST service used by this command """
pass
def get(self, endpoint, *args, **kwargs):
return self._session.get(service=self.service, action=endpoint, *args, **kwargs)
def post(self, endpoint, *args, **kwargs):
return self._session.post(service=self.service, action=endpoint, *args, **kwargs)
def get_with_act_as(self, endpoint, *args, **kwargs):
return self._session.get_with_act_as(service=self.service, action=endpoint, *args, **kwargs)
@property
def name(self):
return self.service.title()
@property
def name_single(self):
return self.name.rstrip('s')
@property
def service_single(self):
return self.service.rstrip('s')
@resolve_names
def __info(self, id=None, yaml=None, **kwargs):
ids = return_list(id)
if not ids:
return
yaml_dump = {}
for i in ids:
get_fields = {self.service_single: i}
try:
info = self.get('get_by_id', **get_fields)
yaml_dump[i] = info[self.service_single]
except APIError:
self.error('Failed retrieving info for {} {}'.format(self.service_single, i))
self.output_info(yaml_dump, yaml_path=yaml, **kwargs)
return yaml_dump
@resolve_names
def _info(self, *args, **kwargs):
self.__info(*args, **kwargs)
@staticmethod
def output_info(entries, quiet=False, yaml_path=None, **_):
if not quiet and entries:
print_parameters(entries, indent=4)
if yaml_path:
print('Storing entries to [{}]'.format(yaml_path))
dump_yaml(entries, yaml_path)
@staticmethod
def _make_query(json, table, sort=None, projection_from_table=False, extra_fields=None):
json = json.copy()
if isinstance(table, six.string_types):
table = table.split('#')
if extra_fields:
table.extend(extra_fields)
if projection_from_table:
json['only_fields'] = table
if sort:
# does nothing if 'order_by' is not in get_fields
json['order_by'] = sort.split('#')[0]
return json, table
def _get_all(self, endpoint, json, retpoint=None):
return self.get(endpoint, **json).get(retpoint or self.service, [])
@resolve_names
def _update(self,
endpoint='update',
send_diff=False,
quiet=False,
primary_key='id',
override=None,
model_desc=None,
yaml=None,
**kwargs):
if not yaml and primary_key not in kwargs:
raise ValueError('Update must supply either yaml file or %s-id' % self.service_single)
data_entries = {}
original_data_entries = {}
if yaml:
data_entries = load_yaml(yaml)
if send_diff or (not yaml and primary_key in kwargs):
i = kwargs.get(primary_key) or next(iter(data_entries))
original_info = self.__info(id=i, quiet=True)[i]
if send_diff:
original_data_entries[i] = copy.deepcopy(original_info)
if yaml and i not in data_entries:
if len(data_entries) > 1:
raise ValueError(
'Error: yaml file [%s] contains more than one task id' % yaml)
first_key = next(iter(data_entries))
if first_key != i:
if kwargs.get('force'):
if not quiet:
print('Warning: overriding yaml task id [%s] with id=%s' % (first_key, i))
else:
raise ValueError(
'Error: yaml task id [%s] != id [%s], use --force to override' % (first_key, i))
data_entries = {i: data_entries[first_key]}
data_entries[i][primary_key] = i
elif not yaml:
data_entries[i] = kwargs
if model_desc:
first_key = next(iter(data_entries))
with open(model_desc) as f:
proto_data = f.read()
info = data_entries[first_key]
info['execution']['model_desc']['prototxt'] = proto_data
if override:
first_key = next(iter(data_entries))
info = data_entries[first_key]
for p in override:
key, val = p.split('=') if isinstance(p, six.string_types) else p
info_key = info
keys = key.split('.')
for k in keys[:-1]:
if not info_key.get(k):
info_key[k] = dict()
info_key = info_key[k]
info_key[keys[-1]] = val
# always make sure tags is a list of strings
# split string to tokens ':'
# examples tags='auto_generated:draft'
if info.get('tags') is not None:
# remove empty strings from list
info['tags'] = [t for t in info['tags'].split(':') if t]
# send only change set
if send_diff:
# only send the values that changed
for i, info in data_entries.items():
org_info = original_data_entries[i]
out_info = {}
recursive_diff(org_info, info, out_info)
data_entries[i] = out_info
for i, info in data_entries.items():
if not info or len(info) == 0 or list(info) == [primary_key]:
if not quiet:
print('Skipping: nothing to update for %s id [%s]' % (self.service_single, i))
continue
if not quiet:
print('Updating %s id [%s]' % (self.service_single, i))
info[self.service_single] = i
result = self.get(endpoint, **info)
if not result['updated']:
raise ValueError('Failed updating %s id [%s]' % (self.service_single, i))
if not quiet:
print('%s [%s] updated fields: %s' % (self.name_single, i, result.get('fields', '')))
@resolve_names
def remove(self, ids, **kwargs):
return self._apply_command(
request_cls=getattr(services, self.service).DeleteRequest,
object_ids=ids,
response_validation_field='deleted',
**kwargs
)
def _apply_command(self, request_cls, object_ids, response_validation_field=None, **kwargs):
object_ids = return_list(object_ids)
def call_one(object_id):
error_message = '[{object_id}]: failed'.format(**locals())
try:
response = self._session.send_api(request_cls(object_id, **kwargs))
except APIError as e:
if not self._session.debug_mode:
self.error('{}: {}'.format(error_message, e))
else:
traceback = e.format_traceback()
if traceback:
print(traceback)
print('Own traceback:')
print_exc()
return False
if not response_validation_field or getattr(response, response_validation_field) == 1:
return True
else:
self.error(error_message)
return False
succeeded = [call_one(object_id) for object_id in object_ids].count(True)
message = '{}/{} succeeded'.format(succeeded, len(object_ids))
(self.log if succeeded == len(object_ids) else self.exit)(message)
def get_service(self, service_class):
return service_class(config=self._session.config)
def _resolve_name(self, name, service=None):
"""
Resolve an object name to an object ID.
Operation:
- If the argument "looks like" an ID, return it.
- Else, get all object with names containing the argument
- if an object with the argument as its name exists, return the object's ID
- Else, print a list of suggestions and exit
:param str name: ID (returned unmodified) or Name to resolve
:param str service: Service to resolve from (type of object). Defaults to service represented by the class
:return: ID of object
:rtype: str
"""
service = service or self.service
if re.match(r'^[-a-f0-9]{30,}$', name):
return name
try:
request_cls = getattr(services, service).GetAllRequest
except AttributeError:
raise NameResolutionError('Name resolution unavailable for {}'.format(service))
request = request_cls.from_dict(dict(name=name, only_fields=['name', 'id']))
# from_dict will ignore unrecognised keyword arguments - not all GetAll's have only_fields
response = getattr(self._session.send_api(request), service)
matches = [db_object for db_object in response if name.lower() == db_object.name.lower()]
def truncated_bullet_list(format_string, elements, callback, **kwargs):
if len(elements) > self.MAX_SUGGESTIONS:
kwargs.update(
dict(details=' (showing {}/{})'.format(self.MAX_SUGGESTIONS, len(elements)), suffix='\n...'))
else:
kwargs.update(dict(details='', suffix=''))
bullet_list = '\n'.join('* {}'.format(callback(item)) for item in elements[:self.MAX_SUGGESTIONS])
return format_string.format(bullet_list, **kwargs)
if len(matches) == 1:
return matches.pop().id
elif len(matches) > 1:
message = truncated_bullet_list(
'Found multiple {service} with name "{name}"{details}:\n{}{suffix}',
matches,
callback=attrgetter('id'),
**locals())
self.exit(message)
message = 'Could not find {} with name "{}"'.format(service.rstrip('s'), name)
if not response:
raise NameResolutionError(message)
suggestions = truncated_bullet_list(
'. Did you mean this?{details}\n{}{suffix}',
sorted(response, key=attrgetter('name')),
lambda db_object: '({}) {}'.format(db_object.id, db_object.name)
)
raise NameResolutionError(message, suggestions)
def recursive_diff(org, upd, out):
if isinstance(upd, dict) and isinstance(org, dict):
diff_keys = [
k for k in upd
if k not in org or upd[k] != org[k]
]
if diff_keys:
has_nested_dict = False
for k in diff_keys:
if isinstance(upd[k], dict):
out[k] = {}
has_nested_dict = True
k_has_nested = recursive_diff(
org.get(k, {}), upd[k], out[k])
if not k_has_nested:
out[k] = upd[k]
elif upd[k] is not None:
out[k] = upd[k]
return has_nested_dict
elif isinstance(upd, list) and isinstance(org, list):
diff_list = [k for k in upd if k not in org]
out.extend(diff_list)
return False

View File

@@ -0,0 +1,215 @@
from __future__ import print_function
from six.moves import input
from pyhocon import ConfigFactory
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from trains_agent.backend_api.session.defs import ENV_HOST
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
description = """
Please create new credentials using the web app: {}/profile
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
Paste credentials here: """
def_host = 'http://localhost:8080'
try:
def_host = ENV_HOST.get(default=def_host) or def_host
except Exception:
pass
host_description = """
Editing configuration file: {CONFIG_FILE}
Enter the url of the trains-server's Web service, for example: {HOST}
""".format(
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
HOST=def_host,
)
def main():
print('TRAINS-AGENT setup process')
conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute()
if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0:
print('Configuration file already exists: {}'.format(str(conf_file)))
print('Leaving setup, feel free to edit the configuration file.')
return
print(host_description)
web_host = input_url('Web Application Host', '')
parsed_host = verify_url(web_host)
if parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('app.'):
# this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'):
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'):
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
else:
api_host = ''
web_host = ''
files_host = ''
if not parsed_host.port:
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
replace_port = input().lower()
if not replace_port or replace_port == 'y' or replace_port == 'yes':
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
api_host = input_url('API Host', api_host)
files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
api_host, web_host, files_host))
while True:
print(description.format(web_host), end='')
parse_input = input()
# check if these are valid credentials
credentials = None
# noinspection PyBroadException
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
credentials = parsed.get("credentials", None)
except Exception:
credentials = None
if not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse user credentials, try again one after the other.')
credentials = {}
# parse individual
print('Enter user access key: ', end='')
credentials['access_key'] = input()
print('Enter user secret: ', end='')
credentials['secret_key'] = input()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'], ))
from trains_agent.backend_api.session import Session
# noinspection PyBroadException
try:
print('Verifying credentials ...')
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
print('Credentials verified!')
break
except Exception:
print('Error: could not verify credentials: host={} access={} secret={}'.format(
api_host, credentials['access_key'], credentials['secret_key']))
# get GIT User/Pass for cloning
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
git_user = input()
if git_user.strip():
print('Enter password for user \'{}\': '.format(git_user), end='')
git_pass = input()
print('Git repository cloning will be using user={} password={}'.format(git_user, git_pass))
else:
git_user = None
git_pass = None
# noinspection PyBroadException
try:
conf_folder = Path(__file__).parent.absolute() / '..' / 'backend_api' / 'config' / 'default'
default_conf = ''
for conf in ('agent.conf', 'sdk.conf', ):
conf_file_section = conf_folder / conf
with open(str(conf_file_section), 'rt') as f:
default_conf += conf.split('.')[0] + ' '
default_conf += f.read()
default_conf += '\n'
except Exception:
print('Error! Could not read default configuration file')
return
# noinspection PyBroadException
try:
with open(str(conf_file), 'wt') as f:
header = '# TRAINS-AGENT configuration file\n' \
'api {\n' \
' api_server: %s\n' \
' web_server: %s\n' \
' files_server: %s\n' \
' # Credentials are generated in the webapp, %s/profile\n' \
' # Override with os environment: TRAINS_API_ACCESS_KEY / TRAINS_API_SECRET_KEY\n' \
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
'}\n\n' % (api_host, web_host, files_host,
web_host, credentials['access_key'], credentials['secret_key'])
f.write(header)
git_credentials = '# Set GIT user/pass credentials\n' \
'# leave blank for GIT SSH credentials\n' \
'agent.git_user=\"{}\"\n' \
'agent.git_pass=\"{}\"\n' \
'\n'.format(git_user or '', git_pass or '')
f.write(git_credentials)
f.write(default_conf)
except Exception:
print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
return
print('\nNew configuration stored in {}'.format(str(conf_file)))
print('TRAINS-AGENT setup completed successfully.')
def input_url(host_type, host=None):
while True:
print('{} configured to: [{}] '.format(host_type, host), end='')
parse_input = input()
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
break
if parse_input and verify_url(parse_input):
host = parse_input
break
return host
def verify_url(parse_input):
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
# if we have a specific port, use http prefix, otherwise assume https
if ':' in parse_input:
parse_input = 'http://' + parse_input
else:
parse_input = 'https://' + parse_input
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
return parsed_host
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,97 @@
from __future__ import print_function
import json
import time
from future.builtins import super
from trains_agent.commands.base import ServiceCommandSection
from trains_agent.helper.base import return_list
class Events(ServiceCommandSection):
max_packet_size = 1024 * 1024
max_event_size = 64 * 1024
def __init__(self, *args, **kwargs):
super(Events, self).__init__(*args, **kwargs)
@property
def service(self):
""" Events command service endpoint """
return 'events'
def send_events(self, list_events):
def send_packet(jsonlines):
if not jsonlines:
return 0
num_lines = len(jsonlines)
jsonlines = '\n'.join(jsonlines)
new_events = self.post('add_batch', data=jsonlines, headers={'Content-type': 'application/json-lines'})
if new_events['added'] != num_lines:
print('Error (%s) sending events only %d of %d registered' %
(new_events['errors'], new_events['added'], num_lines))
return int(new_events['added'])
# print('Sent %d events' % num_lines)
return num_lines
# json every line and push into list of json strings
count_bytes = 0
lines = []
sent_events = 0
for i, event in enumerate(list_events):
line = json.dumps(event)
line_len = len(line) + 1
if count_bytes + line_len > self.max_packet_size:
# flush packet, and restart
sent_events += send_packet(lines)
count_bytes = 0
lines = []
count_bytes += line_len
lines.append(line)
# flush leftovers
sent_events += send_packet(lines)
# print('Sending events done: %d / %d events sent' % (sent_events, len(list_events)))
return sent_events
def send_log_events(self, worker_id, task_id, lines, level='DEBUG'):
log_events = []
base_timestamp = int(time.time() * 1000)
base_log_items = {
'type': 'log',
'level': level,
'task': task_id,
'worker': worker_id,
}
def get_event(c):
d = base_log_items.copy()
d.update(msg=msg, timestamp=base_timestamp + c)
return d
# break log lines into event packets
msg = ''
count = 0
for l in return_list(lines):
# HACK ignore terminal reset ANSI code
if l == '\x1b[0m':
continue
while l:
if len(msg) + len(l) < self.max_event_size:
msg += l
l = None
else:
left_over = self.max_event_size - len(msg)
msg += l[:left_over]
l = l[left_over:]
log_events.append(get_event(count))
msg = ''
count += 1
if msg:
log_events.append(get_event(count))
# now send the events
return self.send_events(list_events=log_events)

File diff suppressed because it is too large Load Diff

87
trains_agent/complete.py Normal file
View File

@@ -0,0 +1,87 @@
"""
Script for generating command-line completion.
Called by trains_agent/utilities/complete.sh (or a copy of it) like so:
python -m trains_agent.complete "current command line"
And writes line-separated completion targets to stdout.
Results are line-separated in order to enable other whitespace in results.
"""
from __future__ import print_function
import argparse
import sys
from trains_agent.interface import get_parser
def is_argument_required(action):
return isinstance(action, argparse._StoreAction)
def format_option(option, argument_required):
"""
Return appropriate string for flags requiring arguments and flags that do not
:param option: flag to format
:param argument_required: whether argument is required
"""
return option + '=' if argument_required else option + ' '
def get_options(parser):
"""
Return all possible flags for parser
:param parser: argparse.ArgumentParser instance
:return: list of options
"""
return [
format_option(option, is_argument_required(action))
for action in parser._actions
for option in action.option_strings
]
def main():
if len(sys.argv) != 2:
return 1
comp_words = iter(sys.argv[1].split()[1:])
parser = get_parser()
seen = []
for word in comp_words:
if word in parser.choices:
parser = parser[word]
continue
actions = {name: action for action in parser._actions for name in action.option_strings}
first, _, rest = word.partition('=')
is_one_word_store_action = rest and first in actions
if is_one_word_store_action:
word = first
seen.append(word)
try:
action = actions[word]
except KeyError:
break
if isinstance(action, argparse._StoreAction) and not isinstance(action, argparse._StoreConstAction):
if not is_one_word_store_action:
try:
next(comp_words)
except StopIteration:
break
options = list(parser.choices)
options = [format_option(option, argument_required=False) for option in options]
options.extend(get_options(parser))
options = [option for option in options if option.rstrip('= ') not in seen]
print('\n'.join(options))
return 0
if __name__ == "__main__":
sys.exit(main())

34
trains_agent/config.py Normal file
View File

@@ -0,0 +1,34 @@
from pyhocon import ConfigTree
import six
from trains_agent.helper.base import Singleton
@six.add_metaclass(Singleton)
class Config(object):
def __init__(self, tree=None):
self.__dict__['_tree'] = tree or ConfigTree()
def __getitem__(self, item):
return self._tree[item]
def __setitem__(self, key, value):
return self._tree.__setitem__(key, value)
def new(self, name):
return self._tree.setdefault(name, ConfigTree())
__getattr__ = __getitem__
__setattr__ = __setitem__
def get_config(name=None):
config = Config()
if name:
return getattr(config, name)
return config
def make_config(name):
return get_config().new(name)

131
trains_agent/definitions.py Normal file
View File

@@ -0,0 +1,131 @@
from datetime import timedelta
from distutils.util import strtobool
from enum import IntEnum
from os import getenv
from typing import Text, Optional, Union, Tuple, Any
from furl import furl
from pathlib2 import Path
import six
from trains_agent.helper.base import normalize_path
PROGRAM_NAME = "trains-agent"
FROM_FILE_PREFIX_CHARS = "@"
CONFIG_DIR = normalize_path("~/.trains")
TOKEN_CACHE_FILE = normalize_path("~/.trains.trains_agent.tmp")
CONFIG_FILE_CANDIDATES = ["~/trains.conf"]
def find_config_path():
for candidate in CONFIG_FILE_CANDIDATES:
if Path(candidate).expanduser().exists():
return candidate
return CONFIG_FILE_CANDIDATES[0]
CONFIG_FILE = normalize_path(find_config_path())
class EnvironmentConfig(object):
conversions = {
bool: lambda value: bool(strtobool(value)),
six.text_type: lambda s: six.text_type(s).strip(),
}
def __init__(self, *names, **kwargs):
self.vars = names
self.type = kwargs.pop("type", six.text_type)
def convert(self, value):
return self.conversions.get(self.type, self.type)(value)
def get(self, key=False): # type: (bool) -> Optional[Union[Any, Tuple[Text, Any]]]
for name in self.vars:
value = getenv(name)
if value:
value = self.convert(value)
if key:
return name, value
return value
return None
ENVIRONMENT_CONFIG = {
"api.api_server": EnvironmentConfig("TRAINS_API_HOST", "ALG_API_HOST"),
"api.credentials.access_key": EnvironmentConfig(
"TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY"
),
"api.credentials.secret_key": EnvironmentConfig(
"TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY"
),
"agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", "ALG_WORKER_NAME"),
"agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", "ALG_WORKER_ID"),
"agent.cuda_version": EnvironmentConfig(
"TRAINS_CUDA_VERSION", "ALG_CUDA_VERSION", "CUDA_VERSION"
),
"agent.cudnn_version": EnvironmentConfig(
"TRAINS_CUDNN_VERSION", "ALG_CUDNN_VERSION", "CUDNN_VERSION"
),
"agent.cpu_only": EnvironmentConfig(
"TRAINS_CPU_ONLY", "ALG_CPU_ONLY", "CPU_ONLY", type=bool
),
}
CONFIG_FILE_ENV = EnvironmentConfig("ALG_CONFIG_FILE")
ENVIRONMENT_SDK_PARAMS = {
"task_id": ("TRAINS_TASK_ID", "ALG_TASK_ID"),
"config_file": ("TRAINS_CONFIG_FILE", "ALG_CONFIG_FILE", "TRAINS_CONFIG_FILE"),
"log_level": ("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL"),
"log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND"),
}
VIRTUAL_ENVIRONMENT_PATH = {
"python2": normalize_path(CONFIG_DIR, "py2venv"),
"python3": normalize_path(CONFIG_DIR, "py3venv"),
}
DEFAULT_BASE_DIR = normalize_path(CONFIG_DIR, "data_cache")
DEFAULT_HOST = "https://demoai.trainsai.io"
MAX_DATASET_SOURCES_COUNT = 50000
INVALID_WORKER_ID = (400, 1001)
WORKER_ALREADY_REGISTERED = (400, 1003)
API_VERSION = "v1.5"
TOKEN_EXPIRATION_SECONDS = int(timedelta(days=2).total_seconds())
HTTP_HEADERS = {
"worker": "X-Trains-Worker",
"act-as": "X-Trains-Act-As",
"client": "X-Trains-Agent",
}
METADATA_EXTENSION = ".json"
DEFAULT_VENV_UPDATE_URL = (
"https://raw.githubusercontent.com/Yelp/venv-update/v3.2.2/venv_update.py"
)
WORKING_REPOSITORY_DIR = "task_repository"
DEFAULT_VCS_CACHE = normalize_path(CONFIG_DIR, "vcs-cache")
PIP_EXTRA_INDICES = [
]
DEFAULT_PIP_DOWNLOAD_CACHE = normalize_path(CONFIG_DIR, "pip-download-cache")
class FileBuffering(IntEnum):
"""
File buffering options:
- UNSET: follows the defaults for the type of file,
line-buffered for interactive (tty) text files and with a default chunk size otherwise
- UNBUFFERED: no buffering at all
- LINE_BUFFERED: per-line buffering, only valid for text files
- values bigger than 1 indicate the size of the buffer in bytes and are not represented by the enum
"""
UNSET = -1
UNBUFFERED = 0
LINE_BUFFERING = 1

86
trains_agent/errors.py Normal file
View File

@@ -0,0 +1,86 @@
from typing import Union, Optional, Text
import requests
import six
from .backend_api import CallResult
from .backend_api.session.client import APIError as ClientAPIError
from .backend_api.session.response import ResponseMeta
INTERNAL_SERVER_ERROR = 500
# TODO: hack: should NOT inherit from ValueError
class APIError(ClientAPIError, ValueError):
"""
Class for representing an API error.
self.data - ``dict`` of all returned JSON data
self.code - HTTP response code
self.subcode - server response subcode
self.codes - (self.code, self.subcode) tuple
self.message - result message sent from server
"""
def __init__(self, response, extra_info=None):
# type: (Union[requests.Response, CallResult], Optional[Text]) -> None
"""
Create a new APIError from a server response
"""
if not isinstance(response, CallResult):
try:
data = response.json()
except ValueError:
data = {}
meta = data.get('meta')
if meta:
response_meta = ResponseMeta(is_valid=False, **meta)
else:
response_meta = ResponseMeta.from_raw_data(response.status_code, response.text)
response = CallResult(
meta=response_meta,
response=response,
response_data=data,
)
super(APIError, self).__init__(response, extra_info=extra_info)
def format_traceback(self):
if self.code != INTERNAL_SERVER_ERROR:
return ''
traceback = self.get_traceback()
if traceback:
return 'Server traceback:\n{}'.format(traceback)
else:
return 'Could not print server traceback'
class CommandFailedError(Exception):
def __init__(self, message=None, *args, **kwargs):
super(CommandFailedError, self).__init__(message, *args, **kwargs)
self.message = message
class UsageError(CommandFailedError):
"""
Used for usage errors that are checked post-argparsing
"""
pass
class ConfigFileNotFound(CommandFailedError):
pass
class Sigterm(BaseException):
pass
@six.python_2_unicode_compatible
class MissingPackageError(CommandFailedError):
def __init__(self, name):
super(MissingPackageError, self).__init__(name)
self.name = name
def __str__(self):
return '{self.__class__.__name__}: ' \
'"{self.name}" package is required. Please run "pip install {self.name}"'.format(self=self)

View File

509
trains_agent/helper/base.py Normal file
View File

@@ -0,0 +1,509 @@
""" TRAINS-AGENT Stdout Helper Functions """
from __future__ import print_function, unicode_literals
import io
import json
import logging
import os
import platform
import re
import shutil
import stat
import subprocess
import sys
import tempfile
from abc import ABCMeta
from collections import OrderedDict
from distutils.spawn import find_executable
from functools import total_ordering
from typing import Text, Dict, Any, Optional, AnyStr, IO, Union
import attr
import furl
import pyhocon
import yaml
from attr import fields_dict
from pathlib2 import Path
from tqdm import tqdm
import six
from six.moves import reduce
from trains_agent.errors import CommandFailedError
from trains_agent.helper.dicts import filter_keys
pretty_lines = False
log = logging.getLogger(__name__)
def which(cmd, path=None):
result = find_executable(cmd, path)
if not result:
raise ValueError('command "{}" not found'.format(cmd))
return result
def select_for_platform(linux, windows):
"""
Select between multiple values according to the OS
:param linux: value to return if OS is linux
:param windows: value to return if OS is Windows
"""
return windows if is_windows_platform() else linux
def bash_c():
return 'bash -c' if not is_windows_platform() else 'cmd /c'
def return_list(arg):
if arg and not isinstance(arg, (tuple, list)):
return [arg]
return arg
def print_table(entries, columns=(), titles=(), csv=None, headers=True):
table = create_table(entries, columns=columns, titles=titles, csv=csv, headers=headers)
if csv:
with open(csv, 'w') as output:
print(table, file=output)
else:
print(table)
def create_table(entries, columns=(), titles=(), csv=None, headers=True):
table = [
[
reduce(
lambda obj, key: obj.get(key, {}),
column.split('.'),
entry
) or ''
for column in columns
]
for entry in entries
]
if headers:
headers = [titles[i] if i < len(titles) and titles[i] else c for i, c in enumerate(columns)]
else:
headers = []
output = ''
if csv:
if headers:
output += ','.join(headers) + '\n'
for entry in table:
output += ','.join(map(str, entry)) + '\n'
else:
min_col_width = 3
col_widths = [max(min_col_width, len(h)+1) for h in (headers or table[0])]
for e in table:
col_widths = list(map(max, zip(col_widths, [len(h)+1 for h in e])))
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
if headers:
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*headers) + ' |' + '\n'
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
for entry in table:
line = map(str, entry)
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*line) + ' |' + '\n'
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
return output
def create_tree(entries, id='id', parent='parent', node_title='%(id)'):
tree = OrderedDict()
all_nodes = dict()
for t in entries:
i = t.get(id, None)
p = t.get(parent, None)
if not p and i not in tree:
# push roots
myd = all_nodes.get(i, OrderedDict())
# add node title
tree[node_title % t] = myd
all_nodes[i] = myd
elif p:
# update parent dictionary
d = all_nodes.get(p, OrderedDict())
# get node dictionary
myd = all_nodes.get(i, OrderedDict())
# add node title
d[node_title % t] = myd
all_nodes[p] = d
all_nodes[i] = myd
else:
pass
return {'': tree}
def print_parameters(param_struct, indent=1):
text = yaml.safe_dump(param_struct, allow_unicode=True, indent=indent, default_flow_style=False)
print(text)
def get_list_files(basefolder, filext=('.jpg')):
filext = [e.lower() for e in filext]
fileiter = (os.path.join(root, f)
for root, _, files in os.walk(basefolder)
for f in files if os.path.splitext(f)[1].lower() in filext)
return fileiter
def is_windows_platform():
return any(platform.win32_ver())
def normalize_path(*paths):
"""
normalize_path
Joins ``*paths``, expands ``~`` and normalizes path separators.
:param paths: path components to create path from
"""
return os.path.normpath(os.path.expandvars(os.path.expanduser(os.path.join(*map(str, paths)))))
def safe_remove_file(filename, error_message=None):
try:
os.remove(filename)
except Exception:
if error_message:
print(error_message)
class Singleton(ABCMeta):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
@total_ordering
class CompareAnything(object):
"""
CompareAnything
Creates an object which is always the smallest when compared to other objects.
"""
@staticmethod
def __eq__(_):
return False
@staticmethod
def __lt__(_):
return True
def nonstrict_in_place_sort(lst, reverse, *keys):
"""
nonstrict_in_place_sort
Sorts a list of dictionaries in-place by ``keys``.
An element without a certain ``key`` will be considered the smallest in respect to that key.
:param lst: list to sort
:type lst: ``[dict]``
:param reverse: whether to reverse sorting
:type reverse: ``bool``
:param keys: Keys to sort by.
Elements will be sorted pseudo-lexicographically by the values corresponding to ``*keys``, i.e:
the list will be first sorted by the first element of ``*keys``,
elements which are equal by the first sort will be internally sorted by
the second element of ``*keys`` and so on.
:type keys: ``[str]``
"""
lst.sort(
key=lambda item: tuple(item.get(key, CompareAnything()) for key in keys),
reverse=reverse,
)
def load_yaml(path):
if isinstance(path, Path):
path = str(path)
try:
with open(path) as data_file:
return yaml.safe_load(data_file) or {}
except yaml.YAMLError as e:
raise ValueError('Failed parsing yaml file [{}]: {}'.format(path, e))
def dump_yaml(obj, path=None, dump_all=False, **kwargs):
base_kwargs = dict(indent=4, allow_unicode=True, default_flow_style=False)
base_kwargs.update(kwargs)
if dump_all:
base_kwargs['Dumper'] = AllDumper
dump_func = yaml.dump
else:
dump_func = yaml.safe_dump
if not path:
return dump_func(obj, **base_kwargs)
path = str(path)
with open(path, 'w') as output:
dump_func(obj, output, **base_kwargs)
def one_value(dct):
return next(iter(six.itervalues(dct)))
@attr.s
class RepoInfo(object):
type = attr.ib(type=str)
url = attr.ib(type=str)
branch = attr.ib(type=str)
commit = attr.ib(type=str)
root = attr.ib(type=str)
def get_repo_info(repo_type, path):
assert repo_type in ['git', 'hg']
if repo_type == 'git':
commands = dict(
url='git remote get-url origin',
branch='git rev-parse --abbrev-ref HEAD',
commit='git rev-parse HEAD',
root='git rev-parse --show-toplevel'
)
elif repo_type == 'hg':
commands = dict(
url='hg paths --verbose',
branch='hg --debug id -b',
commit='hg --debug id -i',
root='hg root'
)
else:
raise RuntimeError("Unknown repository type '{}'".format(repo_type))
commands_result = {
name: subprocess.check_output(command.split(), cwd=path).decode().strip()
for name, command in commands.items()
}
return RepoInfo(type=repo_type, **commands_result)
def reverse_home_folder_expansion(path):
path = str(path)
if is_windows_platform():
return path
return re.sub('^{}/'.format(re.escape(str(Path.home()))), '~/', path)
def represent_ordered_dict(dumper, data):
"""
Serializes ``OrderedDict`` to YAML by its proper order.
Registering this function to ``yaml.SafeDumper`` enables using ``yaml.safe_dump`` with ``OrderedDict``s.
"""
return dumper.represent_mapping(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items())
def construct_mapping(loader, node):
"""
Deserialize YAML mappings as ``OrderedDict``s.
"""
loader.flatten_mapping(node)
return OrderedDict(loader.construct_pairs(node))
yaml.SafeDumper.add_representer(OrderedDict, represent_ordered_dict)
yaml.SafeLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
class AllDumper(yaml.SafeDumper):
pass
AllDumper.add_multi_representer(object, lambda dumper, data: dumper.represent_str(str(data)))
def error(message):
print('\ntrains_agent: ERROR: {}\n'.format(message))
def warning(message):
print('trains_agent: Warning: {}'.format(message))
class TqdmStream(object):
def __init__(self, file_object):
self.buffer = file_object
def write(self, data):
self.buffer.write(data.strip())
def flush(self):
self.buffer.write('\n')
class TqdmLog(tqdm):
def __init__(self, iterable=None, file=None, **kwargs):
super(TqdmLog, self).__init__(iterable, file=TqdmStream(file or sys.stderr), **kwargs)
def url_join(first, *rest):
"""
Join url parts similarly to Path.join
"""
return str(furl.furl(first).path.add(rest)).lstrip('/')
class LowercaseFormatter(logging.Formatter):
def format(self, record, *args, **kwargs):
record.levelname = record.levelname.lower()
return super(LowercaseFormatter, self).format(record, *args, **kwargs)
def mkstemp(
open_kwargs=None, # type: Optional[Dict[Text, Any]]
text=True, # type: bool
name_only=False, # type: bool
*args,
**kwargs):
# type: (...) -> Union[(IO[AnyStr], Text), Text]
"""
WARNING: the returned file object is strict about its input type,
make sure to feed it binary/text input in correspondence to the ``text`` argument
:param open_kwargs: keyword arguments for ``io.open``
:param text: open in text mode
:param name_only: close the file and return its name
:param args: tempfile.mkstemp args
:param kwargs: tempfile.mkstemp kwargs
"""
fd, name = tempfile.mkstemp(text=text, *args, **kwargs)
mode = 'w+'
if not text:
mode += 'b'
if name_only:
os.close(fd)
return name
return io.open(fd, mode, **open_kwargs or {}), name
def named_temporary_file(*args, **kwargs):
if six.PY2:
buffering = kwargs.pop('buffering', None)
if buffering:
kwargs['bufsize'] = buffering
return tempfile.NamedTemporaryFile(*args, **kwargs)
def parse_override(string):
return pyhocon.ConfigFactory.parse_string(string).as_plain_ordered_dict()
def chain_map(*args):
return reduce(lambda x, y: x.update(y) or x, args, {})
def check_directory_path(path):
message = 'Could not create directory "{}": {}'
if not is_windows_platform():
match = re.search(r'\s', path)
if match:
raise CommandFailedError(
'directories may not contain whitespace (char: {!r}, position: {})'.format(match.group(0),
match.endpos))
try:
Path(os.path.expandvars(path)).expanduser().mkdir(parents=True, exist_ok=True)
except OSError as e:
raise CommandFailedError(message.format(path, e.strerror))
except Exception as e:
raise CommandFailedError(message.format(path, e))
def create_file_if_not_exists(path):
if not os.path.exists(os.path.expanduser(os.path.expandvars(path))):
open(path, "w").close()
def rm_tree(root): # type: (Union[Path, Text]) -> None
"""
A version of shutil.rmtree that handles access errors, specifically hidden files on Windows
"""
def on_error(func, path, _):
try:
if os.path.exists(path) and not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
except Exception:
pass
return shutil.rmtree(os.path.expanduser(os.path.expandvars(Text(root))), onerror=on_error)
def is_conda(config):
return config['agent.package_manager.type'].lower() == 'conda'
class NonStrictAttrs(object):
@classmethod
def from_dict(cls, kwargs):
fields = fields_dict(cls)
return cls(**filter_keys(lambda key: key in fields, kwargs))
def python_version_string():
return '{v.major}.{v.minor}'.format(v=sys.version_info)
join_lines = '\n'.join
class HOCONEncoder(json.JSONEncoder):
"""
pyhocon bugs:
1. "\\t" is dumped as "\t" instead of "\\t", which is read as the character "\t".
2. parsed config trees have dummy `pyhocon.config_tree.NoneValue` in them.
(see: https://github.com/chimpler/pyhocon/issues/111)
Workaround: dump HOCON to JSON, of which it is a subset, taking care of `NoneValue`s.
"""
def default(self, o):
"""
If o is `pyhocon.config_tree.NoneValue`, encode it the same way as `None`.
"""
if isinstance(o, pyhocon.config_tree.NoneValue):
return super(HOCONEncoder, self).encode(None)
return super(HOCONEncoder, self).default(o)
nullable_string = attr.ib(default="", converter=lambda x: x.strip())
normal_path = attr.ib(default="", converter=lambda p: p and normalize_path(p))
@attr.s
class ExecutionInfo(NonStrictAttrs):
repository = nullable_string
entry_point = normal_path
working_dir = normal_path
branch = nullable_string
version_num = nullable_string
tag = nullable_string
@classmethod
def from_task(cls, task_info):
# type: (...) -> ExecutionInfo
"""
extract ExecutionInfo tuple from task parameters
"""
if not task_info.script:
raise CommandFailedError("can not run task without script information")
execution = cls.from_dict(task_info.script.to_dict())
if not execution.entry_point:
log.warning("notice: `script.entry_point` is empty")
if not execution.working_dir:
entry_point, _, working_dir = execution.entry_point.partition(":")
execution.entry_point = entry_point
execution.working_dir = working_dir or ""
return execution

View File

@@ -0,0 +1,60 @@
import os
from time import sleep
import requests
import json
from threading import Thread
from semantic_version import Version
from ..version import __version__
__check_update_thread = None
def start_check_update_daemon():
global __check_update_thread
if __check_update_thread:
return
__check_update_thread = Thread(target=_check_update_daemon)
__check_update_thread.daemon = True
__check_update_thread.start()
def _check_new_version_available():
cur_version = __version__
update_server_releases = requests.get('https://updates.trainsai.io/updates',
data=json.dumps({"versions": {"trains-agent": str(cur_version)}}),
timeout=3.0)
if update_server_releases.ok:
update_server_releases = update_server_releases.json()
else:
return None
trains_answer = update_server_releases.get("trains-agent", {})
latest_version = trains_answer.get("version")
cur_version = Version(cur_version)
latest_version = Version(latest_version)
if cur_version >= latest_version:
return None
patch_upgrade = latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")
def _check_update_daemon():
counter = 0
while True:
# noinspection PyBroadException
try:
latest_version = _check_new_version_available()
# only print when we begin
if latest_version:
if latest_version[1]:
sep = os.linesep
print('TRAINS-AGENT new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
latest_version[0], sep.join(latest_version[2])))
else:
print('TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]))
except Exception:
pass
# sleep until the next day
sleep(60 * 60 * 24)
counter += 1

View File

@@ -0,0 +1,100 @@
from __future__ import unicode_literals, print_function
import csv
import sys
from collections import Iterable
from typing import List, Dict, Text, Any
from attr import attrs, attrib
import six
from six import binary_type, text_type
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
def print_text(text, newline=True):
if newline:
text += '\n'
data = text.encode(sys.stdout.encoding or 'utf8', errors='replace')
try:
sys.stdout.buffer.write(data)
except AttributeError:
sys.stdout.write(data)
def ensure_text(s, encoding='utf-8', errors='strict'):
"""Coerce *s* to six.text_type.
For Python 2:
- `unicode` -> `unicode`
- `str` -> `unicode`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
if isinstance(s, binary_type):
return s.decode(encoding, errors)
elif isinstance(s, text_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
def ensure_binary(s, encoding='utf-8', errors='strict'):
"""Coerce **s** to six.binary_type.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> encoded to `bytes`
- `bytes` -> `bytes`
"""
if isinstance(s, text_type):
return s.encode(encoding, errors)
elif isinstance(s, binary_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
class ListFormatter(object):
@attrs(init=False)
class Table(object):
entries = attrib(type=List[Dict])
columns = attrib(type=List[Text])
def __init__(self, entries, columns):
self.entries = entries
if isinstance(columns, str):
columns = columns.split('#')
self.columns = columns
def as_rows(self): # type: () -> Iterable[Iterable[Any]]
return (
map(entry.get, self.columns)
for entry in self.entries
)
def __init__(self, service_name):
self.service_name = service_name
def get_total(self, entries):
return '\nTotal {} {}'.format(self.service_name, len(entries))
@classmethod
def write_csv(cls, entries, columns, dest, headers=True):
table = cls.Table(entries, columns)
with open(dest, 'w') as output:
writer = csv.DictWriter(output, fieldnames=table.columns, extrasaction='ignore')
if headers:
writer.writeheader()
writer.writerows(table.entries)
@staticmethod
def sort_in_place(entries, key, reverse=None):
if isinstance(key, six.string_types):
nonstrict_in_place_sort(entries, reverse, *key.split('#'))
elif callable(key):
entries.sort(key=key, reverse=reverse)
else:
raise ValueError('"sort" argument must be either a string or a callable object')

View File

@@ -0,0 +1,5 @@
from typing import Callable, Dict, Any
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
return {key: value for key, value in dct.items() if filter_(key)}

View File

View File

@@ -0,0 +1,385 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Implementation of gpustat
@author Jongwook Choi
@url https://github.com/wookayin/gpustat
@ copied from gpu-stat 0.6.0
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import json
import os.path
import platform
import sys
import time
from datetime import datetime
import psutil
from ..gpu import pynvml as N
NOT_SUPPORTED = 'Not Supported'
MB = 1024 * 1024
class GPUStat(object):
def __init__(self, entry):
if not isinstance(entry, dict):
raise TypeError(
'entry should be a dict, {} given'.format(type(entry))
)
self.entry = entry
def keys(self):
return self.entry.keys()
def __getitem__(self, key):
return self.entry[key]
@property
def index(self):
"""
Returns the index of GPU (as in nvidia-smi).
"""
return self.entry['index']
@property
def uuid(self):
"""
Returns the uuid returned by nvidia-smi,
e.g. GPU-12345678-abcd-abcd-uuid-123456abcdef
"""
return self.entry['uuid']
@property
def name(self):
"""
Returns the name of GPU card (e.g. Geforce Titan X)
"""
return self.entry['name']
@property
def memory_total(self):
"""
Returns the total memory (in MB) as an integer.
"""
return int(self.entry['memory.total'])
@property
def memory_used(self):
"""
Returns the occupied memory (in MB) as an integer.
"""
return int(self.entry['memory.used'])
@property
def memory_free(self):
"""
Returns the free (available) memory (in MB) as an integer.
"""
v = self.memory_total - self.memory_used
return max(v, 0)
@property
def memory_available(self):
"""
Returns the available memory (in MB) as an integer.
Alias of memory_free.
"""
return self.memory_free
@property
def temperature(self):
"""
Returns the temperature (in celcius) of GPU as an integer,
or None if the information is not available.
"""
v = self.entry['temperature.gpu']
return int(v) if v is not None else None
@property
def fan_speed(self):
"""
Returns the fan speed percentage (0-100) of maximum intended speed
as an integer, or None if the information is not available.
"""
v = self.entry['fan.speed']
return int(v) if v is not None else None
@property
def utilization(self):
"""
Returns the GPU utilization (in percentile),
or None if the information is not available.
"""
v = self.entry['utilization.gpu']
return int(v) if v is not None else None
@property
def power_draw(self):
"""
Returns the GPU power usage in Watts,
or None if the information is not available.
"""
v = self.entry['power.draw']
return int(v) if v is not None else None
@property
def power_limit(self):
"""
Returns the (enforced) GPU power limit in Watts,
or None if the information is not available.
"""
v = self.entry['enforced.power.limit']
return int(v) if v is not None else None
@property
def processes(self):
"""
Get the list of running processes on the GPU.
"""
return self.entry['processes']
def jsonify(self):
o = dict(self.entry)
if self.entry['processes'] is not None:
o['processes'] = [{k: v for (k, v) in p.items() if k != 'gpu_uuid'}
for p in self.entry['processes']]
else:
o['processes'] = '({})'.format(NOT_SUPPORTED)
return o
class GPUStatCollection(object):
global_processes = {}
_initialized = False
_device_count = None
_gpu_device_info = {}
def __init__(self, gpu_list, driver_version=None):
self.gpus = gpu_list
# attach additional system information
self.hostname = platform.node()
self.query_time = datetime.now()
self.driver_version = driver_version
@staticmethod
def clean_processes():
for pid in list(GPUStatCollection.global_processes.keys()):
if not psutil.pid_exists(pid):
del GPUStatCollection.global_processes[pid]
@staticmethod
def new_query(shutdown=False, per_process_stats=False, get_driver_info=False):
"""Query the information of all the GPUs on local machine"""
if not GPUStatCollection._initialized:
N.nvmlInit()
GPUStatCollection._initialized = True
def _decode(b):
if isinstance(b, bytes):
return b.decode() # for python3, to unicode
return b
def get_gpu_info(index, handle):
"""Get one GPU information specified by nvml handle"""
def get_process_info(nv_process):
"""Get the process information of specific pid"""
process = {}
if nv_process.pid not in GPUStatCollection.global_processes:
GPUStatCollection.global_processes[nv_process.pid] = \
psutil.Process(pid=nv_process.pid)
ps_process = GPUStatCollection.global_processes[nv_process.pid]
process['username'] = ps_process.username()
# cmdline returns full path;
# as in `ps -o comm`, get short cmdnames.
_cmdline = ps_process.cmdline()
if not _cmdline:
# sometimes, zombie or unknown (e.g. [kworker/8:2H])
process['command'] = '?'
process['full_command'] = ['?']
else:
process['command'] = os.path.basename(_cmdline[0])
process['full_command'] = _cmdline
# Bytes to MBytes
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
process['cpu_percent'] = ps_process.cpu_percent()
process['cpu_memory_usage'] = \
round((ps_process.memory_percent() / 100.0) *
psutil.virtual_memory().total)
process['pid'] = nv_process.pid
return process
if not GPUStatCollection._gpu_device_info.get(index):
name = _decode(N.nvmlDeviceGetName(handle))
uuid = _decode(N.nvmlDeviceGetUUID(handle))
GPUStatCollection._gpu_device_info[index] = (name, uuid)
name, uuid = GPUStatCollection._gpu_device_info[index]
try:
temperature = N.nvmlDeviceGetTemperature(
handle, N.NVML_TEMPERATURE_GPU
)
except N.NVMLError:
temperature = None # Not supported
try:
fan_speed = N.nvmlDeviceGetFanSpeed(handle)
except N.NVMLError:
fan_speed = None # Not supported
try:
memory = N.nvmlDeviceGetMemoryInfo(handle) # in Bytes
except N.NVMLError:
memory = None # Not supported
try:
utilization = N.nvmlDeviceGetUtilizationRates(handle)
except N.NVMLError:
utilization = None # Not supported
try:
power = N.nvmlDeviceGetPowerUsage(handle)
except N.NVMLError:
power = None
try:
power_limit = N.nvmlDeviceGetEnforcedPowerLimit(handle)
except N.NVMLError:
power_limit = None
try:
nv_comp_processes = \
N.nvmlDeviceGetComputeRunningProcesses(handle)
except N.NVMLError:
nv_comp_processes = None # Not supported
try:
nv_graphics_processes = \
N.nvmlDeviceGetGraphicsRunningProcesses(handle)
except N.NVMLError:
nv_graphics_processes = None # Not supported
if not per_process_stats or (nv_comp_processes is None and nv_graphics_processes is None):
processes = None
else:
processes = []
nv_comp_processes = nv_comp_processes or []
nv_graphics_processes = nv_graphics_processes or []
for nv_process in nv_comp_processes + nv_graphics_processes:
try:
process = get_process_info(nv_process)
processes.append(process)
except psutil.NoSuchProcess:
# TODO: add some reminder for NVML broken context
# e.g. nvidia-smi reset or reboot the system
pass
# TODO: Do not block if full process info is not requested
time.sleep(0.1)
for process in processes:
pid = process['pid']
cache_process = GPUStatCollection.global_processes[pid]
process['cpu_percent'] = cache_process.cpu_percent()
index = N.nvmlDeviceGetIndex(handle)
gpu_info = {
'index': index,
'uuid': uuid,
'name': name,
'temperature.gpu': temperature,
'fan.speed': fan_speed,
'utilization.gpu': utilization.gpu if utilization else None,
'power.draw': power // 1000 if power is not None else None,
'enforced.power.limit': power_limit // 1000
if power_limit is not None else None,
# Convert bytes into MBytes
'memory.used': memory.used // MB if memory else None,
'memory.total': memory.total // MB if memory else None,
'processes': processes,
}
if per_process_stats:
GPUStatCollection.clean_processes()
return gpu_info
# 1. get the list of gpu and status
gpu_list = []
if GPUStatCollection._device_count is None:
GPUStatCollection._device_count = N.nvmlDeviceGetCount()
for index in range(GPUStatCollection._device_count):
handle = N.nvmlDeviceGetHandleByIndex(index)
gpu_info = get_gpu_info(index, handle)
gpu_stat = GPUStat(gpu_info)
gpu_list.append(gpu_stat)
# 2. additional info (driver version, etc).
if get_driver_info:
try:
driver_version = _decode(N.nvmlSystemGetDriverVersion())
except N.NVMLError:
driver_version = None # N/A
else:
driver_version = None
# no need to shutdown:
if shutdown:
N.nvmlShutdown()
GPUStatCollection._initialized = False
return GPUStatCollection(gpu_list, driver_version=driver_version)
def __len__(self):
return len(self.gpus)
def __iter__(self):
return iter(self.gpus)
def __getitem__(self, index):
return self.gpus[index]
def __repr__(self):
s = 'GPUStatCollection(host=%s, [\n' % self.hostname
s += '\n'.join(' ' + str(g) for g in self.gpus)
s += '\n])'
return s
# --- Printing Functions ---
def jsonify(self):
return {
'hostname': self.hostname,
'query_time': self.query_time,
"gpus": [g.jsonify() for g in self]
}
def print_json(self, fp=sys.stdout):
def date_handler(obj):
if hasattr(obj, 'isoformat'):
return obj.isoformat()
else:
raise TypeError(type(obj))
o = self.jsonify()
json.dump(o, fp, indent=4, separators=(',', ': '),
default=date_handler)
fp.write('\n')
fp.flush()
def new_query(shutdown=False, per_process_stats=False, get_driver_info=False):
'''
Obtain a new GPUStatCollection instance by querying nvidia-smi
to get the list of GPUs and running process information.
'''
return GPUStatCollection.new_query(shutdown=shutdown, per_process_stats=per_process_stats,
get_driver_info=get_driver_info)

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,107 @@
from __future__ import unicode_literals
import abc
from contextlib import contextmanager
from typing import Text, Iterable, Union
import six
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
from trains_agent.helper.process import Executable, Argv, PathLike
@six.add_metaclass(abc.ABCMeta)
class PackageManager(object):
"""
ABC for classes providing python package management interface
"""
_selected_manager = None
@abc.abstractproperty
def bin(self):
# type: () -> PathLike
pass
@abc.abstractmethod
def create(self):
pass
@abc.abstractmethod
def remove(self):
pass
@abc.abstractmethod
def install_from_file(self, path):
pass
@abc.abstractmethod
def freeze(self):
pass
@abc.abstractmethod
def load_requirements(self, requirements):
pass
@abc.abstractmethod
def install_packages(self, *packages):
# type: (Iterable[Text]) -> None
"""
Install packages, upgrading depends on config
"""
pass
@abc.abstractmethod
def _install(self, *packages):
# type: (Iterable[Text]) -> None
"""
Run install command
"""
pass
@abc.abstractmethod
def uninstall_packages(self, *packages):
# type: (Iterable[Text]) -> None
pass
def upgrade_pip(self):
return self._install("pip", "--upgrade")
def get_python_command(self, extra=()):
# type: (...) -> Executable
return Argv(self.bin, *extra)
@contextmanager
def temp_file(self, prefix, contents, suffix=".txt"):
# type: (Union[Text, Iterable[Text]], Iterable[Text], Text) -> Text
"""
Write contents to a temporary file, yielding its path. Finally, delete it.
:param prefix: file name prefix
:param contents: text lines to write
:param suffix: file name suffix
"""
f, temp_path = mkstemp(suffix=suffix, prefix=prefix)
with f:
f.write(
contents
if isinstance(contents, six.text_type)
else join_lines(contents)
)
try:
yield temp_path
finally:
if not self.session.debug_mode:
safe_remove_file(temp_path)
def set_selected_package_manager(self):
# set this instance as the selected package manager
# this is helpful when we want out of context requirement installations
PackageManager._selected_manager = self
@classmethod
def out_of_scope_install_package(cls, package_name):
if PackageManager._selected_manager is not None:
try:
return PackageManager._selected_manager._install(package_name)
except Exception:
pass
return

View File

@@ -0,0 +1,365 @@
from __future__ import unicode_literals
import json
import re
import subprocess
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
import yaml
from attr import attrs, attrib, Factory
from pathlib2 import Path
from semantic_version import Version
from requirements import parse
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.session import Session
from .base import PackageManager
from .pip_api.venv import VirtualenvPip
from .requirements import RequirementsManager, MarkerRequirement
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
def package_set(packages):
return set(map(package_normalize, packages))
def _package_diff(path, packages):
# type: (Union[Path, Text], Iterable[Text]) -> Set[Text]
return package_set(Path(path).read_text().splitlines()) - package_set(packages)
class CondaPip(VirtualenvPip):
def __init__(self, source=None, *args, **kwargs):
super(CondaPip, self).__init__(*args, **kwargs)
self.source = source
def run_with_env(self, command, output=False, **kwargs):
if not self.source:
return super(CondaPip, self).run_with_env(command, output=output, **kwargs)
command = CommandSequence(self.source, Argv("pip", *command))
return (command.get_output if output else command.check_call)(
stdin=DEVNULL, **kwargs
)
class CondaAPI(PackageManager):
"""
A programmatic interface for controlling conda
"""
MINIMUM_VERSION = Version("4.3.30", partial=True)
def __init__(self, session, path, python, requirements_manager):
# type: (Session, PathLike, float, RequirementsManager) -> None
"""
:param python: base python version to use (e.g python3.6)
:param path: path of env
"""
self.session = session
self.python = python
self.source = None
self.requirements_manager = requirements_manager
self.path = path
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
self.pip = CondaPip(
session=self.session,
source=self.source,
python=self.python,
requirements_manager=self.requirements_manager,
path=self.path,
)
self.conda = (
find_executable("conda")
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
)
try:
output = Argv(self.conda, "--version").get_output()
except subprocess.CalledProcessError as ex:
raise CommandFailedError(
"Unable to determine conda version: {ex}, output={ex.output}".format(
ex=ex
)
)
self.conda_version = self.get_conda_version(output)
if Version(self.conda_version, partial=True) < self.MINIMUM_VERSION:
raise CommandFailedError(
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
self.conda_version, self.MINIMUM_VERSION
)
)
@staticmethod
def get_conda_version(output):
match = re.search(r"(\d+\.){0,2}\d+", output)
if not match:
raise CommandFailedError("Unidentified conda version string:", output)
return match.group(0)
@property
def bin(self):
return self.pip.bin
def upgrade_pip(self):
return self.pip.upgrade_pip()
def create(self):
"""
Create a new environment
"""
output = Argv(
self.conda,
"create",
"--yes",
"--mkdir",
"--prefix",
self.path,
"python={}".format(self.python),
).get_output(stderr=DEVNULL)
match = re.search(
r"\W*(.*activate) ({})".format(re.escape(str(self.path))), output
)
self.source = self.pip.source = (
tuple(match.group(1).split()) + (match.group(2),)
if match
else ("activate", self.path)
)
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if conda_env.is_file():
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# install cuda toolkit
try:
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
if cuda_version > 0:
self._install('cudatoolkit={:.1f}'.format(cuda_version))
except Exception:
pass
return self
def remove(self):
"""
Delete a conda environment.
Use 'conda env remove', then 'rm_tree' to be safe.
Conda seems to load "vcruntime140.dll" from all its environment on startup.
This means environment have to be deleted using 'conda env remove'.
If necessary, conda can be fooled into deleting a partially-deleted environment by creating an empty file
in '<ENV>\conda-meta\history' (value found in 'conda.gateways.disk.test.PREFIX_MAGIC_FILE').
Otherwise, it complains that said directory is not a conda environment.
See: https://github.com/conda/conda/issues/7682
"""
try:
self._run_command(("env", "remove", "-p", self.path))
except Exception:
pass
rm_tree(self.path)
def _install_from_file(self, path):
"""
Install packages from requirement file.
"""
self._install("--file", path)
def _install(self, *args):
# type: (*PathLike) -> ()
channels_args = tuple(
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
)
self._run_command(("install", "-p", self.path) + channels_args + args)
def _get_pip_packages(self, packages):
# type: (Iterable[Text]) -> Sequence[Text]
"""
Return subset of ``packages`` which are not available on conda
"""
pips = []
while True:
with self.temp_file("conda_reqs", packages) as path:
try:
self._install_from_file(path)
except PackageNotFoundError as e:
pips.append(e.pkg)
packages = _package_diff(path, {e.pkg})
else:
break
return pips
def install_packages(self, *packages):
# type: (*Text) -> ()
return self._install(*packages)
def uninstall_packages(self, *packages):
return self._run_command(("uninstall", "-p", self.path))
def install_from_file(self, path):
"""
Try to install packages from conda. Install packages which are not available from conda with pip.
"""
try:
self._install_from_file(path)
return
except PackageNotFoundError as e:
pip_packages = [e.pkg]
except PackagesNotFoundError as e:
pip_packages = package_set(e.packages)
with self.temp_file("conda_reqs", _package_diff(path, pip_packages)) as reqs:
self.install_from_file(reqs)
with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs)
def freeze(self):
# result = yaml.load(
# self._run_command((self.conda, "env", "export", "-p", self.path), raw=True)
# )
# for key in "name", "prefix":
# result.pop(key, None)
# freeze = {"conda": result}
# try:
# freeze["pip"] = result["dependencies"][-1]["pip"]
# except (TypeError, KeyError):
# freeze["pip"] = []
# else:
# del result["dependencies"][-1]
# return freeze
return self.pip.freeze()
def load_requirements(self, requirements):
# create new environment file
conda_env = dict()
conda_env['channels'] = self.extra_channels
reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']]
pip_requirements = []
while reqs:
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs]
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
result = self._run_command(
("env", "update", "-p", self.path, "--file", name)
)
# check if we need to remove specific packages
bad_req = self._parse_conda_result_bad_packges(result)
if not bad_req:
break
solved = False
for bad_r in bad_req:
name = bad_r.split('[')[0].split('=')[0]
# look for name in requirements
for r in reqs:
if r.name.lower() == name.lower():
pip_requirements.append(r)
reqs.remove(r)
solved = True
break
# we couldn't remove even one package,
# nothing we can do but try pip
if not solved:
pip_requirements.extend(reqs)
break
if pip_requirements:
try:
pip_req_str = [r.tostr() for r in pip_requirements]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
self.pip.load_requirements('\n'.join(pip_req_str))
except Exception as e:
print(e)
raise e
self.requirements_manager.post_install()
return True
def _parse_conda_result_bad_packges(self, result_dict):
if not result_dict:
return None
if 'bad_deps' in result_dict and result_dict['bad_deps']:
return result_dict['bad_deps']
if result_dict.get('error'):
error_lines = result_dict['error'].split('\n')
if error_lines[0].strip().lower().startswith("unsatisfiableerror:"):
empty_lines = [i for i, l in enumerate(error_lines) if not l.strip()]
if len(empty_lines) >= 2:
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
try:
return yaml.load('\n'.join(deps))
except:
return None
return None
def _run_command(self, command, raw=False, **kwargs):
# type: (Iterable[Text], bool, Any) -> Union[Dict, Text]
"""
Run a conda command, returning JSON output.
The command is prepended with 'conda' and run with JSON output flags.
:param command: command to run
:param raw: return text output and don't change command
:param kwargs: kwargs for Argv.get_output()
:return: JSON output or text output
"""
command = Argv(*command) # type: Executable
if not raw:
command = (self.conda,) + command + ("--quiet", "--json")
try:
print('Executing Conda: {}'.format(command.serialize()))
result = command.get_output(stdin=DEVNULL, **kwargs)
except Exception as e:
if raw:
raise
result = e.output if hasattr(e, 'output') else ''
if raw:
return result
result = json.loads(result) if result else {}
if result.get('success', False):
print('Pass')
elif result.get('error'):
print('Conda error: {}'.format(result.get('error')))
return result
def get_python_command(self, extra=()):
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
# enable hashing with cmp=False because pdb fails on unhashable exceptions
exception = attrs(str=True, cmp=False)
@exception
class CondaException(Exception, NonStrictAttrs):
command = attrib()
message = attrib(default=None)
@exception
class UnknownCondaError(CondaException):
data = attrib(default=Factory(dict))
@exception
class PackagesNotFoundError(CondaException):
"""
Conda 4.5 exception - this reports all missing packages.
"""
packages = attrib(default=())
@exception
class PackageNotFoundError(CondaException):
"""
Conda 4.3 exception - this reports one missing package at a time,
as a singleton YAML list.
"""
pkg = attrib(default="", converter=lambda val: yaml.load(val)[0].replace(" ", ""))

View File

@@ -0,0 +1,25 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class CythonRequirement(SimpleSubstitution):
name = "cython"
def __init__(self, *args, **kwargs):
super(CythonRequirement, self).__init__(*args, **kwargs)
def match(self, req):
# match both Cython & cython
return self.name == req.name.lower()
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# install Cython before
PackageManager.out_of_scope_install_package(str(req))
return Text(req)

View File

@@ -0,0 +1,32 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class HorovodRequirement(SimpleSubstitution):
name = "horovod"
def __init__(self, *args, **kwargs):
super(HorovodRequirement, self).__init__(*args, **kwargs)
self.post_install_req = None
def match(self, req):
# match both horovod
return self.name == req.name.lower()
def post_install(self):
if self.post_install_req:
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
self.post_install_req = None
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req = req
# mark skip package, we will install it in post install hook
return Text('')

View File

@@ -0,0 +1,92 @@
import sys
from itertools import chain
from typing import Text
from trains_agent.definitions import PIP_EXTRA_INDICES, PROGRAM_NAME
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.process import Argv, DEVNULL
class SystemPip(PackageManager):
indices_args = None
def __init__(self, interpreter=None):
# type: (Text) -> ()
"""
Program interface to the system pip.
"""
self._bin = interpreter or sys.executable
@property
def bin(self):
return self._bin
def create(self):
pass
def remove(self):
pass
def install_from_file(self, path):
self.run_with_env(('install', '-r', path) + self.install_flags())
def install_packages(self, *packages):
self._install(*(packages + self.install_flags()))
def _install(self, *args):
self.run_with_env(('install',) + args)
def uninstall_packages(self, *packages):
self.run_with_env(('uninstall', '-y') + packages)
def download_package(self, package, cache_dir):
self.run_with_env(
(
'download',
package,
'--dest', cache_dir,
'--no-deps',
) + self.install_flags()
)
def load_requirements(self, requirements):
requirements = requirements.get('pip') if isinstance(requirements, dict) else requirements
if not requirements:
return
with self.temp_file('cached-reqs', requirements) as path:
self.install_from_file(path)
def uninstall(self, package):
self.run_with_env(('uninstall', '-y', package))
def freeze(self):
"""
pip freeze to all install packages except the running program
:return: Dict contains pip as key and pip's packages to install
:rtype: Dict[str: List[str]]
"""
packages = self.run_with_env(('freeze',), output=True).splitlines()
packages_without_program = [package for package in packages if PROGRAM_NAME not in package]
return {'pip': packages_without_program}
def run_with_env(self, command, output=False, **kwargs):
"""
Run a shell command using environment from a virtualenv script
:param command: command to run
:type command: Iterable[Text]
:param output: return output
:param kwargs: kwargs for get_output/check_output command
"""
command = self._make_command(command)
return (command.get_output if output else command.check_call)(stdin=DEVNULL, **kwargs)
def _make_command(self, command):
return Argv(self.bin, '-m', 'pip', *command)
def install_flags(self):
if self.indices_args is None:
self.indices_args = tuple(
chain.from_iterable(('--extra-index-url', x) for x in PIP_EXTRA_INDICES)
)
return self.indices_args

View File

@@ -0,0 +1,77 @@
from pathlib2 import Path
from trains_agent.helper.base import select_for_platform, rm_tree
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session
from ..pip_api.system import SystemPip
from ..requirements import RequirementsManager
class VirtualenvPip(SystemPip, PackageManager):
def __init__(self, session, python, requirements_manager, path, interpreter=None):
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
"""
Program interface to virtualenv pip.
Must be given either path to virtualenv or source command.
Either way, ``self.source`` is exposed.
:param python: interpreter path
:param path: path of virtual environment to create/manipulate
:param python: python version
:param interpreter: path of python interpreter
"""
super(VirtualenvPip, self).__init__(
interpreter
or Path(
path,
select_for_platform(linux="bin/python", windows="scripts/python.exe"),
)
)
self.session = session
self.path = path
self.requirements_manager = requirements_manager
self.python = "python{}".format(python)
def _make_command(self, command):
return self.session.command(self.bin, "-m", "pip", *command)
def load_requirements(self, requirements):
if isinstance(requirements, dict) and requirements.get("pip"):
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
super(VirtualenvPip, self).load_requirements(requirements)
self.requirements_manager.post_install()
def create_flags(self):
"""
Configurable environment creation arguments
"""
return Argv.conditional_flag(
self.session.config["agent.package_manager.system_site_packages"],
"--system-site-packages",
)
def install_flags(self):
"""
Configurable package installation creation arguments
"""
return super(VirtualenvPip, self).install_flags() + Argv.conditional_flag(
self.session.config["agent.package_manager.force_upgrade"], "--upgrade"
)
def create(self):
"""
Create virtualenv.
Only valid if instantiated with path.
Use self.python as self.bin does not exist.
"""
self.session.command(
self.python, "-m", "virtualenv", self.path, *self.create_flags()
).check_call()
return self
def remove(self):
"""
Delete virtualenv.
Only valid if instantiated with path.
"""
rm_tree(self.path)

View File

@@ -0,0 +1,98 @@
from functools import wraps
import attr
from pathlib2 import Path
from trains_agent.helper.process import Argv, DEVNULL
from trains_agent.session import Session, POETRY
def prop_guard(prop, log_prop=None):
assert isinstance(prop, property)
assert not log_prop or isinstance(log_prop, property)
def decorator(func):
message = "%s:%s calling {}, {} = %s".format(
func.__name__, prop.fget.__name__
)
@wraps(func)
def new_func(self, *args, **kwargs):
prop_value = prop.fget(self)
if log_prop:
log_prop.fget(self).debug(
message,
type(self).__name__,
"" if prop_value else " not",
prop_value,
)
if prop_value:
return func(self, *args, **kwargs)
return new_func
return decorator
class PoetryConfig:
def __init__(self, session):
# type: (Session) -> ()
self.session = session
self._log = session.get_logger(__name__)
@property
def log(self):
return self._log
@property
def enabled(self):
return self.session.config["agent.package_manager.type"] == POETRY
_guard_enabled = prop_guard(enabled, log)
def run(self, *args, **kwargs):
func = kwargs.pop("func", Argv.get_output)
kwargs.setdefault("stdin", DEVNULL)
argv = Argv("poetry", "-n", *args)
self.log.debug("running: %s", argv)
return func(argv, **kwargs)
def _config(self, *args, **kwargs):
return self.run("config", *args, **kwargs)
@_guard_enabled
def initialize(self):
self._config("settings.virtualenvs.in-project", "true")
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
def get_api(self, path):
# type: (Path) -> PoetryAPI
return PoetryAPI(self, path)
@attr.s
class PoetryAPI(object):
config = attr.ib(type=PoetryConfig)
path = attr.ib(type=Path, converter=Path)
INDICATOR_FILES = "pyproject.toml", "poetry.lock"
def install(self):
# type: () -> bool
if self.enabled:
self.config.run("install", cwd=str(self.path), func=Argv.check_call)
return True
return False
@property
def enabled(self):
return self.config.enabled and (
any((self.path / indicator).exists() for indicator in self.INDICATOR_FILES)
)
def freeze(self):
return {"poetry": self.config.run("show", cwd=str(self.path)).splitlines()}
def get_python_command(self, extra):
return Argv("poetry", "run", "python", *extra)

View File

@@ -0,0 +1,595 @@
from __future__ import unicode_literals
import re
import sys
from furl import furl
import urllib.parse
from operator import itemgetter
from html.parser import HTMLParser
from typing import Text
import attr
import requests
from semantic_version import Version, Spec
import six
from .requirements import SimpleSubstitution, FatalSpecsResolutionError
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
def os_to_wheel_name(x):
return OS_TO_WHEEL_NAME[x]
def fix_version(version):
def replace(nums, prerelease):
if prerelease:
return "{}-{}".format(nums, prerelease)
return nums
return re.sub(
r"(\d+(?:\.\d+){,2})(?:\.(.*))?",
lambda match: replace(*match.groups()),
version,
)
class LinksHTMLParser(HTMLParser):
def __init__(self):
super(LinksHTMLParser, self).__init__()
self.links = []
def handle_data(self, data):
if data and data.strip():
self.links += [data]
@attr.s
class PytorchWheel(object):
os_name = attr.ib(type=str, converter=os_to_wheel_name)
cuda_version = attr.ib(converter=lambda x: "cu{}".format(x) if x else "cpu")
python = attr.ib(type=str, converter=lambda x: str(x).replace(".", ""))
torch_version = attr.ib(type=str, converter=fix_version)
url_template = (
"http://download.pytorch.org/whl/"
"{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
)
def __attrs_post_init__(self):
self.unicode = "u" if self.python.startswith("2") else ""
def make_url(self):
# type: () -> Text
return self.url_template.format(self)
class PytorchResolutionError(FatalSpecsResolutionError):
pass
class SimplePytorchRequirement(SimpleSubstitution):
name = "torch"
packages = ("torch", "torchvision", "torchaudio")
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
torch_page_lookup = {
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
90: 'https://download.pytorch.org/whl/cu90/torch_stable.html',
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
}
def __init__(self, *args, **kwargs):
super(SimplePytorchRequirement, self).__init__(*args, **kwargs)
self._matched = False
def match(self, req):
# match both any of out packages
return req.name in self.packages
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Get rid of +cpu +cu?? etc.
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except:
pass
self._matched = True
return Text(req)
def matching_done(self, reqs, package_manager):
# type: (Sequence[MarkerRequirement], object) -> ()
if not self._matched:
return
# TODO: add conda channel support
from .pip_api.system import SystemPip
if package_manager and isinstance(package_manager, SystemPip):
extra_url, _ = self.get_torch_page(self.cuda_version)
package_manager.add_extra_install_flags(('-f', extra_url))
@classmethod
def get_torch_page(cls, cuda_version):
try:
cuda = int(cuda_version)
except:
cuda = 0
# first check if key is valid
if cuda in cls.torch_page_lookup:
return cls.torch_page_lookup[cuda], cuda
# then try a new cuda version page
torch_url = cls.page_lookup_template.format(cuda)
try:
if requests.get(torch_url, timeout=10).ok:
cls.torch_page_lookup[cuda] = torch_url
return cls.torch_page_lookup[cuda], cuda
except Exception:
pass
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
for k in keys:
if k <= cuda:
return cls.torch_page_lookup[k], k
# return default - zero
return cls.torch_page_lookup[0], 0
class PytorchRequirement(SimpleSubstitution):
name = "torch"
packages = ("torch", "torchvision", "torchaudio")
def __init__(self, *args, **kwargs):
os_name = kwargs.pop("os_override", None)
super(PytorchRequirement, self).__init__(*args, **kwargs)
self.log = self._session.get_logger(__name__)
self.package_manager = self.config["agent.package_manager.type"].lower()
self.os = os_name or self.get_platform()
self.cuda = "cuda{}".format(self.cuda_version).lower()
self.python_version_string = str(self.config["agent.default_python"])
self.python_semantic_version = Version.coerce(
self.python_version_string, partial=True
)
self.python = "python{}.{}".format(self.python_semantic_version.major, self.python_semantic_version.minor)
self.exceptions = [
PytorchResolutionError(message)
for message in (
None,
'cuda version "{}" is not supported'.format(self.cuda),
'python version "{}" is not supported'.format(
self.python_version_string
),
)
]
try:
self.validate_python_version()
except PytorchResolutionError as e:
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
@property
def is_conda(self):
return self.package_manager == "conda"
@property
def is_pip(self):
return not self.is_conda
def validate_python_version(self):
"""
Make sure python version has both major and minor versions as required for choosing pytorch wheel
"""
if self.is_pip and not (
self.python_semantic_version.major and self.python_semantic_version.minor
):
raise PytorchResolutionError(
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
"must have both major and minor parts of the version (for example: '3.7')".format(
self.python_version_string
)
)
def match(self, req):
return req.name in self.packages
@staticmethod
def get_platform():
if sys.platform == "linux":
return "linux"
if sys.platform == "win32" or sys.platform == "cygwin":
return "windows"
if sys.platform == "darwin":
return "macos"
raise RuntimeError("unrecognized OS")
def _get_link_from_torch_page(self, req, torch_url):
links_parser = LinksHTMLParser()
links_parser.feed(requests.get(torch_url, timeout=10).text)
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
py_ver = "{0.major}{0.minor}".format(self.python_semantic_version)
url = None
# search for our package
for l in links_parser.links:
parts = l.split('/')[-1].split('-')
if len(parts) < 5:
continue
if parts[0] != req.name:
continue
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
if parts[1].split('%')[0].split('+')[0] != req.specs[0][1]:
continue
if not parts[2].endswith(py_ver):
continue
if platform_wheel not in parts[4]:
continue
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
break
return url
def get_url_for_platform(self, req):
assert self.package_manager == "pip"
assert self.os != "mac"
assert req.specs
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except:
pass
op, version = req.specs[0]
# assert op == "=="
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version:
if not url:
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
url = self._get_link_from_torch_page(req, torch_url)
if not url:
url = PytorchWheel(
torch_version=fix_version(version),
python="{0.major}{0.minor}".format(self.python_semantic_version),
os_name=self.os,
cuda_version=self.cuda_version,
).make_url()
if url:
# normalize url (sometimes we will get ../ which we should not...
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
self.log.debug("checking url: %s", url)
return url, requests.head(url, timeout=10).ok
@staticmethod
def match_version(req, options):
versioned_options = sorted(
((Version(fix_version(key)), value) for key, value in options.items()),
key=itemgetter(0),
reverse=True,
)
req.specs = [(op, fix_version(version)) for op, version in req.specs]
if req.specs:
specs = Spec(req.format_specs())
else:
specs = None
try:
return next(
replacement
for version, replacement in versioned_options
if not specs or version in specs
)
except StopIteration:
raise PytorchResolutionError(
'Could not find wheel for "{}", '
"Available versions: {}".format(req, list(options))
)
def replace_conda(self, req):
spec = "".join(req.specs[0]) if req.specs else ""
if not self.cuda_version:
return "pytorch-cpu{spec}\ntorchvision-cpu".format(spec=spec)
return "pytorch{spec}\ntorchvision\ncuda{self.cuda_version}".format(
self=self, spec=spec
)
def _table_lookup(self, req):
"""
Look for pytorch wheel matching `req` in table
:param req: python requirement
"""
def check(base_, key_, exception_):
result = base_.get(key_)
if not result:
if key_.startswith('cuda'):
print('Could not locate, {}'.format(exception_))
ver = sorted([float(a.replace('cuda', '').replace('none', '0')) for a in base_.keys()], reverse=True)[0]
key_ = 'cuda'+str(int(ver))
result = base_.get(key_)
print('Reverting to \"{}\"'.format(key_))
if not result:
raise exception_
return result
raise exception_
if isinstance(result, Exception):
raise result
return result
if self.is_conda:
return self.replace_conda(req)
base = self.MAP
for key, exception in zip((self.os, self.cuda, self.python), self.exceptions):
base = check(base, key, exception)
return self.match_version(req, base).replace(" ", "\n")
def replace(self, req):
try:
return self._replace(req)
except Exception as e:
message = "Exception when trying to resolve python wheel"
self.log.debug(message, exc_info=True)
raise PytorchResolutionError("{}: {}".format(message, e))
def _replace(self, req):
self.validate_python_version()
try:
result, ok = self.get_url_for_platform(req)
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
except:
pass
try:
result = self._table_lookup(req)
except Exception as e:
exc = e
else:
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
self.log.debug(
"Could not find Pytorch wheel in table, trying manually constructing URL"
)
result = ok = None
# try:
# result, ok = self.get_url_for_platform(req)
# except Exception:
# pass
if not ok:
if result:
self.log.debug("URL not found: {}".format(result))
exc = PytorchResolutionError(
"Was not able to find pytorch wheel URL: {}".format(exc)
)
# cancel exception chaining
six.raise_from(exc, None)
self.log.debug('Replacing requirement "%s" with %r', req, result)
return result
MAP = {
"windows": {
"cuda100": {
"python3.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-win_amd64.whl"
},
"python3.6": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda92": {
"python3.7": {
"0.4.1",
"http://download.pytorch.org/whl/cu92/torch-0.4.1-cp37-cp37m-win_amd64.whl",
},
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda91": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-win_amd64.whl"
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-win_amd64.whl"
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda90": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cuda80": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
"cudanone": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-win_amd64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-win_amd64.whl",
},
"python2.7": PytorchResolutionError(
"PyTorch does not support Python 2.7 on Windows"
),
},
},
"macos": {
"cuda100": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda92": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda91": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda90": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cuda80": PytorchResolutionError(
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
),
"cudanone": {
"python3.6": {"0.4.0": "torch"},
"python3.5": {"0.4.0": "torch"},
"python2.7": {"0.4.0": "torch"},
},
},
"linux": {
"cuda100": {
"python3.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp37-cp37m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl",
},
"python3.6": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp36-cp36m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl",
},
"python3.5": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp35-cp35m-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp35-cp35m-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl",
},
"python2.7": {
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp27-cp27mu-linux_x86_64.whl",
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl",
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl",
},
},
"cuda92": {
"python3.7": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp37-cp37m-manylinux1_x86_64.whl"
},
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl"
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp35-cp35m-manylinux1_x86_64.whl"
},
"python2.7": {
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp27-cp27mu-manylinux1_x86_64.whl"
},
},
"cuda91": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl"
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl"
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl"
},
},
"cuda90": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
"cuda80": {
"python3.6": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
"0.3.1": "torch==0.3.1",
"0.3.0.post4": "torch==0.3.0.post4",
"0.1.2.post1": "torch==0.1.2.post1",
"0.1.2": "torch==0.1.2",
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
"cudanone": {
"python3.6": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
},
"python3.5": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
},
"python2.7": {
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
},
},
},
}

View File

@@ -0,0 +1,340 @@
from __future__ import absolute_import, unicode_literals
import operator
import os
import re
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from itertools import chain, starmap
from operator import itemgetter
from os import path
from typing import Text, List, Type, Optional, Tuple
import semantic_version
from pathlib2 import Path
from pyhocon import ConfigTree
from requirements import parse
# noinspection PyPackageRequirements
from requirements.requirement import Requirement
import six
from trains_agent.definitions import PIP_EXTRA_INDICES
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session, normalize_cuda_version
from .translator import RequirementsTranslator
class SpecsResolutionError(Exception):
pass
class FatalSpecsResolutionError(Exception):
pass
@six.python_2_unicode_compatible
class MarkerRequirement(object):
def __init__(self, req): # type: (Requirement) -> None
self.req = req
@property
def marker(self):
match = re.search(r';\s*(.*)', self.req.line)
if match:
return match.group(1)
return None
def tostr(self, markers=True):
if not self.uri:
parts = [self.name]
if self.extras:
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
if self.specifier:
parts.append(self.format_specs())
else:
parts = [self.uri]
if markers and self.marker:
parts.append('; {0}'.format(self.marker))
return ''.join(parts)
__str__ = tostr
def __repr__(self):
return '{self.__class__.__name__}[{self}]'.format(self=self)
def format_specs(self):
return ','.join(starmap(operator.add, self.specs))
def __getattr__(self, item):
return getattr(self.req, item)
@property
def specs(self): # type: () -> List[Tuple[Text, Text]]
return self.req.specs
@specs.setter
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
self.req.specs = value
def fix_specs(self):
def solve_by(func, op_is, specs):
return func([(op, version) for op, version in specs if op == op_is])
def solve_equal(specs):
if len(set(version for _, version in self.specs)) > 1:
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
return specs
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
equal = solve_by(solve_equal, '==', self.specs)
if equal:
self.specs = equal
else:
self.specs = greater + smaller
@six.add_metaclass(ABCMeta)
class RequirementSubstitution(object):
_pip_extra_index_url = PIP_EXTRA_INDICES
def __init__(self, session):
# type: (Session) -> ()
self._session = session
self.config = session.config # type: ConfigTree
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
self.package_manager = self.config['agent.package_manager.type']
@abstractmethod
def match(self, req): # type: (MarkerRequirement) -> bool
"""
Returns whether a requirement needs to be modified by this substitution.
"""
pass
@abstractmethod
def replace(self, req): # type: (MarkerRequirement) -> Text
"""
Replace a requirement
"""
pass
def post_install(self):
pass
@classmethod
def get_pip_version(cls, package):
output = Argv(
'pip',
'search',
package,
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
).get_output()
# ad-hoc pattern to duplicate the behavior of the old code
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
@property
def cuda_version(self):
return self.config['agent.cuda_version']
@property
def cudnn_version(self):
return self.config['agent.cudnn_version']
class SimpleSubstitution(RequirementSubstitution):
@property
@abstractmethod
def name(self):
pass
def match(self, req): # type: (MarkerRequirement) -> bool
return (self.name == req.name or (
req.uri and
re.match(r'https?://', req.uri) and
self.name in req.uri
))
def replace(self, req): # type: (MarkerRequirement) -> Text
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
if req.uri:
return re.sub(
r'({})(.*?)(-cp)'.format(self.name),
r'\1\2{}\3'.format(self.suffix),
req.uri,
count=1)
if req.specs:
_, version_number = req.specs[0]
assert semantic_version.Version(version_number, partial=True)
else:
version_number = self.get_pip_version(self.name)
req.specs = [('==', version_number + self.suffix)]
return Text(req)
@six.add_metaclass(ABCMeta)
class CudaSensitiveSubstitution(SimpleSubstitution):
def match(self, req): # type: (MarkerRequirement) -> bool
return self.cuda_version and self.cudnn_version and \
super(CudaSensitiveSubstitution, self).match(req)
class CudaNotFound(Exception):
pass
class RequirementsManager(object):
def __init__(self, session, base_interpreter=None):
# type: (Session, PathLike) -> ()
self._session = session
self.config = deepcopy(session.config) # type: ConfigTree
self.handlers = [] # type: List[RequirementSubstitution]
agent = self.config['agent']
self.active = not agent.get('cpu_only', False)
self.found_cuda = False
if self.active:
try:
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
self.found_cuda = True
except Exception:
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
if agent.get('cuda_version'):
self.found_cuda = True
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
cache_dir=pip_cache_dir.as_posix())
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
self.handlers.append(cls(self._session))
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
match = re.search(r';\s*(.*)', Text(req))
if match:
req.markers = match.group(1).split(',')
if not self.active:
return None
for handler in self.handlers:
if handler.match(req):
return handler.replace(req)
return None
def replace(self, requirements): # type: (Text) -> Text
parsed_requirements = tuple(
map(
MarkerRequirement,
filter(
None,
parse(requirements)
if isinstance(requirements, six.text_type)
else (next(parse(line), None) for line in requirements)
)
)
)
if not parsed_requirements:
# return the original requirements just in case
return requirements
def replace_one(i, req):
# type: (int, MarkerRequirement) -> Optional[Text]
try:
return self._replace_one(req)
except FatalSpecsResolutionError:
raise
except Exception:
warning('could not find installed CUDA/CuDNN version for {}, '
'using original requirements line: {}'.format(req, i))
return None
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
conda = is_conda(self.config)
result = map(
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
new_requirements,
parsed_requirements
)
if not conda:
result = map(self.translator.translate, result)
return join_lines(result)
def post_install(self):
for h in self.handlers:
try:
h.post_install()
except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
@staticmethod
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
cuda_version = config['agent.cuda_version']
cudnn_version = config['agent.cudnn_version']
if cuda_version and cudnn_version:
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
if not cuda_version:
try:
try:
output = Argv('nvcc', '--version').get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'release (.{3})', output).group(1)
cuda_version = Text(int(float(match) * 10))
except:
pass
if not cuda_version:
try:
try:
output = Argv('nvidia-smi',).get_output()
except OSError:
raise CudaNotFound('nvcc not found')
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
match = match.group(1)+'.'+match.group(2)
cuda_version = Text(int(float(match) * 10))
except:
pass
if not cudnn_version:
try:
cuda_lib = which('nvcc')
if is_windows_platform:
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
else:
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
cudnn_major, cudnn_minor = None, None
try:
include_file = open(cudnn_h)
except OSError:
raise CudaNotFound('Could not read cudnn.h')
with include_file:
for line in include_file:
if 'CUDNN_MAJOR' in line:
cudnn_major = line.split()[-1]
if 'CUDNN_MINOR' in line:
cudnn_minor = line.split()[-1]
if cudnn_major and cudnn_minor:
break
cudnn_version = cudnn_major + (cudnn_minor or '0')
except:
pass
return (normalize_cuda_version(cuda_version or 0),
normalize_cuda_version(cudnn_version or 0))

View File

@@ -0,0 +1,63 @@
from typing import Text
from furl import furl
from pathlib2 import Path
from trains_agent.config import Config
from .pip_api.system import SystemPip
class RequirementsTranslator(object):
"""
Translate explicit URLs to local URLs after downloading them to cache
"""
SUPPORTED_SCHEMES = ["http", "https", "ftp"]
def __init__(self, session, interpreter=None, cache_dir=None):
self._session = session
config = session.config
self.cache_dir = cache_dir or Path(config["agent.pip_download_cache.path"]).expanduser().as_posix()
self.enabled = config["agent.pip_download_cache.enabled"]
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
self.config = Config()
self.pip = SystemPip(interpreter=interpreter)
def download(self, url):
self.pip.download_package(url, cache_dir=self.cache_dir)
@classmethod
def is_supported_link(cls, line):
# type: (Text) -> bool
"""
Return whether requirement is a link that should be downloaded to cache
"""
url = furl(line)
return (
url.scheme
and url.scheme.lower() in cls.SUPPORTED_SCHEMES
and line.lstrip().lower().startswith(url.scheme.lower())
)
def translate(self, line):
"""
If requirement is supported, download it to cache and return the download path
"""
if not (self.enabled and self.is_supported_link(line)):
return line
command = self.config.command
command.log('Downloading "{}" to pip cache'.format(line))
url = furl(line)
try:
wheel_name = url.path.segments[-1]
except IndexError:
command.error('Could not parse wheel name of "{}"'.format(line))
return line
try:
self.download(line)
downloaded = Path(self.cache_dir, wheel_name).expanduser().as_uri()
except Exception:
command.error('Could not download wheel name of "{}"'.format(line))
return line
return downloaded

View File

@@ -0,0 +1,115 @@
from typing import Optional, Text
import requests
from pathlib2 import Path
import six
from trains_agent.definitions import CONFIG_DIR
from trains_agent.helper.process import Argv, DEVNULL
from .pip_api.venv import VirtualenvPip
class VenvUpdateAPI(VirtualenvPip):
URL_FILE_PATH = Path(CONFIG_DIR, "venv-update-url.txt")
SCRIPT_PATH = Path(CONFIG_DIR, "venv-update")
def __init__(self, url, *args, **kwargs):
super(VenvUpdateAPI, self).__init__(*args, **kwargs)
self.url = url
self._script_path = None
self._first_install = True
@property
def downloaded_venv_url(self):
# type: () -> Optional[Text]
try:
return self.URL_FILE_PATH.read_text()
except OSError:
return None
@downloaded_venv_url.setter
def downloaded_venv_url(self, value):
self.URL_FILE_PATH.write_text(value)
def _check_script_validity(self, path):
"""
Make sure script in ``path`` is a valid python script
:param path:
:return:
"""
result = Argv(self.bin, path, "--version").call(
stdout=DEVNULL, stderr=DEVNULL, stdin=DEVNULL
)
return result == 0
@property
def script_path(self):
# type: () -> Text
if not self._script_path:
self._script_path = self.SCRIPT_PATH
if not (
self._script_path.exists()
and self.downloaded_venv_url
and self.downloaded_venv_url == self.url
and self._check_script_validity(self._script_path)
):
with self._script_path.open("wb") as f:
for data in requests.get(self.url, stream=True):
f.write(data)
self.downloaded_venv_url = self.url
return self._script_path
def install_from_file(self, path):
first_install = (
Argv(
self.python,
six.text_type(self.script_path),
"venv=",
"-p",
self.python,
self.path,
)
+ self.create_flags()
+ ("install=", "-r", path)
+ self.install_flags()
)
later_install = first_install + (
"pip-command=",
"pip-faster",
"install",
"--upgrade", # no --prune
)
self._choose_install(first_install, later_install)
def install_packages(self, *packages):
first_install = (
Argv(
self.python,
six.text_type(self.script_path),
"venv=",
self.path,
"install=",
)
+ packages
)
later_install = first_install + (
"pip-command=",
"pip-faster",
"install",
"--upgrade", # no --prune
)
self._choose_install(first_install, later_install)
def _choose_install(self, first, rest):
if self._first_install:
command = first
self._first_install = False
else:
command = rest
command.check_call(stdin=DEVNULL)
def upgrade_pip(self):
"""
pip and venv-update versions are coupled, venv-update installs the latest compatible pip
"""
pass

View File

@@ -0,0 +1,365 @@
from __future__ import unicode_literals, print_function
import abc
import logging
import os
import re
import subprocess
import sys
from contextlib import contextmanager
from copy import deepcopy
from distutils.spawn import find_executable
from itertools import chain, repeat, islice
from os.path import devnull
from typing import Union, Text, Sequence, Any, TypeVar, Callable
import psutil
from furl import furl
from future.builtins import super
from pathlib2 import Path
import six
from trains_agent.definitions import PROGRAM_NAME, CONFIG_FILE
from trains_agent.helper.base import bash_c, is_windows_platform, select_for_platform, chain_map
PathLike = Union[Text, Path]
def get_bash_output(cmd, strip=False, stderr=subprocess.STDOUT, stdin=False):
try:
output = (
subprocess.check_output(
bash_c().split() + [cmd],
stderr=stderr,
stdin=subprocess.PIPE if stdin else None,
)
.decode()
.strip()
)
except subprocess.CalledProcessError:
output = None
return output if not strip or not output else output.strip()
def kill_all_child_processes(pid=None):
# get current process if pid not provided
include_parent = True
if not pid:
pid = os.getpid()
include_parent = False
print("\nLeaving process id {}".format(pid))
try:
parent = psutil.Process(pid)
except psutil.Error:
# could not find parent process id
return
for child in parent.children(recursive=True):
child.kill()
if include_parent:
parent.kill()
def check_if_command_exists(cmd):
return bool(find_executable(cmd))
def get_program_invocation():
return [sys.executable, "-u", "-m", PROGRAM_NAME.replace('-', '_')]
Retval = TypeVar("Retval")
@six.add_metaclass(abc.ABCMeta)
class Executable(object):
@abc.abstractmethod
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
# type: (Callable[..., Retval]) -> Retval
pass
def get_output(self, *args, **kwargs):
return (
self.call_subprocess(subprocess.check_output, *args, **kwargs)
.decode("utf8")
.rstrip()
)
def check_call(self, *args, **kwargs):
return self.call_subprocess(subprocess.check_call, *args, **kwargs)
@staticmethod
@contextmanager
def normalize_exception(censor_password=False):
try:
yield
except subprocess.CalledProcessError as e:
if censor_password:
e.cmd = [furl(word).remove(password=True).tostr() for word in e.cmd]
if e.output and not isinstance(e.output, six.text_type):
e.output = e.output.decode()
raise
@abc.abstractmethod
def pretty(self):
pass
class Argv(Executable):
ARGV_SEPARATOR = " "
def __init__(self, *argv, **kwargs):
# type: (*PathLike, Any) -> ()
"""
Object representing a series of strings used to invoke a process.
"""
self.argv = argv
self._log = kwargs.pop("log", None)
if not self._log:
self._log = logging.getLogger(__name__)
self._log.propagate = False
def serialize(self):
"""
Returns a string of the shell command
"""
return self.ARGV_SEPARATOR.join(map(quote, self))
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
self._log.debug("running: %s: %s", func.__name__, list(self))
with self.normalize_exception(censor_password):
return func(list(self), *args, **kwargs)
def call(self, *args, **kwargs):
return self.call_subprocess(subprocess.call, *args, **kwargs)
def get_argv(self):
return self.argv
def __repr__(self):
return "<Argv{}>".format(self.argv)
def __str__(self):
return "Executing: {}".format(self.argv)
def __iter__(self):
return (six.text_type(word) for word in self.argv)
def __getitem__(self, item):
return self.argv[item]
def __add__(self, other):
try:
iter(other)
except TypeError:
return NotImplemented
return type(self)(*(self.argv + tuple(other)), log=self._log)
def __radd__(self, other):
try:
iter(other)
except TypeError:
return NotImplemented
return type(self)(*(tuple(other) + self.argv), log=self._log)
pretty = serialize
@staticmethod
def conditional_flag(condition, flag, *flags):
# type: (Any, PathLike, PathLike) -> Sequence[PathLike]
"""
Translate a boolean to a flag command like arguments.
:param condition: condition to translate to flag
:param flag: flag to use if condition true (at least one)
:param flags: additional flags to use if condition is true
"""
return (flag,) + flags if condition else ()
class CommandSequence(Executable):
JOIN_COMMAND_OPERATOR = "&&"
def __init__(self, *commands, **kwargs):
"""
Object representing a sequence of shell commands.
:param commands: Command elements. Each CommandSequence will be treated as a single command-line argument.
:type commands: Each command: [str] | Argv
"""
self._log = kwargs.pop("log", None)
if not self._log:
self._log = logging.getLogger(__name__)
self._log.propagate = False
self.commands = []
for c in commands:
if isinstance(c, CommandSequence):
self.commands.extend(deepcopy(c.commands))
elif isinstance(c, Argv):
self.commands.append(deepcopy(c))
else:
self.commands.append(Argv(*c, log=self._log))
def get_argv(self, shell=False):
"""
Get array of argv's.
:param bool shell: if True, returns the argv of a process that will invoke a shell running the command sequence
"""
if shell:
return tuple(bash_c().split()) + (self.serialize(),)
def safe_get_argv(obj):
try:
func = obj.get_argv
except AttributeError:
result = obj
else:
result = func()
return tuple(map(str, result))
return tuple(map(safe_get_argv, self.commands))
def serialize(self):
def intersperse(delimiter, seq):
return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None)
def normalize(command):
return list(command) if is_windows_platform() else command.serialize()
return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands))))
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
with self.normalize_exception(censor_password):
return func(
self.serialize(),
*args,
**chain_map(
dict(
executable=select_for_platform(linux="bash", windows=None),
shell=True,
),
kwargs,
)
)
def __repr__(self):
tab = " " * 4
return "<{}(\n{}{},\n)>".format(
type(self).__name__, tab, (",\n" + tab).join(map(repr, self.commands))
)
def __iter__(self):
return iter(self.commands)
def __getitem__(self, item):
return self.commands[item]
def __setitem__(self, key, value):
self.commands[key] = value
def __add__(self, other):
try:
iter(other)
except TypeError:
return NotImplemented
return type(self)(*(self.commands + tuple(other)))
def pretty(self):
serialized = self.serialize()
if is_windows_platform():
return " ".join(serialized)
return serialized
class WorkerParams(object):
def __init__(
self,
log_level="INFO",
config_file=CONFIG_FILE,
optimization=0,
debug=False,
trace=False,
):
self.trace = trace
self.log_level = log_level
self.optimization = optimization
self.config_file = config_file
self.debug = debug
def get_worker_flags(self):
"""
Serialize a WorkerParams instance to a tuple of command-line flags
:param WorkerParams self: parameters of worker
:return: a tuple of global flags and "workers execute/daemon" flags
"""
global_args = ("--config-file", str(self.config_file))
if self.debug:
global_args += ("--debug",)
if self.trace:
global_args += ("--trace",)
worker_args = tuple()
if self.optimization:
worker_args += self.get_optimization_flag()
return global_args, worker_args
def get_optimization_flag(self):
return "-{}".format("O" * self.optimization)
def get_argv_for_command(self, command):
"""
Get argv for a particular worker command.
"""
global_args, worker_args = self.get_worker_flags()
command_line = (
tuple(get_program_invocation())
+ global_args
+ (command, )
+ worker_args
)
return Argv(*command_line)
class DaemonParams(WorkerParams):
def __init__(self, foreground=False, queues=(), *args, **kwargs):
super(DaemonParams, self).__init__(*args, **kwargs)
self.foreground = foreground
self.queues = tuple(queues)
def get_worker_flags(self):
global_args, worker_args = super(DaemonParams, self).get_worker_flags()
if self.foreground:
worker_args += ("--foreground",)
if self.queues:
worker_args += ("--queue",) + self.queues
return global_args, worker_args
DEVNULL = open(devnull, "w+")
SOURCE_COMMAND = select_for_platform(linux="source", windows="call")
class ExitStatus(object):
success = 0
failure = 1
interrupted = 2
COMMAND_SUCCESS = 0
_find_unsafe = re.compile(r"[^\w@%+=:,./-]", getattr(re, "ASCII", 0)).search
def quote(s):
"""
Backport of shlex.quote():
Return a shell-escaped version of the string *s*.
"""
if not s:
return "''"
if _find_unsafe(s) is None:
return s
# use single quotes, and put single quotes into double quotes
# the string $'b is then quoted as '$'"'"'b'
return "'" + s.replace("'", "'\"'\"'") + "'"

555
trains_agent/helper/repo.py Normal file
View File

@@ -0,0 +1,555 @@
import abc
import re
import shutil
import subprocess
from distutils.spawn import find_executable
from hashlib import md5
from os import environ, getenv
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple
import attr
from furl import furl
from pathlib2 import Path
import six
from trains_agent.helper.console import ensure_text, ensure_binary
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import (
select_for_platform,
rm_tree,
ExecutionInfo,
normalize_path,
create_file_if_not_exists,
)
from trains_agent.helper.process import DEVNULL, Argv, PathLike, COMMAND_SUCCESS
from trains_agent.session import Session
class VcsFactory(object):
"""
Creates VCS instances
"""
GIT_SUFFIX = ".git"
@classmethod
def create(cls, session, execution_info, location):
# type: (Session, ExecutionInfo, PathLike) -> VCS
"""
Create a VCS instance for config and url
:param session: program session
:param execution_info: task ExecutionInfo
:param location: (desired) clone location
"""
url = execution_info.repository
is_git = url.endswith(cls.GIT_SUFFIX)
vcs_cls = Git if is_git else Hg
revision = (
execution_info.version_num
or execution_info.tag
or vcs_cls.remote_branch_name(execution_info.branch or vcs_cls.main_branch)
)
return vcs_cls(session, url, location, revision)
# noinspection PyUnresolvedReferences
@attr.s
class RepoInfo(object):
"""
Cloned repository information
:param type: VCS type
:param url: repository url
:param branch: revision branch
:param commit: revision number
:param root: clone location path
"""
type = attr.ib(type=str)
url = attr.ib(type=str)
branch = attr.ib(type=str)
commit = attr.ib(type=str)
root = attr.ib(type=str)
RType = TypeVar("RType")
@six.add_metaclass(abc.ABCMeta)
class VCS(object):
"""
Provides overloaded utilities for handling repositories of different types
"""
# additional environment variables for VCS commands
COMMAND_ENV = {}
PATCH_ADDED_FILE_RE = re.compile(r"^\+\+\+ b/(?P<path>.*)")
def __init__(self, session, url, location, revision):
# type: (Session, Text, PathLike, Text) -> ()
"""
Create a VCS instance for config and url
:param session: program session
:param url: repository url
:param location: (desired) clone location
:param: desired clone revision
"""
self.session = session
self.log = self.session.get_logger(
"{}.{}".format(__name__, type(self).__name__)
)
self.url = url
self.location = Text(location)
self.revision = revision
self.log = self.session.get_logger(__name__)
@property
def url_with_auth(self):
"""
Return URL with configured user/password
"""
return self.add_auth(self.session.config, self.url)
@abc.abstractproperty
def executable_name(self):
"""
Name of command executable
"""
pass
@abc.abstractproperty
def main_branch(self):
"""
Name of default/main branch
"""
pass
@abc.abstractproperty
def checkout_flags(self):
# type: () -> Sequence[Text]
"""
Command-line flags for checkout command
"""
pass
@abc.abstractproperty
def patch_base(self):
# type: () -> Sequence[Text]
"""
Command and flags for applying patches
"""
pass
def patch(self, location, patch_content):
# type: (PathLike, Text) -> bool
"""
Apply patch repository at `location`
"""
self.log.info("applying diff to %s", location)
for match in filter(
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
):
create_file_if_not_exists(normalize_path(location, match.group("path")))
return_code, errors = self.call_with_stdin(
patch_content, *self.patch_base, cwd=location
)
if return_code:
self.log.error("Failed applying diff")
lines = errors.splitlines()
if any(l for l in lines if "no such file or directory" in l.lower()):
self.log.warning(
"NOTE: files were not found when applying diff, perhaps you forgot to push your changes?"
)
return False
else:
self.log.info("successfully applied uncommitted changes")
return True
# Command-line flags for clone command
clone_flags = ()
@abc.abstractmethod
def executable_not_found_error_help(self):
# type: () -> Text
"""
Instructions for when executable is not found
"""
pass
@staticmethod
def remote_branch_name(branch):
# type: (Text) -> Text
"""
Creates name of remote branch from name of local/ambiguous branch.
Returns same name by default.
"""
return branch
# parse scp-like git ssh URLs, e.g: git@host:user/project.git
SSH_URL_GIT_SYNTAX = re.compile(
r"""
^
(?:(?P<user>{regular}*?)@)?
(?P<host>{regular}*?)
:
(?P<path>{regular}.*)?
$
""".format(
regular=r"[^/@:#]"
),
re.VERBOSE,
)
@classmethod
def resolve_ssh_url(cls, url):
# type: (Text) -> Text
"""
Replace SSH URL with HTTPS URL when applicable
"""
def get_username(user_, password=None):
"""
Remove special SSH users hg/git
"""
return (
None
if user_ and user_.lower() in ["hg", "git"] and not password
else user_
)
match = cls.SSH_URL_GIT_SYNTAX.match(url)
if match:
user, host, path = match.groups()
return (
furl()
.set(scheme="https", username=get_username(user), host=host, path=path)
.url
)
parsed_url = furl(url)
if parsed_url.scheme == "ssh":
return parsed_url.set(
scheme="https",
username=get_username(
parsed_url.username, password=parsed_url.password
),
).url
return url
def _set_ssh_url(self):
"""
Replace instance URL with SSH substitution result and report to log.
According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
"""
if not self.session.config.agent.translate_ssh:
return
ssh_agent_variable = "SSH_AUTH_SOCK"
if not getenv(ssh_agent_variable) and (self.session.config.get('agent.git_user', None) and
self.session.config.get('agent.git_pass', None)):
new_url = self.resolve_ssh_url(self.url)
if new_url != self.url:
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
self.url, new_url))
self.url = new_url
def clone(self, branch=None):
# type: (Text) -> None
"""
Clone repository to destination and checking out `branch`.
If not in debug mode, filter VCS password from output.
"""
self._set_ssh_url()
clone_command = ("clone", self.url_with_auth, self.location) + self.clone_flags
if branch:
clone_command += ("-b", branch)
if self.session.debug_mode:
self.call(*clone_command)
return
def normalize_output(result):
"""
Returns result string without user's password.
NOTE: ``self.get_stderr``'s result might or might not have the same type as ``e.output`` in case of error.
"""
string_type = (
ensure_text
if isinstance(result, six.text_type)
else ensure_binary
)
return result.replace(
string_type(self.url),
string_type(furl(self.url).remove(password=True).tostr()),
)
def print_output(output):
print(ensure_text(output))
try:
print_output(normalize_output(self.get_stderr(*clone_command)))
except subprocess.CalledProcessError as e:
# In Python 3, subprocess.CalledProcessError has a `stderr` attribute,
# but since stderr is redirect to `subprocess.PIPE` it will appear in the usual `output` attribute
if e.output:
e.output = normalize_output(e.output)
print_output(e.output)
raise
def checkout(self):
# type: () -> None
"""
Checkout repository at specified revision
"""
self.call("checkout", self.revision, *self.checkout_flags, cwd=self.location)
@abc.abstractmethod
def pull(self):
# type: () -> None
"""
Pull remote changes for revision
"""
pass
def call(self, *argv, **kwargs):
"""
Execute argv without stdout/stdin.
Remove stdin so git/hg can't ask for passwords.
``kwargs`` can override all arguments passed to subprocess.
"""
return self._call_subprocess(subprocess.check_call, argv, **kwargs)
def call_with_stdin(self, input_, *argv, **kwargs):
# type: (...) -> Tuple[int, str]
"""
Run command with `input_` as stdin
"""
input_ = input_.encode("latin1")
if not input_.endswith(b"\n"):
input_ += b"\n"
process = self._call_subprocess(
subprocess.Popen,
argv,
**dict(
kwargs,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
)
_, stderr = process.communicate(input_)
if stderr:
self.log.warning("%s: %s", self._get_vcs_command(argv), stderr)
return process.returncode, Text(stderr)
def get_stderr(self, *argv, **kwargs):
"""
Execute argv without stdout/stdin in <cwd> and get stderr output.
Remove stdin so git/hg can't ask for passwords.
``kwargs`` can override all arguments passed to subprocess.
"""
process = self._call_subprocess(
subprocess.Popen, argv, **dict(kwargs, stderr=subprocess.PIPE, stdout=None)
)
_, stderr = process.communicate()
code = process.poll()
if code == COMMAND_SUCCESS:
return stderr
with Argv.normalize_exception(censor_password=True):
raise subprocess.CalledProcessError(
returncode=code, cmd=argv, output=stderr
)
def _call_subprocess(self, func, argv, **kwargs):
# type: (Callable[..., RType], Iterable[Text], dict) -> RType
cwd = kwargs.pop("cwd", None)
cwd = cwd and str(cwd)
kwargs = dict(
dict(
censor_password=True,
cwd=cwd,
stdin=DEVNULL,
stdout=DEVNULL,
env=dict(self.COMMAND_ENV, **environ),
),
**kwargs
)
command = self._get_vcs_command(argv)
self.log.debug("Running: %s", list(command))
return command.call_subprocess(func, **kwargs)
def _get_vcs_command(self, argv):
# type: (Iterable[PathLike]) -> Argv
return Argv(self.executable_name, *argv)
@classmethod
def add_auth(cls, config, url):
"""
Add username and password to URL if missing from URL and present in config.
Does not modify ssh URLs.
"""
parsed_url = furl(url)
if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"):
return parsed_url.url
config_user = config.get("agent.{}_user".format(cls.executable_name), None)
config_pass = config.get("agent.{}_pass".format(cls.executable_name), None)
if (
(not (parsed_url.username and parsed_url.password))
and config_user
and config_pass
):
parsed_url.set(username=config_user, password=config_pass)
return parsed_url.url
@abc.abstractproperty
def info_commands(self):
# type: () -> Mapping[Text, Argv]
"""
` Mapping from `RepoInfo` attribute name (except `type`) to command which acquires it
"""
pass
def get_repository_copy_info(self, path):
"""
Get `RepoInfo` instance from copy of clone in `path`
"""
path = Text(path)
commands_result = {
name: command.get_output(cwd=path)
# name: subprocess.check_output(command.split(), cwd=path).decode().strip()
for name, command in self.info_commands.items()
}
return RepoInfo(type=self.executable_name, **commands_result)
class Git(VCS):
executable_name = "git"
main_branch = "master"
clone_flags = ("--quiet", "--recursive")
checkout_flags = ("--force",)
COMMAND_ENV = {
# do not prompt for password
"GIT_TERMINAL_PROMPT": "0",
# do not prompt for ssh key passphrase
"GIT_SSH_COMMAND": "ssh -oBatchMode=yes",
}
@staticmethod
def remote_branch_name(branch):
return "origin/{}".format(branch)
def executable_not_found_error_help(self):
return 'Cannot find "{}" executable. {}'.format(
self.executable_name,
select_for_platform(
linux="You can install it by running: sudo apt-get install {}".format(
self.executable_name
),
windows="You can download it here: {}".format(
"https://gitforwindows.org/"
),
),
)
def pull(self):
self.call("fetch", "origin", cwd=self.location)
info_commands = dict(
url=Argv("git", "remote", "get-url", "origin"),
branch=Argv("git", "rev-parse", "--abbrev-ref", "HEAD"),
commit=Argv("git", "rev-parse", "HEAD"),
root=Argv("git", "rev-parse", "--show-toplevel"),
)
patch_base = ("apply",)
class Hg(VCS):
executable_name = "hg"
main_branch = "default"
checkout_flags = ("--clean",)
patch_base = ("import", "--no-commit")
def executable_not_found_error_help(self):
return 'Cannot find "{}" executable. {}'.format(
self.executable_name,
select_for_platform(
linux="You can install it by running: sudo apt-get install {}".format(
self.executable_name
),
windows="You can download it here: {}".format(
"https://www.mercurial-scm.org/wiki/Download"
),
),
)
def pull(self):
self.call(
"pull",
self.url_with_auth,
cwd=self.location,
*(("-r", self.revision) if self.revision else ())
)
info_commands = dict(
url=Argv("hg", "paths", "--verbose"),
branch=Argv("hg", "--debug", "id", "-b"),
commit=Argv("hg", "--debug", "id", "-i"),
root=Argv("hg", "root"),
)
def clone_repository_cached(session, execution, destination):
# type: (Session, ExecutionInfo, Path) -> Tuple[VCS, RepoInfo]
"""
Clone a remote repository.
:param execution: execution info
:param destination: directory to clone to (in which a directory for the repository will be created)
:param session: program session
:return: repository information
:raises: CommandFailedError if git/hg is not installed
"""
repo_url = execution.repository # type: str
parsed_url = furl(repo_url)
no_password_url = parsed_url.copy().remove(password=True).url
clone_folder_name = Path(str(furl(repo_url).path)).name # type: str
clone_folder = Path(destination) / clone_folder_name
cached_repo_path = (
Path(session.config["agent.vcs_cache.path"]).expanduser()
/ "{}.{}".format(clone_folder_name, md5(ensure_binary(repo_url)).hexdigest())
/ clone_folder_name
) # type: Path
vcs = VcsFactory.create(
session, execution_info=execution, location=cached_repo_path
)
if not find_executable(vcs.executable_name):
raise CommandFailedError(vcs.executable_not_found_error_help())
if session.config["agent.vcs_cache.enabled"] and cached_repo_path.exists():
print('Using cached repository in "{}"'.format(cached_repo_path))
else:
print("cloning: {}".format(no_password_url))
rm_tree(cached_repo_path)
vcs.clone(branch=execution.branch)
vcs.pull()
vcs.checkout()
rm_tree(destination)
shutil.copytree(Text(cached_repo_path), Text(clone_folder))
if not clone_folder.is_dir():
raise CommandFailedError(
"copying of repository failed: from {} to {}".format(
cached_repo_path, clone_folder
)
)
repo_info = vcs.get_repository_copy_info(clone_folder)
# make sure we have no user/pass in the returned repository structure
repo_info = attr.evolve(repo_info, url=no_password_url)
return vcs, repo_info

View File

@@ -0,0 +1,296 @@
from __future__ import unicode_literals, division
import logging
import os
from collections import deque
from itertools import starmap
from threading import Thread, Event
from time import time
from typing import Text, Sequence
import attr
import psutil
from pathlib2 import Path
from trains_agent.session import Session
try:
from .gpu import gpustat
except ImportError:
gpustat = None
log = logging.getLogger(__name__)
class BytesSizes(object):
@staticmethod
def kilobytes(x):
# type: (float) -> float
return x / 1024
@staticmethod
def megabytes(x):
# type: (float) -> float
return x / (1024*1024)
@staticmethod
def gigabytes(x):
# type: (float) -> float
return x / (1024*1024*1024)
class ResourceMonitor(object):
@attr.s
class StatusReport(object):
task = attr.ib(default=None, type=str)
queue = attr.ib(default=None, type=str)
queues = attr.ib(default=None, type=Sequence[str])
def to_dict(self):
return {
key: value
for key, value in attr.asdict(self).items()
if value is not None
}
def __init__(
self,
session, # type: Session
worker_id, # type: ResourceMonitor.StatusReport,
sample_frequency_per_sec=2.0,
report_frequency_sec=30.0,
first_report_sec=None,
):
self.session = session
self.queue = deque(maxlen=1)
self.queue.appendleft(self.StatusReport())
self._worker_id = worker_id
self._sample_frequency = sample_frequency_per_sec
self._report_frequency = report_frequency_sec
self._first_report_sec = first_report_sec or report_frequency_sec
self._num_readouts = 0
self._readouts = {}
self._previous_readouts = {}
self._previous_readouts_ts = time()
self._thread = None
self._exit_event = Event()
self._gpustat_fail = 0
self._gpustat = gpustat
if not self._gpustat:
log.warning('Trains-Agent Resource Monitor: GPU monitoring is not available')
else:
self._active_gpus = None
try:
active_gpus = os.environ.get('NVIDIA_VISIBLE_DEVICES', '') or \
os.environ.get('CUDA_VISIBLE_DEVICES', '')
if active_gpus:
self._active_gpus = [int(g.strip()) for g in active_gpus.split(',')]
except Exception:
pass
def set_report(self, report):
# type: (ResourceMonitor.StatusReport) -> ()
if report is not None:
self.queue.appendleft(report)
def get_report(self):
# type: () -> ResourceMonitor.StatusReport
return self.queue[0]
def start(self):
self._exit_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
return self
def stop(self):
self._exit_event.set()
self.send_report()
def send_report(self, stats=None):
report = dict(
machine_stats=stats,
timestamp=(int(time()) * 1000),
worker=self._worker_id,
**self.get_report().to_dict()
)
log.debug("sending report: %s", report)
try:
self.session.get(service="workers", action="status_report", **report)
except Exception:
log.warning("Failed sending report: %s", report)
return False
return True
def _daemon(self):
seconds_since_started = 0
reported = 0
while True:
last_report = time()
current_report_frequency = (
self._report_frequency if reported != 0 else self._first_report_sec
)
while (time() - last_report) < current_report_frequency:
# wait for self._sample_frequency seconds, if event set quit
if self._exit_event.wait(1 / self._sample_frequency):
return
# noinspection PyBroadException
try:
self._update_readouts()
except Exception as ex:
log.warning("failed getting machine stats: %s", report_error(ex))
self._failure()
seconds_since_started += int(round(time() - last_report))
# check if we do not report any metric (so it means the last iteration will not be changed)
# if we do not have last_iteration, we just use seconds as iteration
# start reporting only when we figured out, if this is seconds based, or iterations based
average_readouts = self._get_average_readouts()
stats = {
# 3 points after the dot
key: round(value, 3) if isinstance(value, float) else [round(v, 3) for v in value]
for key, value in average_readouts.items()
}
# send actual report
if self.send_report(stats):
# clear readouts if this is update was sent
self._clear_readouts()
# count reported iterations
reported += 1
def _update_readouts(self):
readouts = self._machine_stats()
elapsed = time() - self._previous_readouts_ts
self._previous_readouts_ts = time()
def fix(k, v):
if k.endswith("_mbs"):
v = (v - self._previous_readouts.get(k, v)) / elapsed
if v is None:
v = 0
return k, self._readouts.get(k, 0) + v
self._readouts.update(starmap(fix, readouts.items()))
self._num_readouts += 1
self._previous_readouts = readouts
def _get_num_readouts(self):
return self._num_readouts
def _get_average_readouts(self):
def create_general_key(old_key):
"""
Create key for backend payload
:param old_key: old stats key
:type old_key: str
:return: new key for sending stats
:rtype: str
"""
key_parts = old_key.rpartition("_")
return "{}_*".format(key_parts[0] if old_key.startswith("gpu") else old_key)
ret = {}
# make sure the gpu/cpu stats are always ordered in the accumulated values list (general_key)
ordered_keys = sorted(self._readouts.keys())
for k in ordered_keys:
v = self._readouts[k]
stat_key = self.BACKEND_STAT_MAP.get(k)
if stat_key:
ret[stat_key] = v / self._num_readouts
else:
general_key = create_general_key(k)
general_key = self.BACKEND_STAT_MAP.get(general_key)
if general_key:
ret.setdefault(general_key, []).append(v / self._num_readouts)
else:
pass # log.debug("Cannot find key {}".format(k))
return ret
def _clear_readouts(self):
self._readouts = {}
self._num_readouts = 0
def _machine_stats(self):
"""
:return: machine stats dictionary, all values expressed in megabytes
"""
cpu_usage = psutil.cpu_percent(percpu=True)
stats = {"cpu_usage": sum(cpu_usage) / len(cpu_usage)}
virtual_memory = psutil.virtual_memory()
stats["memory_used"] = BytesSizes.megabytes(virtual_memory.used)
stats["memory_free"] = BytesSizes.megabytes(virtual_memory.available)
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
stats["disk_free_percent"] = 100 - disk_use_percentage
sensor_stat = (
psutil.sensors_temperatures()
if hasattr(psutil, "sensors_temperatures")
else {}
)
if "coretemp" in sensor_stat and len(sensor_stat["coretemp"]):
stats["cpu_temperature"] = max([t.current for t in sensor_stat["coretemp"]])
# update cached measurements
net_stats = psutil.net_io_counters()
stats["network_tx_mbs"] = BytesSizes.megabytes(net_stats.bytes_sent)
stats["network_rx_mbs"] = BytesSizes.megabytes(net_stats.bytes_recv)
io_stats = psutil.disk_io_counters()
stats["io_read_mbs"] = BytesSizes.megabytes(io_stats.read_bytes)
stats["io_write_mbs"] = BytesSizes.megabytes(io_stats.write_bytes)
# check if we can access the gpu statistics
if self._gpustat:
try:
gpu_stat = self._gpustat.new_query()
for i, g in enumerate(gpu_stat.gpus):
# only monitor the active gpu's, if none were selected, monitor everything
if self._active_gpus and i not in self._active_gpus:
continue
stats["gpu_temperature_{:d}".format(i)] = g["temperature.gpu"]
stats["gpu_utilization_{:d}".format(i)] = g["utilization.gpu"]
stats["gpu_mem_usage_{:d}".format(i)] = (
100.0 * g["memory.used"] / g["memory.total"]
)
# already in MBs
stats["gpu_mem_free_{:d}".format(i)] = (
g["memory.total"] - g["memory.used"]
)
stats["gpu_mem_used_%d" % i] = g["memory.used"]
except Exception as ex:
# something happened and we can't use gpu stats,
log.warning("failed getting machine stats: %s", report_error(ex))
self._failure()
return stats
def _failure(self):
self._gpustat_fail += 1
if self._gpustat_fail >= 3:
log.error(
"GPU monitoring failed getting GPU reading, switching off GPU monitoring"
)
self._gpustat = None
BACKEND_STAT_MAP = {"cpu_usage_*": "cpu_usage",
"cpu_temperature_*": "cpu_temperature",
"disk_free_percent": "disk_free_home",
"io_read_mbs": "disk_read",
"io_write_mbs": "disk_write",
"network_tx_mbs": "network_tx",
"network_rx_mbs": "network_rx",
"memory_free": "memory_free",
"memory_used": "memory_used",
"gpu_temperature_*": "gpu_temperature",
"gpu_mem_used_*": "gpu_memory_used",
"gpu_mem_free_*": "gpu_memory_free",
"gpu_utilization_*": "gpu_usage"}
def report_error(ex):
return "{}: {}".format(type(ex).__name__, ex)

View File

@@ -0,0 +1,119 @@
import os
import psutil
from time import sleep
from glob import glob
from tempfile import gettempdir, NamedTemporaryFile
from trains_agent.helper.base import warning
class Singleton(object):
prefix = 'trainsagent'
sep = '_'
ext = '.tmp'
worker_id = None
worker_name_sep = ':'
instance_slot = None
_pid_file = None
_lock_file_name = sep+prefix+sep+'global.lock'
_lock_timeout = 10
@classmethod
def register_instance(cls, unique_worker_id=None, worker_name=None):
"""
# Exit the process if another instance of us is using the same worker_id
:param unique_worker_id: if already exists, return negative
:param worker_name: slot number will be added to worker name, based on the available instance slot
:return: (str worker_id, int slot_number) Return None value on instance already running
"""
# try to lock file
lock_file = os.path.join(gettempdir(), cls._lock_file_name)
timeout = 0
while os.path.exists(lock_file):
if timeout > cls._lock_timeout:
warning('lock file timed out {}sec - clearing lock'.format(cls._lock_timeout))
try:
os.remove(lock_file)
except Exception:
pass
break
sleep(1)
timeout += 1
with open(lock_file, 'wb') as f:
f.write(bytes(os.getpid()))
f.flush()
try:
ret = cls._register_instance(unique_worker_id=unique_worker_id, worker_name=worker_name)
except:
ret = None, None
try:
os.remove(lock_file)
except Exception:
pass
return ret
@classmethod
def _register_instance(cls, unique_worker_id=None, worker_name=None):
if cls.worker_id:
return cls.worker_id, cls.instance_slot
# make sure we have a unique name
instance_num = 0
temp_folder = gettempdir()
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
slots = {}
for file in files:
parts = file.split(cls.sep)
try:
pid = int(parts[1])
except Exception:
# something is wrong, use non existing pid and delete the file
pid = -1
# count active instances and delete dead files
if not psutil.pid_exists(pid):
# delete the file
try:
os.remove(os.path.join(file))
except Exception:
pass
continue
instance_num += 1
try:
with open(file, 'r') as f:
uid, slot = str(f.read()).split('\n')
slot = int(slot)
except Exception:
continue
if uid == unique_worker_id:
return None, None
slots[slot] = uid
# get a new slot
if not slots:
cls.instance_slot = 0
else:
# guarantee we have the minimal slot possible
for i in range(max(slots.keys())+2):
if i not in slots:
cls.instance_slot = i
break
# build worker id based on slot
if not unique_worker_id:
unique_worker_id = worker_name + cls.worker_name_sep + str(cls.instance_slot)
# create lock
cls._pid_file = NamedTemporaryFile(dir=gettempdir(), prefix=cls.prefix + cls.sep + str(os.getpid()) + cls.sep,
suffix=cls.ext)
cls._pid_file.write(('{}\n{}'.format(unique_worker_id, cls.instance_slot)).encode())
cls._pid_file.flush()
cls.worker_id = unique_worker_id
return cls.worker_id, cls.instance_slot

View File

@@ -0,0 +1,144 @@
from __future__ import unicode_literals, print_function, absolute_import
import linecache
import os
import sys
import time
import trace
from itertools import chain
from types import ModuleType
from typing import Text, Sequence, Union
from pathlib2 import Path
import six
try:
from functools import lru_cache
except ImportError:
from functools32 import lru_cache
def inclusive_parents(path):
"""
Return path parents including path itself.
"""
return chain((path,), path.parents)
def get_module_path(module):
"""
:param module: Module object or name
:return: module path
"""
if isinstance(module, six.string_types):
module = sys.modules[module]
path = Path(module.__file__)
return path.parent if path.stem == '__init__' else path
Module = Union[ModuleType, Text]
class PackageTraceIgnore(object):
"""
Object that includes package modules in trace and excludes sub modules and all other code.
"""
def __init__(self, package, ignore_submodules):
# type: (Module, Sequence[Module]) -> None
"""
Modules given by name will be searched for in sys.modules, enabling use of "__name__".
:param package: Package to include modules of
:param ignore_submodules: sub modules of package to ignore
"""
self.ignore_submodules = tuple(map(get_module_path, ignore_submodules))
self.package = package
self.package_path = get_module_path(package)
@lru_cache(None)
def names(self, file_name, module_name=None):
# type: (Text, Text) -> bool
"""
Return whether a file should be ignored based on it's path and module name.
Ignore files which are not part of self.package.
trace.Ignore's documentation states that module_name is unreliable for packages,
therefore, it is not used here.
:param file_name: source file path
:param module_name: module name
:return: whether file should be ignored
"""
file_path = Path(file_name).resolve()
include = self.include(file_path)
return not include
def include(self, base):
# type: (Path) -> bool
for path in inclusive_parents(base):
if not path.exists():
continue
if any(path.samefile(sub) for sub in self.ignore_submodules):
return False
if path.samefile(self.package_path):
return True
return False
class PackageTrace(trace.Trace, object):
"""
Trace object for tracing only lines from a specific package.
Some functions are copied and modified for lack of modularity of ``trace.Trace``.
"""
def __init__(self, package, out_file, ignore_submodules=(), *args, **kwargs):
super(PackageTrace, self).__init__(*args, **kwargs)
self.ignore = PackageTraceIgnore(package, ignore_submodules)
self.__out_file = out_file
def __out(self, *args, **kwargs):
print(*args, file=self.__out_file, **kwargs)
def globaltrace_lt(self, frame, why, arg):
"""
## Copied from trace module ##
Handler for call events.
If the code block being entered is to be ignored, returns `None',
else returns self.localtrace.
"""
if why == 'call':
code = frame.f_code
filename = frame.f_globals.get('__file__', None)
if filename:
# XXX modname() doesn't work right for packages, so
# the ignore support won't work right for packages
ignore_it = self.ignore.names(filename)
if not ignore_it:
if self.trace:
filename = Path(filename)
modulename = '.'.join(
filename.relative_to(self.ignore.package_path).parts[:-1] + (filename.stem,)
)
self.__out(' --- modulename: %s, funcname: %s' % (modulename, code.co_name))
return self.localtrace
else:
return None
def localtrace_trace(self, frame, why, arg):
"""
## Copied from trace module ##
"""
if why == "line":
# record the file name and line number of every trace
filename = frame.f_code.co_filename
lineno = frame.f_lineno
if self.start_time:
self.__out('%.2f' % (time.time() - self.start_time), end='')
bname = os.path.basename(filename)
self.__out('%s(%d): %s' % (bname, lineno, linecache.getline(filename, lineno)), end='')
return self.localtrace
localtrace_trace_and_count = localtrace_trace

View File

@@ -0,0 +1,31 @@
from functools import partial
from importlib import import_module
import argparse
from trains_agent.definitions import PROGRAM_NAME
from .base import Parser, base_arguments, add_service, OnlyPluralChoicesHelpFormatter
SERVICES = [
'worker',
]
def get_parser():
top_parser = Parser(
prog=PROGRAM_NAME,
add_help=False,
formatter_class=partial(
OnlyPluralChoicesHelpFormatter,
max_help_position=120,
width=120,
),
)
base_arguments(top_parser)
from .worker import COMMANDS
subparsers = top_parser.add_subparsers(dest='command')
for c in COMMANDS:
parser = subparsers.add_parser(name=c, help=COMMANDS[c]['help'])
for a in COMMANDS[c].get('args', {}).keys():
parser.add_argument(a, **COMMANDS[c]['args'][a])
return top_parser

View File

@@ -0,0 +1,424 @@
from __future__ import print_function
import abc
import argparse
from copy import deepcopy
from functools import partial
import six
from pathlib2 import Path
from trains_agent import definitions
from trains_agent.session import Session
HEADER = 'TRAINS-AGENT Deep Learning DevOps'
class Parser(argparse.ArgumentParser):
__default_subparser = None
def __init__(self, usage_on_error=True, *args, **kwargs):
super(Parser, self).__init__(fromfile_prefix_chars=definitions.FROM_FILE_PREFIX_CHARS, *args, **kwargs)
self._usage_on_error = usage_on_error
@property
def choices(self):
try:
subparser = next(
action for action in self._actions
if isinstance(action, argparse._SubParsersAction))
except StopIteration:
return {}
return subparser.choices
def error(self, message):
if self._usage_on_error and message == argparse._('too few arguments'):
self.print_help()
print()
self.exit(2, argparse._('%s: error: %s\n') % (self.prog, message))
super(Parser, self).error(message)
def __getitem__(self, name):
return self.choices[name]
def remove_top_level_results(self, parse_results):
"""
Remove useless, artifact values
:param parse_results: resulting namespace of parse_args, converted to dict ( vars(args) )
"""
for action in self._actions:
if action.dest != 'version':
parse_results.pop(action.dest, None)
for key in ('func', 'command', 'subcommand', 'action'):
parse_results.pop(key, None)
def set_default_subparser(self, name):
self.__default_subparser = name
def get_default_subparser(self):
return self.choices[self.__default_subparser]
def _parse_known_args(self, arg_strings, *args, **kwargs):
in_args = set(arg_strings)
d_sp = self.__default_subparser
if d_sp is not None and not {'-h', '--help'}.intersection(in_args):
for x in self._subparsers._actions:
subparser_found = (
isinstance(x, argparse._SubParsersAction) and
in_args.intersection(x._name_parser_map.keys())
)
if subparser_found:
break
else:
# insert default in first position, this implies no
# global options without a sub_parsers specified
arg_strings = [d_sp] + arg_strings
return super(Parser, self)._parse_known_args(
arg_strings, *args, **kwargs
)
class AliasedPseudoAction(argparse.Action):
"""
Action for choosing between sub-commands, including aliases
"""
def __init__(self, name, aliases, help):
dest = name
aliases = [a for a in aliases if a != name]
if aliases:
dest += ' (%s)' % ','.join(aliases)
super(AliasedPseudoAction, self).__init__(option_strings=[], dest=dest, help=help)
class AliasedSubParsersAction(argparse._SubParsersAction):
"""
Action for adding aliases for sub-commands
"""
def add_parser(self, name, **kwargs):
aliases = kwargs.pop('aliases', [])
parser = super(AliasedSubParsersAction, self).add_parser(name, **kwargs)
# Make the aliases work
for alias in aliases:
self._name_parser_map[alias] = parser
# Make the help text reflect them, first removing old help entry.
help = kwargs.pop('help', None)
if help:
self._choices_actions.pop()
pseudo_action = AliasedPseudoAction(name, aliases, help)
self._choices_actions.append(pseudo_action)
return parser
class OnlyPluralChoicesHelpFormatter(argparse.HelpFormatter):
@staticmethod
def _metavar_formatter(action, default_metavar):
if action.metavar is not None:
result = action.metavar
elif action.choices is not None:
choice_strs = [str(choice) for choice in action.choices]
choice_strs = [choice for choice in choice_strs if choice + 's' not in choice_strs]
result = '{%s}' % ','.join(choice_strs)
else:
result = default_metavar
def format(tuple_size):
if isinstance(result, tuple):
return result
else:
return (result, ) * tuple_size
return format
def hyphenate(s):
return s.replace('_', '-')
def add_args(parser, args):
"""
Add arguments to parser from args mapping
:param parser: parser to add arguments to
:type parser: argparse.ArgumentParser
:param args: mapping of name -> other arguments to ArgumentParser.add_argument
:type args: dict
"""
for arg_name, arg_params in args.items():
aliases = arg_params.pop('aliases', tuple())
parser.add_argument(arg_name, *aliases, **arg_params)
def add_mutually_exclusive_groups(parser, groups):
"""
Add mutually exclusive groups to parser from list
:param parser: parser to add groups to
:param groups: list of dictionaries, each containing:
1. 'args': parameter to add_args
2. arguments to ArgumentParser.add_mutually_exclusive_group
"""
for group in groups:
args = group.pop('args', {})
group_parser = parser.add_mutually_exclusive_group(**group)
add_args(group_parser, args)
def add_service(subparsers, name, commands, command_name_dest='command', formatter_class=argparse.RawDescriptionHelpFormatter, **kwargs):
"""
Add service commands to parser from arguments dictionary
:param subparsers: subparsers object of ArgumentParser
:param name: name of service
:param commands: mapping of names to dictionaries, each of them containing:
1. 'args' - mapping of name -> other arguments to ArgumentParser.add_argument
2. 'help' - command description
3. 'mutually_exclusive_groups' - see add_mutually_exclusive_groups
:param command_name_dest: name of attribute in which to store selected sub-command
:param formatter_class; help formatter class
:param kwargs: any other arguments to add_parser method of subparser object
:return: service subparser
"""
commands = deepcopy(commands)
service_parser = subparsers.add_parser(
name,
# aliases=(name.strip('s'),),
formatter_class=formatter_class,
**kwargs
)
service_parser.register('action', 'parsers', AliasedSubParsersAction)
service_parser.set_defaults(**{command_name_dest: name})
service_subparsers = service_parser.add_subparsers(
title='{} commands'.format(name.capitalize()),
parser_class=partial(Parser, usage_on_error=False),
dest='action')
# This is a fix for a bug in python3's argparse: running "trains-agent some_service" fails
service_subparsers.required = True
for name, subparser in commands.pop('subparsers', {}).items():
add_service(service_subparsers, name, command_name_dest='subcommand', **subparser)
for command_name, command in commands.items():
command_type = command.pop('type', None)
mutually_exclusive_groups = command.pop('mutually_exclusive_groups', [])
func = command.pop('func', command_name)
args = command.pop('args', {})
command_parser = service_subparsers.add_parser(hyphenate(command_name), **command)
if command_type:
command_type.make(command_parser)
command_parser.set_defaults(func=func)
add_mutually_exclusive_groups(command_parser, mutually_exclusive_groups)
add_args(command_parser, args)
return service_parser
@six.add_metaclass(abc.ABCMeta)
class CommandType(object):
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
def make(self, parser):
return self._make(parser, *self._args, **self._kwargs)
@abc.abstractmethod
def _make(self, *args, **kwargs):
pass
class ListCommand(CommandType):
@staticmethod
def _make(parser, default_value, pagination=False, tree=False):
if tree:
table_group = parser.add_mutually_exclusive_group()
tree_action = table_group.add_argument(
'--tree', help='Tree view output', action='store_true', default=False)
csv_group = parser.add_mutually_exclusive_group()
# hack: tree cannot be used with either csv or table, which can be used
# which each other
csv_group._group_actions.append(tree_action)
else:
csv_group = table_group = parser
table_group.add_argument(
'--table',
help='Select table columns ("#" separated, default: %(default)s)',
default=default_value)
csv_group.add_argument(
'--csv',
help='Generate CSV output to specified path',
default=None)
parser.add_argument(
'--no-headers',
action='store_false',
dest='headers',
help='Do not print table/csv headers')
parser.add_argument(
'--sort',
help='Fields to sort by (same format as --table)')
parser.add_argument(
'--ascending',
default=None,
dest='sort_reverse',
action='store_false',
help='Sort in ascending order (default)')
parser.add_argument(
'--descending',
default=None,
dest='sort_reverse',
action='store_true',
help='Sort in descending order')
if not pagination:
return
parser.add_argument(
'--page',
help='Page number to show (default: %(default)s)',
default=0,
type=bound_number_type(minimum=0))
parser.add_argument(
'--page-size',
help='Size of page (default: %(default)s)',
default=50,
type=bound_number_type(minimum=1))
parser.add_argument(
'--no-pagination',
action='store_false',
dest='pagination',
help='Disable pagination (return all results)')
class _HelpAction(argparse._HelpAction):
def __call__(self, parser, namespace, values, option_string=None):
# print header
print(HEADER + '\n')
parser.print_help()
print('')
parser.exit()
class _DetailedHelpAction(argparse._HelpAction):
def __call__(self, parser, namespace, values, option_string=None):
# print header
print(HEADER + '\n')
parser.print_help()
print('\n')
# retrieve subparsers from parser
subparsers_actions = [
action for action in parser._actions
if isinstance(action, argparse._SubParsersAction)
]
# iterate and print help for each suparser
for subparsers_action in subparsers_actions:
# get all subparsers and print help
for choice, subparsercmd in subparsers_action.choices.items():
# split help into lines so we can skip the header
text = subparsercmd.format_help().split('\n')
# find first line of command (skip usage and header)
for i, t in enumerate(text):
if t.startswith(choice.title()):
break
# print help command prefix
print(text[i])
# print help per sub-commands, we actually assume only one
subact = [
action for action in subparsercmd._actions
if isinstance(action, argparse._SubParsersAction)
]
# per action print all parameters
for j, t in enumerate(text[i + 2:]):
print(t)
k = t.split()
if not k:
continue
try:
subc = subact[0].choices[k[0]]
except KeyError:
continue
# hack so we can control formatting in one place
# otherwise we need to update all
# the parsers when we create them
subc.formatter_class = lambda prog: argparse.HelpFormatter(prog, width=120)
subchelp = subc.format_help().split('\n')
# skip until we reach "optional arguments:"
for si, st in enumerate(subchelp):
if st.startswith('optional arguments:'):
break
# print help with single tab indent
for st in subchelp[si + 1:]:
print('\t %s' % st)
parser.exit()
def base_arguments(top_parser):
top_parser.register('action', 'parsers', AliasedSubParsersAction)
top_parser.add_argument('-h', action=_HelpAction, help='Displays summary of all commands')
top_parser.add_argument(
'--help',
action=_DetailedHelpAction,
help='Detailed help of command line interface')
top_parser.add_argument(
'--version',
action='version',
version='TRAINS-AGENT version %s' % Session.version,
help='TRAINS-AGENT version number')
top_parser.add_argument(
'--config-file',
help='Use a different configuration file (default: "{}")'.format(definitions.CONFIG_FILE))
top_parser.add_argument('--debug', '-d', action='store_true', help='print debug information')
top_parser.add_argument(
'--trace', '-t',
action='store_true',
help='Trace execution to a temporary file.',
)
def bound_number_type(minimum=None, maximum=None):
"""
bound_number_type
Creates a bounded integer "type" (validator function)
for use with argparse.ArgumentParser.add_argument.
At least one of ``minimum`` and ``maximum`` must be passed.
:param minimum: maximum allowed value
:param maximum: minimum allowed value
"""
if minimum is maximum is None:
raise ValueError('either "minimum" or "maximum" must be provided')
def bound_int(arg):
num = int(arg)
if minimum is not None and num < minimum:
raise argparse.ArgumentTypeError('minimum value is {}'.format(minimum))
if maximum is not None and num > maximum:
raise argparse.ArgumentTypeError('maximum value is {}'.format(minimum))
return num
return bound_int
def real_path_type(string):
path = Path(string).expanduser()
if not path.exists():
raise argparse.ArgumentTypeError('"{}": No such file or directory'.format(path))
return path
class ObjectID(object):
def __init__(self, name, service=None):
self.name = name
self.service = service
def foreign_object_id(service):
return partial(ObjectID, service=service)

View File

@@ -0,0 +1,111 @@
import argparse
from textwrap import dedent
from trains_agent.helper.base import warning, is_windows_platform
from trains_agent.interface.base import foreign_object_id
class DeprecatedFlag(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
warning('argument "{}" is deprecated'.format(option_string))
WORKER_ARGS = {
'-O': {
'help': 'Compile optimized pyc code (see python documentation). Repeat for more optimization.',
'action': 'count',
'default': 0,
'dest': 'optimization',
},
'--git-user': {
'help': 'git username for repository access',
},
'--git-pass': {
'help': 'git password for repository access',
},
'--log-level': {
'help': 'SDK log level',
'choices': ['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR', 'CRITICAL'],
'type': lambda x: x.upper(),
'default': 'INFO',
},
}
DAEMON_ARGS = dict({
'--foreground': {
'help': 'Pipe full log to stdout/stderr, should not be used if running in background',
'action': 'store_true',
},
'--docker': {
'help': 'Run execution task inside a docker (v19.03 and above). Optional args <image> <arguments> or '
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
'set NVIDIA_VISIBLE_DEVICES to limit gpu visibility for docker',
'nargs': '*',
'default': False,
},
'--queue': {
'help': 'Queue ID(s)/Name(s) to pull tasks from (\'default\' queue)',
'nargs': '+',
'default': tuple(),
'dest': 'queues',
'type': foreign_object_id('queues'),
},
}, **WORKER_ARGS)
COMMANDS = {
'execute': {
'help': 'Build & Execute a selected experiment',
'args': dict({
'--id': {
'help': 'Task ID to run',
'required': True,
'dest': 'task_id',
'type': foreign_object_id('tasks'),
},
'--log-file': {
'help': 'Output task execution (stdout/stderr) into text file',
},
'--disable-monitoring': {
'help': 'Disable logging & monitoring (stdout is still visible)',
'action': 'store_true',
},
'--full-monitoring': {
'help': 'Full environment setup log & task logging & monitoring (stdout is still visible)',
'action': 'store_true',
},
}, **WORKER_ARGS),
},
'build': {
'help': 'Build selected experiment environment '
'(including pip packages, cloned code and git diff)\n'
'Used mostly for debugging purposes',
'args': dict({
'--id': {
'help': 'Task ID to build',
'required': True,
'dest': 'task_id',
'type': foreign_object_id('tasks'),
},
'--target-folder': {
'help': 'Where to build the task\'s virtual environment and source code',
},
'--python-version': {
'help': 'Virtual environment python version to use',
},
}, **WORKER_ARGS),
},
'list': {
'help': 'List all worker machines and status',
},
'daemon': {
'help': 'Start Trains-Agent daemon worker',
'args': DAEMON_ARGS,
},
'config': {
'help': 'Check daemon configuration and print it',
},
'init': {
'help': 'Trains-Agent configuration wizard',
}
}

314
trains_agent/session.py Normal file
View File

@@ -0,0 +1,314 @@
from __future__ import print_function, unicode_literals
import json
import logging
import os
import platform
import sys
from copy import deepcopy
from typing import Any, Callable
import attr
from pathlib2 import Path
from pyhocon import ConfigFactory, HOCONConverter, ConfigTree
from trains_agent.backend_api.session import Session as _Session, Request
from trains_agent.backend_api.session.client import APIClient
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES
from trains_agent.definitions import ENVIRONMENT_CONFIG
from trains_agent.errors import APIError
from trains_agent.helper.base import HOCONEncoder
from trains_agent.helper.process import Argv
from .version import __version__
POETRY = "poetry"
@attr.s
class ConfigValue(object):
"""
Manages a single config key
"""
config = attr.ib(type=ConfigTree)
key = attr.ib(type=str)
def get(self, default=None):
"""
Get value of key with default
"""
return self.config.get(self.key, default=default)
def set(self, value):
"""
Change the value of key
"""
self.config.put(self.key, value)
def modify(self, fn):
# type: (Callable[[Any], Any]) -> ()
"""
Change the value of a key using a function
"""
self.set(fn(self.get()))
def tree(*args):
"""
Helper function for creating config trees
"""
return ConfigTree(args)
class Session(_Session):
version = __version__
def __init__(self, *args, **kwargs):
# make sure we set the environment variable so the parent session opens the correct file
if kwargs.get('config_file'):
config_file = Path(os.path.expandvars(kwargs.get('config_file'))).expanduser().absolute().as_posix()
kwargs['config_file'] = config_file
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
if not Path(config_file).is_file():
raise ValueError("Could not open configuration file: {}".format(config_file))
super(Session, self).__init__(*args, **kwargs)
self.log = self.get_logger(__name__)
self.trace = kwargs.get('trace', False)
self._config_file = kwargs.get('config_file') or \
os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0]
self.api_client = APIClient(session=self, api_version="2.4")
# HACK make sure we have python version to execute,
# if nothing was specific, use the one that runs us
def_python = ConfigValue(self.config, "agent.default_python")
if not def_python.get():
def_python.set("{version.major}.{version.minor}".format(version=sys.version_info))
# HACK: backwards compatibility
os.environ['ALG_CONFIG_FILE'] = self._config_file
os.environ['SM_CONFIG_FILE'] = self._config_file
if not self.config.get('api.host', None) and self.config.get('api.api_server', None):
self.config['api']['host'] = self.config.get('api.api_server')
# initialize nvidia visibility variable
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'):
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'):
os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES')
# override with environment variables
# cuda_version & cudnn_version are overridden with os.environ here, and normalized in the next section
for config_key, env_config in ENVIRONMENT_CONFIG.items():
value = env_config.get()
if not value:
continue
env_key = ConfigValue(self.config, config_key)
env_key.set(value)
# initialize cuda versions
try:
from trains_agent.helper.package.requirements import RequirementsManager
agent = self.config['agent']
agent['cuda_version'], agent['cudnn_version'] = \
RequirementsManager.get_cuda_version(self.config)
except Exception:
pass
# initialize worker name
worker_name = ConfigValue(self.config, "agent.worker_name")
if not worker_name.get():
worker_name.set(platform.node())
self.create_cache_folders()
@staticmethod
def get_logger(name):
logger = logging.getLogger(name)
logger.propagate = True
return TrainsAgentLogger(logger)
@property
def debug_mode(self):
return self.config.get("agent.debug", False)
@property
def config_file(self):
return self._config_file
def create_cache_folders(self, slot_index=0):
"""
create and update the cache folders
notice we support multiple instances sharing the same cache on some folders
and on some we use "instance slot" numbers in order to differentiate between the different instances running
notice slot_index=0 is the default, meaning no suffix is added to the singleton_folders
Note: do not call this function twice with non zero slot_index
it will add a suffix to the folders on each call
:param slot_index: integer
"""
# create target folders:
folder_keys = ('agent.venvs_dir', 'agent.vcs_cache.path',
'agent.pip_download_cache.path',
'agent.docker_pip_cache', 'agent.docker_apt_cache')
singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path',)
for key in folder_keys:
folder_key = ConfigValue(self.config, key)
if not folder_key.get():
continue
if slot_index and key in singleton_folders:
f = folder_key.get()
if f.endswith(os.path.sep):
f = f[:-1]
folder_key.set(f + '.{}'.format(slot_index))
# update the configuration for full path
folder = Path(os.path.expandvars(folder_key.get())).expanduser().absolute()
folder_key.set(folder.as_posix())
try:
folder.mkdir(parents=True, exist_ok=True)
except:
pass
def print_configuration(self, remove_secret_keys=("secret", "pass", "token", "account_key")):
# remove all the secrets from the print
def recursive_remove_secrets(dictionary, secret_keys=()):
for k in list(dictionary):
for s in secret_keys:
if s in k:
dictionary.pop(k)
break
if isinstance(dictionary.get(k, None), dict):
recursive_remove_secrets(dictionary[k], secret_keys=secret_keys)
elif isinstance(dictionary.get(k, None), (list, tuple)):
for item in dictionary[k]:
if isinstance(item, dict):
recursive_remove_secrets(item, secret_keys=secret_keys)
config = deepcopy(self.config.to_dict())
# remove the env variable, it's not important
config.pop('env', None)
if remove_secret_keys:
recursive_remove_secrets(config, secret_keys=remove_secret_keys)
config = ConfigFactory.from_dict(config)
self.log.debug("Run by interpreter: %s", sys.executable)
print(
"Current configuration (trains_agent v{}, location: {}):\n"
"----------------------\n{}\n".format(
self.version, self._config_file, HOCONConverter.convert(config, "properties")
)
)
def send_api(self, request):
# type: (Request) -> Any
result = self.send(request)
if not result.ok():
raise APIError(result)
if not result.response:
raise APIError(result, extra_info="Invalid response")
return result.response
def get(self, service, action, version=None, headers=None,
data=None, json=None, async_enable=False, **kwargs):
return self._manual_request(service=service, action=action,
version=version, method="get", headers=headers,
data=data, async_enable=async_enable,
json=json or kwargs)
def post(self, service, action, version=None, headers=None,
data=None, json=None, async_enable=False, **kwargs):
return self._manual_request(service=service, action=action,
version=version, method="post", headers=headers,
data=data, async_enable=async_enable,
json=json or kwargs)
def _manual_request(self, service, action, version=None, method="get", headers=None,
data=None, json=None, async_enable=False, **kwargs):
res = self.send_request(service=service, action=action,
version=version, method=method, headers=headers,
data=data, async_enable=async_enable,
json=json or kwargs)
try:
res_json = res.json()
return_code = res_json["meta"]["result_code"]
except (ValueError, KeyError, TypeError):
raise APIError(res)
# check return code
if return_code != 200:
raise APIError(res)
return res_json["data"]
def to_json(self):
return json.dumps(
self.config.as_plain_ordered_dict(), cls=HOCONEncoder, indent=4
)
def command(self, *args):
return Argv(*args, log=self.get_logger(Argv.__module__))
@attr.s
class TrainsAgentLogger(object):
"""
Proxy around logging.Logger because inheriting from it is difficult.
"""
logger = attr.ib(type=logging.Logger)
def _log_with_error(self, level, *args, **kwargs):
"""
Include error information when in debug mode
"""
kwargs.setdefault("exc_info", self.logger.isEnabledFor(logging.DEBUG))
return self.logger.log(level, *args, **kwargs)
def warning(self, *args, **kwargs):
return self._log_with_error(logging.WARNING, *args, **kwargs)
def error(self, *args, **kwargs):
return self._log_with_error(logging.ERROR, *args, **kwargs)
def __getattr__(self, item):
return getattr(self.logger, item)
def __call__(self, *args, **kwargs):
"""
Compatibility with old ``Command.log()`` method
"""
return self.logger.info(*args, **kwargs)
def normalize_cuda_version(value):
# type: (Any) -> str
"""
Take variably formatted cuda version string/number and return it in the same format:
string decimal representation of 10 * major + minor.
>>> normalize_cuda_version(100)
'100'
>>> normalize_cuda_version("100")
'100'
>>> normalize_cuda_version(10)
'10'
>>> normalize_cuda_version(10.0)
'100'
>>> normalize_cuda_version("10.0")
'100'
>>> normalize_cuda_version("10.0.130")
'100'
"""
value = str(value)
if "." in value:
try:
value = str(int(float(".".join(value.split(".")[:2])) * 10))
except (ValueError, TypeError):
pass
return value

Some files were not shown because too many files have changed in this diff Show More