mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Initial release
This commit is contained in:
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# Mac
|
||||
.DS_Store
|
||||
|
||||
# IntelliJ
|
||||
.idea
|
||||
|
||||
# Python
|
||||
*.pyc
|
||||
__pycache__
|
||||
build/
|
||||
dist/
|
||||
*.egg-info
|
||||
|
||||
235
docs/trains.conf
Normal file
235
docs/trains.conf
Normal file
@@ -0,0 +1,235 @@
|
||||
# TRAINS-AGENT configuration file
|
||||
api {
|
||||
api_server: https://demoapi.trainsai.io
|
||||
web_server: https://demoapp.trainsai.io
|
||||
files_server: https://demofiles.trainsai.io
|
||||
|
||||
# Credentials are generated in the webapp, https://demoapp.trainsai.io/profile
|
||||
# Overridden with os environment: TRAINS_API_ACCESS_KEY / TRAINS_API_SECRET_KEY
|
||||
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
||||
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
}
|
||||
|
||||
agent {
|
||||
# Set GIT user/pass credentials
|
||||
# leave blank for GIT SSH credentials
|
||||
git_user=""
|
||||
git_pass=""
|
||||
|
||||
|
||||
# unique name of this worker, if None, created based on hostname:process_id
|
||||
# Overridden with os environment: TRAINS_WORKER_NAME
|
||||
# worker_id: "trains-agent-machine1:gpu0"
|
||||
worker_id: ""
|
||||
|
||||
# worker name, replaces the hostname when creating a unique name for this worker
|
||||
# Overridden with os environment: TRAINS_WORKER_ID
|
||||
# worker_name: "trains-agent-machine1"
|
||||
worker_name: ""
|
||||
|
||||
# select python package manager:
|
||||
# currently supported pip and conda
|
||||
# poetry is used if pip selected and repository contains poetry.lock file
|
||||
package_manager: {
|
||||
# supported options: pip, conda
|
||||
type: pip,
|
||||
# virtual environment inheres packages from system
|
||||
system_site_packages: false,
|
||||
# install with --upgrade
|
||||
force_upgrade: false,
|
||||
|
||||
# additional artifact repositories to use when installing python packages
|
||||
# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]
|
||||
extra_index_url: []
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["pytorch", "conda-forge", ]
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
venvs_dir = ~/.trains/venvs-builds
|
||||
|
||||
# cached git clone folder
|
||||
vcs_cache: {
|
||||
enabled: true,
|
||||
path: ~/.trains/vcs-cache
|
||||
},
|
||||
|
||||
# use venv-update in order to accelerate python virtual environment building
|
||||
# Still in beta, turned off by default
|
||||
venv_update: {
|
||||
enabled: false,
|
||||
},
|
||||
|
||||
# cached folder for specific python package download (mostly pytorch versions)
|
||||
pip_download_cache {
|
||||
enabled: true,
|
||||
path: ~/.trains/pip-download-cache
|
||||
},
|
||||
|
||||
translate_ssh: true,
|
||||
# reload configuration file every daemon execution
|
||||
reload_config: false,
|
||||
|
||||
# pip cache folder used mapped into docker, for python package caching
|
||||
docker_pip_cache = ~/.trains/pip-cache
|
||||
# apt cache folder used mapped into docker, for ubuntu package caching
|
||||
docker_apt_cache = ~/.trains/apt-cache
|
||||
|
||||
default_docker: {
|
||||
# default docker image to use when running in docker mode
|
||||
image: "nvidia/cuda"
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# arguments: ["--ipc=host"]
|
||||
}
|
||||
}
|
||||
|
||||
sdk {
|
||||
# TRAINS - default SDK configuration
|
||||
|
||||
storage {
|
||||
cache {
|
||||
# Defaults to system temp folder / cache
|
||||
default_base_dir: "~/.trains/cache"
|
||||
}
|
||||
|
||||
direct_access: [
|
||||
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
|
||||
# or cached, and any download request will return a direct reference.
|
||||
# Objects are specified in glob format, available for url and content_type.
|
||||
{ url: "file://*" } # file-urls are always directly referenced
|
||||
]
|
||||
}
|
||||
|
||||
metrics {
|
||||
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
|
||||
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
}
|
||||
|
||||
network {
|
||||
metrics {
|
||||
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
|
||||
# a specific iteration
|
||||
file_upload_threads: 4
|
||||
|
||||
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
|
||||
# being sent for upload
|
||||
file_upload_starvation_warning_sec: 120
|
||||
}
|
||||
|
||||
iteration {
|
||||
# Max number of retries when getting frames if the server returned an error (http code 500)
|
||||
max_retries_on_server_error: 5
|
||||
# Backoff factory for consecutive retry attempts.
|
||||
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
|
||||
retry_backoff_factor_sec: 10
|
||||
}
|
||||
}
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
# {
|
||||
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
# host: "my-minio-host:9000"
|
||||
# key: "12345678"
|
||||
# secret: "12345678"
|
||||
# multipart: false
|
||||
# secure: false
|
||||
# }
|
||||
]
|
||||
}
|
||||
boto3 {
|
||||
pool_connections: 512
|
||||
max_multipart_concurrency: 16
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# # Default project and credentials file
|
||||
# # Will be used when no bucket configuration is found
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
|
||||
# # Specific credentials per bucket and sub directory
|
||||
# credentials = [
|
||||
# {
|
||||
# bucket: "my-bucket"
|
||||
# subdir: "path/in/bucket" # Not required
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
# },
|
||||
# ]
|
||||
}
|
||||
azure.storage {
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "trains"
|
||||
# account_key: "secret"
|
||||
# # container_name:
|
||||
# }
|
||||
# ]
|
||||
}
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
}
|
||||
|
||||
development {
|
||||
# Development-mode options
|
||||
|
||||
# dev task reuse window
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: True
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
# Status report period in seconds
|
||||
report_period_sec: 2
|
||||
|
||||
# ping to the server - check connectivity
|
||||
ping_period_sec: 30
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
6
main.py
Normal file
6
main.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sys
|
||||
from trains_agent import __main__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(__main__.main())
|
||||
23
requirements.txt
Normal file
23
requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
attrs>=18.0
|
||||
enum34>=0.9 ; python_version < '3.6'
|
||||
furl>=2.0.0
|
||||
future>=0.16.0
|
||||
humanfriendly>=2.1
|
||||
jsonmodels>=2.2
|
||||
jsonschema>=2.6.0
|
||||
pathlib2>=2.3.0
|
||||
psutil>=3.4.2
|
||||
pyhocon>=0.3.38
|
||||
pyparsing>=2.0.3
|
||||
python-dateutil>=2.4.2
|
||||
pyjwt>=1.6.4
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.20.0
|
||||
requirements_parser>=0.2.0
|
||||
semantic_version>=2.6.0
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
typing>=3.6.4
|
||||
urllib3>=1.21.1
|
||||
virtualenv>=16
|
||||
76
setup.py
Normal file
76
setup.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
TRAINS - Artificial Intelligence Version Control
|
||||
TRAINS-AGENT DevOps for machine/deep learning
|
||||
https://github.com/allegroai/trains-agent
|
||||
"""
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import setup, find_packages
|
||||
from six import exec_
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'README.md').read_text()
|
||||
|
||||
|
||||
def read_version_string():
|
||||
result = {}
|
||||
exec_((here / 'trains_agent/version.py').read_text(), result)
|
||||
return result['__version__']
|
||||
|
||||
|
||||
version = read_version_string()
|
||||
|
||||
requirements = (here / 'requirements.txt').read_text().splitlines()
|
||||
|
||||
|
||||
setup(
|
||||
name='trains_agent',
|
||||
version=version,
|
||||
description='Trains-Agent DevOps for deep learning (DevOps for TRAINS)',
|
||||
long_description=long_description,
|
||||
# The project's main homepage.
|
||||
url='https://github.com/allegroai/trains-agent',
|
||||
author='Allegroai',
|
||||
author_email='trains@allegro.ai',
|
||||
license='Apache License 2.0',
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: System Administrators',
|
||||
'Intended Audience :: Science/Research',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Operating System :: MacOS :: MacOS X',
|
||||
'Operating System :: Microsoft',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Scientific/Engineering :: Image Recognition',
|
||||
'Topic :: System :: Logging',
|
||||
'Topic :: System :: Monitoring',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
],
|
||||
|
||||
keywords='trains devops machine deep learning agent automation hpc cluster',
|
||||
|
||||
packages=find_packages(exclude=['contrib', 'docs', 'data', 'examples', 'tests*']),
|
||||
|
||||
install_requires=requirements,
|
||||
extras_require={
|
||||
},
|
||||
package_data={
|
||||
'trains_agent': ['backend_api/config/default/*.conf']
|
||||
},
|
||||
include_package_data=True,
|
||||
# To provide executable scripts, use entry points in preference to the
|
||||
# "scripts" keyword. Entry points provide cross-platform support and allow
|
||||
# pip to create the appropriate form of executable for the target platform.
|
||||
entry_points={
|
||||
'console_scripts': ['trains-agent=trains_agent.__main__:main'],
|
||||
},
|
||||
)
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
51
tests/conftest.py
Normal file
51
tests/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from argparse import Namespace
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pathlib2 import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def run_trains_agent(script_runner):
|
||||
""" Execute trains_agent agent app in subprocess and return stdout as a string.
|
||||
Args:
|
||||
script_runner (object): a pytest plugin for testing python scripts
|
||||
installed via console_scripts entry point of setup.py.
|
||||
It can run the scripts under test in a separate process or using the interpreter that's running
|
||||
the test suite. The former mode ensures that the script will run in an environment that
|
||||
is identical to normal execution whereas the latter one allows much quicker test runs during development
|
||||
while simulating the real runs as muh as possible.
|
||||
For more details: https://pypi.python.org/pypi/pytest-console-scripts
|
||||
Returns:
|
||||
string: The return value. stdout output
|
||||
"""
|
||||
def _method(*args):
|
||||
trains_agent_file = str(PROJECT_ROOT / "trains_agent.sh")
|
||||
ret = script_runner.run(trains_agent_file, *args)
|
||||
return ret
|
||||
return _method
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def trains_agentyaml(tmpdir):
|
||||
@contextmanager
|
||||
def _method(template_file):
|
||||
file = tmpdir.join("trains_agent.yaml")
|
||||
with (PROJECT_ROOT / "tests/templates" / template_file).open() as f:
|
||||
code = yaml.load(f)
|
||||
yield Namespace(code=code, file=file.strpath)
|
||||
file.write(yaml.dump(code))
|
||||
return _method
|
||||
|
||||
|
||||
# class Test(object):
|
||||
# def yaml_file(self, tmpdir, template_file):
|
||||
# file = tmpdir.join("trains_agent.yaml")
|
||||
# with open(template_file) as f:
|
||||
# test_object = yaml.load(f)
|
||||
# self.let(test_object)
|
||||
# file.write(yaml.dump(test_object))
|
||||
# return file.strpath
|
||||
0
tests/package/__init__.py
Normal file
0
tests/package/__init__.py
Normal file
35
tests/package/ssh_conversion.py
Normal file
35
tests/package/ssh_conversion.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from trains_agent.helper.repo import VCS
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["url", "expected"],
|
||||
(
|
||||
("a", None),
|
||||
("foo://a/b", None),
|
||||
("foo://a/b/", None),
|
||||
("https://a/b/", None),
|
||||
("https://example.com/a/b", None),
|
||||
("https://example.com/a/b/", None),
|
||||
("ftp://example.com/a/b", None),
|
||||
("ftp://example.com/a/b/", None),
|
||||
("github.com:foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("git@github.com:foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("hg@bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://git@github.com/foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("ssh://user@github.com/foo/bar.git", "https://user@github.com/foo/bar.git"),
|
||||
("ssh://git:password@github.com/foo/bar.git", "https://git:password@github.com/foo/bar.git"),
|
||||
("ssh://user:password@github.com/foo/bar.git", "https://user:password@github.com/foo/bar.git"),
|
||||
("ssh://hg@bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://user@bitbucket.org/foo/bar.git", "https://user@bitbucket.org/foo/bar.git"),
|
||||
("ssh://hg:password@bitbucket.org/foo/bar.git", "https://hg:password@bitbucket.org/foo/bar.git"),
|
||||
("ssh://user:password@bitbucket.org/foo/bar.git", "https://user:password@bitbucket.org/foo/bar.git"),
|
||||
),
|
||||
)
|
||||
def test(url, expected):
|
||||
result = VCS.resolve_ssh_url(url)
|
||||
expected = expected or url
|
||||
assert result == expected
|
||||
76
tests/package/test_pip_download_cache.py
Normal file
76
tests/package/test_pip_download_cache.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from furl import furl
|
||||
|
||||
from trains_agent.helper.package.translator import RequirementsTranslator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"line",
|
||||
(
|
||||
furl()
|
||||
.set(
|
||||
scheme=scheme,
|
||||
host=host,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
.url
|
||||
for scheme in ("http", "https", "ftp")
|
||||
for host in ("a", "example.com")
|
||||
for path in (None, "/", "a", "a/", "a/b", "a/b/", "a b", "a b ")
|
||||
for query in (None, "foo", "foo=3", "foo=3&bar")
|
||||
for fragment in (None, "foo")
|
||||
for port in (None, 1337)
|
||||
for username in (None, "", "user")
|
||||
for password in (None, "", "password")
|
||||
),
|
||||
)
|
||||
def test_supported(line):
|
||||
assert "://" in line
|
||||
assert RequirementsTranslator.is_supported_link(line)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"line",
|
||||
[
|
||||
"pytorch",
|
||||
"foo",
|
||||
"foo1",
|
||||
"bar",
|
||||
"bar1",
|
||||
"foo-bar",
|
||||
"foo-bar1",
|
||||
"foo-bar-1",
|
||||
"foo_bar",
|
||||
"foo_bar1",
|
||||
"foo_bar_1",
|
||||
" https://a",
|
||||
" https://a/b",
|
||||
" http://a",
|
||||
" http://a/b",
|
||||
" ftp://a/b",
|
||||
"file://a/b",
|
||||
"ssh://a/b",
|
||||
"foo://a/b",
|
||||
"git//a/b",
|
||||
"git+https://a/b",
|
||||
"https+git://a/b",
|
||||
"git+http://a/b",
|
||||
"http+git://a/b",
|
||||
"",
|
||||
" ",
|
||||
"-e ",
|
||||
"-e x",
|
||||
"-e http://a",
|
||||
"-e http://a/b",
|
||||
"-e https://a",
|
||||
"-e https://a/b",
|
||||
"-e file://a/b",
|
||||
],
|
||||
)
|
||||
def test_not_supported(line):
|
||||
assert not RequirementsTranslator.is_supported_link(line)
|
||||
33
tests/package/test_pytorch_map.py
Normal file
33
tests/package/test_pytorch_map.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import attr
|
||||
import pytest
|
||||
import requests
|
||||
from furl import furl
|
||||
|
||||
import six
|
||||
from trains_agent.helper.package.pytorch import PytorchRequirement
|
||||
|
||||
|
||||
@attr.s
|
||||
class PytorchURLWheel(object):
|
||||
os = attr.ib()
|
||||
cuda = attr.ib()
|
||||
python = attr.ib()
|
||||
pytorch = attr.ib()
|
||||
url = attr.ib()
|
||||
|
||||
|
||||
wheels = [
|
||||
PytorchURLWheel(os=os, cuda=cuda, python=python, pytorch=pytorch_version, url=url)
|
||||
for os, os_d in PytorchRequirement.MAP.items()
|
||||
for cuda, cuda_d in os_d.items()
|
||||
if isinstance(cuda_d, dict)
|
||||
for python, python_d in cuda_d.items()
|
||||
if isinstance(python_d, dict)
|
||||
for pytorch_version, url in python_d.items()
|
||||
if isinstance(url, six.string_types) and furl(url).scheme
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wheel', wheels, ids=[','.join(map(str, attr.astuple(wheel))) for wheel in wheels])
|
||||
def test_map(wheel):
|
||||
assert requests.head(wheel.url).ok
|
||||
75
tests/package/test_repo_url_auth.py
Normal file
75
tests/package/test_repo_url_auth.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
|
||||
from trains_agent.helper.repo import Git
|
||||
|
||||
NO_CHANGE = object()
|
||||
|
||||
|
||||
def param(url, expected, user=False, password=False):
|
||||
"""
|
||||
Helper function for creating parametrization arguments.
|
||||
:param url: input url
|
||||
:param expected: expected output URL or NO_CHANGE if the same as input URL
|
||||
:param user: Add `agent.git_user=USER` to config
|
||||
:param password: Add `agent.git_password=PASSWORD` to config
|
||||
"""
|
||||
expected_repr = "NO_CHANGE" if expected is NO_CHANGE else None
|
||||
user = "USER" if user else None
|
||||
password = "PASSWORD" if password else None
|
||||
return pytest.param(
|
||||
url,
|
||||
expected,
|
||||
user,
|
||||
password,
|
||||
id="-".join(filter(None, (url, user, password, expected_repr))),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,expected,user,password",
|
||||
[
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"https://bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"https://user@bitbucket.org/company/repo",
|
||||
"https://USER:PASSWORD@bitbucket.org/company/repo",
|
||||
user=True,
|
||||
password=True,
|
||||
),
|
||||
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param(
|
||||
"https://user:password@bitbucket.org/company/repo", NO_CHANGE, password=True
|
||||
),
|
||||
param(
|
||||
"https://user:password@bitbucket.org/company/repo",
|
||||
NO_CHANGE,
|
||||
user=True,
|
||||
password=True,
|
||||
),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE, user=True),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE, password=True),
|
||||
param(
|
||||
"git@bitbucket.org:company/repo.git", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
],
|
||||
)
|
||||
def test(url, user, password, expected):
|
||||
config = {"agent": {"git_user": user, "git_pass": password}}
|
||||
result = Git.add_auth(config, url)
|
||||
expected = result if expected is NO_CHANGE else expected
|
||||
assert result == expected
|
||||
254
tests/package/test_task_script.py
Normal file
254
tests/package/test_task_script.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Test handling of jupyter notebook tasks.
|
||||
Logging is enabled in `trains_agent/tests/pytest.ini`. Search for `pytest live logging` for more info.
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
import select
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, ContextManager, Sequence, IO, Text
|
||||
from uuid import uuid4
|
||||
|
||||
from trains_agent.backend_api.services.tasks import Script
|
||||
from trains_agent.backend_api.session.client import APIClient
|
||||
from pathlib2 import Path
|
||||
from pytest import fixture
|
||||
|
||||
from trains_agent.helper.process import Argv
|
||||
|
||||
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_TASK_ARGS = {"type": "testing", "name": "test", "input": {"view": {}}}
|
||||
HERE = Path(__file__).resolve().parent
|
||||
SHORT_TIMEOUT = 30
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def client():
|
||||
return APIClient(api_version="2.2")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_task(client, **params):
|
||||
"""
|
||||
Create task in backend
|
||||
"""
|
||||
log.info("creating new task")
|
||||
task = client.tasks.create(**params)
|
||||
try:
|
||||
yield task
|
||||
finally:
|
||||
log.info("deleting task, id=%s", task.id)
|
||||
task.delete(force=True)
|
||||
|
||||
|
||||
def select_read(file_obj, timeout):
|
||||
return select.select([file_obj], [], [], timeout)[0]
|
||||
|
||||
|
||||
def run_task(task):
|
||||
return Argv("trains_agent", "--debug", "worker", "execute", "--id", task.id)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def iterate_output(timeout, command):
|
||||
# type: (int, Argv) -> ContextManager[Iterator[Text]]
|
||||
"""
|
||||
Run `command` in a subprocess and return a contextmanager of an iterator
|
||||
over its output's lines. If `timeout` seconds have passed, iterator ends and
|
||||
the process is killed.
|
||||
:param timeout: maximum amount of time to wait for command to end
|
||||
:param command: command to run
|
||||
"""
|
||||
log.info("running: %s", command)
|
||||
process = command.call_subprocess(
|
||||
subprocess.Popen, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||
)
|
||||
try:
|
||||
yield _iterate_output(timeout, process)
|
||||
finally:
|
||||
status = process.poll()
|
||||
if status is not None:
|
||||
log.info("command %s terminated, status: %s", command, status)
|
||||
else:
|
||||
log.info("killing command %s", command)
|
||||
process.kill()
|
||||
|
||||
|
||||
def _iterate_output(timeout, process):
|
||||
# type: (int, subprocess.Popen) -> Iterator[Text]
|
||||
"""
|
||||
Return an iterator over process's output lines.
|
||||
over its output's lines.
|
||||
If `timeout` seconds have passed, iterator ends and the process is killed.
|
||||
:param timeout: maximum amount of time to wait for command to end
|
||||
:param process: process to iterate over its output lines
|
||||
"""
|
||||
start = time.time()
|
||||
exit_loop = [] # type: Sequence[IO]
|
||||
|
||||
def loop_helper(file_obj):
|
||||
# type: (IO) -> Sequence[IO]
|
||||
diff = timeout - (time.time() - start)
|
||||
if diff <= 0:
|
||||
return exit_loop
|
||||
return select_read(file_obj, timeout=diff)
|
||||
|
||||
buffer = ""
|
||||
|
||||
for output, in iter(lambda: loop_helper(process.stdout), exit_loop):
|
||||
try:
|
||||
added = output.read(1024).decode("utf8")
|
||||
except EOFError:
|
||||
if buffer:
|
||||
yield buffer
|
||||
return
|
||||
buffer += added
|
||||
lines = buffer.split("\n", 1)
|
||||
|
||||
while len(lines) > 1:
|
||||
line, buffer = lines
|
||||
log.debug("--- %s", line)
|
||||
yield line
|
||||
lines = buffer.split("\n", 1)
|
||||
|
||||
|
||||
def search_lines(lines, search_for, error):
|
||||
# type: (Iterator[Text], Text, Text) -> None
|
||||
"""
|
||||
Fail test if `search_for` string appears nowhere in `lines`.
|
||||
Consumes `lines` up to point where `search_for` is found.
|
||||
:param lines: lines to search in
|
||||
:param search_for: string to search lines for
|
||||
:param error: error to show if not found
|
||||
"""
|
||||
for line in lines:
|
||||
if search_for in line:
|
||||
break
|
||||
else:
|
||||
assert False, error
|
||||
|
||||
|
||||
def search_lines_pattern(lines, pattern, error):
|
||||
# type: (Iterator[Text], Text, Text) -> None
|
||||
"""
|
||||
Like `search_lines` but searches for a pattern.
|
||||
:param lines: lines to search in
|
||||
:param pattern: pattern to search lines for
|
||||
:param error: error to show if not found
|
||||
"""
|
||||
for line in lines:
|
||||
if re.search(pattern, line):
|
||||
break
|
||||
else:
|
||||
assert False, error
|
||||
|
||||
|
||||
def test_entry_point_warning(client):
|
||||
"""
|
||||
non-empty script.entry_point should output a warning
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff="print('hello')", entry_point="foo.py", repository=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
for line in output:
|
||||
if "found non-empty script.entry_point" in line:
|
||||
break
|
||||
else:
|
||||
assert False, "did not find warning in output"
|
||||
|
||||
|
||||
def test_run_no_dirs(client):
|
||||
"""
|
||||
The arbitrary `code` directory should be selected when there is no `script.repository`
|
||||
"""
|
||||
uuid = uuid4().hex
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff=script, entry_point="", repository="", working_dir=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
search_lines(
|
||||
output,
|
||||
"found literal script",
|
||||
"task was not recognized as a literal script",
|
||||
)
|
||||
search_lines_pattern(
|
||||
output,
|
||||
r"selected execution directory:.*code",
|
||||
r"did not selected empty `code` dir as execution dir",
|
||||
)
|
||||
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
|
||||
|
||||
|
||||
def test_run_working_dir(client):
|
||||
"""
|
||||
Literal script tasks should respect `working_dir`
|
||||
"""
|
||||
uuid = uuid4().hex
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
diff=script,
|
||||
entry_point="",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
working_dir="space dir",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(120, run_task(task)) as output:
|
||||
search_lines(
|
||||
output,
|
||||
"found literal script",
|
||||
"task was not recognized as a literal script",
|
||||
)
|
||||
search_lines_pattern(
|
||||
output,
|
||||
r"selected execution directory:.*space dir",
|
||||
r"did not selected working_dir as set in execution_info",
|
||||
)
|
||||
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
|
||||
|
||||
|
||||
def test_regular_task(client):
|
||||
"""
|
||||
Test a plain old task
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
entry_point="noop.py",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
message = "Done"
|
||||
search_lines(
|
||||
output, message, "did not reach dummy output message {}".format(message)
|
||||
)
|
||||
|
||||
|
||||
def test_regular_task_nested(client):
|
||||
"""
|
||||
`entry_point` should be relative to `working_dir` if present
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
entry_point="noop_nested.py",
|
||||
working_dir="no_reqs",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
message = "Done"
|
||||
search_lines(
|
||||
output, message, "did not reach dummy output message {}".format(message)
|
||||
)
|
||||
11
tests/pytest.ini
Normal file
11
tests/pytest.ini
Normal file
@@ -0,0 +1,11 @@
|
||||
[pytest]
|
||||
python_files=*.py
|
||||
script_launch_mode = subprocess
|
||||
env =
|
||||
PYTHONIOENCODING=utf-8
|
||||
log_cli = 1
|
||||
log_cli_level = DEBUG
|
||||
log_cli_format = [%(name)s] [%(asctime)s] [%(levelname)8s] %(message)s
|
||||
log_print = False
|
||||
log_cli_date_format=%Y-%m-%d %H:%M:%S
|
||||
addopts = -p no:warnings
|
||||
2
tests/requirements.txt
Normal file
2
tests/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
pytest
|
||||
-r ../requirements.txt
|
||||
0
tests/requirements/__init__.py
Normal file
0
tests/requirements/__init__.py
Normal file
13
tests/requirements/conftest.py
Normal file
13
tests/requirements/conftest.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--cpu', action='store_true')
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if not config.option.cpu:
|
||||
return
|
||||
os.environ['PATH'] = ':'.join(p for p in os.environ['PATH'].split(':') if 'cuda' not in p)
|
||||
os.environ['CUDA_VERSION'] = ''
|
||||
os.environ['CUDNN_VERSION'] = ''
|
||||
350
tests/requirements/requirements_substitution.py
Normal file
350
tests/requirements/requirements_substitution.py
Normal file
@@ -0,0 +1,350 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from itertools import chain
|
||||
from os import path
|
||||
from os.path import sep
|
||||
from sys import platform as sys_platform
|
||||
|
||||
import pytest
|
||||
import requirements
|
||||
|
||||
from trains_agent.commands.worker import Worker
|
||||
from trains_agent.helper.package.pytorch import PytorchRequirement
|
||||
from trains_agent.helper.package.requirements import RequirementsManager, \
|
||||
RequirementSubstitution, MarkerRequirement
|
||||
from trains_agent.helper.process import get_bash_output
|
||||
from trains_agent.session import Session
|
||||
|
||||
_cuda_based_packages_hack = ('seematics.caffe', 'lightnet')
|
||||
|
||||
|
||||
def old_get_suffix(session):
|
||||
cuda_version = session.config['agent.cuda_version']
|
||||
cudnn_version = session.config['agent.cudnn_version']
|
||||
if cuda_version and cudnn_version:
|
||||
nvcc_ver = cuda_version.strip()
|
||||
cudnn_ver = cudnn_version.strip()
|
||||
else:
|
||||
if sys_platform == 'win32':
|
||||
nvcc_ver = subprocess.check_output('nvcc --version'.split()).decode('utf-8'). \
|
||||
replace('\r', '').split('\n')
|
||||
else:
|
||||
nvcc_ver = subprocess.check_output(
|
||||
'nvcc --version'.split()).decode('utf-8').split('\n')
|
||||
nvcc_ver = [l for l in nvcc_ver if 'release ' in l]
|
||||
nvcc_ver = nvcc_ver[0][nvcc_ver[0].find(
|
||||
'release ') + len('release '):][:3]
|
||||
nvcc_ver = str(int(float(nvcc_ver) * 10))
|
||||
if sys_platform == 'win32':
|
||||
cuda_lib = subprocess.check_output('where nvcc'.split()).decode('utf-8'). \
|
||||
replace('\r', '').split('\n')[0]
|
||||
cudnn_h = sep.join(cuda_lib.split(
|
||||
sep)[:-2] + ['include', 'cudnn.h'])
|
||||
else:
|
||||
cuda_lib = subprocess.check_output(
|
||||
'which nvcc'.split()).decode('utf-8').split('\n')[0]
|
||||
cudnn_h = path.join(
|
||||
sep, *(cuda_lib.split(sep)[:-2] + ['include', 'cudnn.h']))
|
||||
cudnn_mj, cudnn_mi = None, None
|
||||
for l in open(cudnn_h, 'r'):
|
||||
if 'CUDNN_MAJOR' in l:
|
||||
cudnn_mj = l.split()[-1]
|
||||
if 'CUDNN_MINOR' in l:
|
||||
cudnn_mi = l.split()[-1]
|
||||
if cudnn_mj and cudnn_mi:
|
||||
break
|
||||
cudnn_ver = cudnn_mj + ('0' if not cudnn_mi else cudnn_mi)
|
||||
# build cuda + cudnn version suffix
|
||||
# make sure these are integers, someone else will catch the exception
|
||||
pkg_suffix_ver = '.post' + \
|
||||
str(int(nvcc_ver)) + '.dev' + str(int(cudnn_ver))
|
||||
return pkg_suffix_ver, nvcc_ver, cudnn_ver
|
||||
|
||||
|
||||
def old_replace(session, line):
|
||||
try:
|
||||
cuda_ver_suffix, cuda_ver, cuda_cudnn_ver = old_get_suffix(session)
|
||||
except Exception:
|
||||
return line
|
||||
if line.lstrip().startswith('#'):
|
||||
return line
|
||||
for package_name in _cuda_based_packages_hack:
|
||||
if package_name not in line:
|
||||
continue
|
||||
try:
|
||||
line_lstrip = line.lstrip()
|
||||
if line_lstrip.startswith('http://') or line_lstrip.startswith('https://'):
|
||||
pos = line.find(package_name) + len(package_name)
|
||||
# patch line with specific version
|
||||
line = line[:pos] + \
|
||||
line[pos:].replace('-cp', cuda_ver_suffix + '-cp', 1)
|
||||
else:
|
||||
# this is a pypi package
|
||||
tokens = line.replace('=', ' ').replace('<', ' ').replace('>', ' ').replace(';', ' '). \
|
||||
replace('!', ' ').split()
|
||||
if package_name != tokens[0]:
|
||||
# how did we get here, probably a mistake
|
||||
found_cuda_based_package = False
|
||||
continue
|
||||
|
||||
version_number = None
|
||||
if len(tokens) > 1:
|
||||
# get the package version info
|
||||
test_version_number = tokens[1]
|
||||
# check if we have a valid version, i.e. does not contain post/dev
|
||||
version_number = '.'.join([v for v in test_version_number.split('.')
|
||||
if v and '0' <= v[0] <= '9'])
|
||||
if version_number != test_version_number:
|
||||
raise ValueError()
|
||||
|
||||
# we have no version, but we have to have one
|
||||
if not version_number:
|
||||
# get the latest version from the extra index list
|
||||
pip_search_cmd = ['pip', 'search']
|
||||
if Worker._pip_extra_index_url:
|
||||
pip_search_cmd.extend(
|
||||
chain.from_iterable(('-i', x) for x in Worker._pip_extra_index_url))
|
||||
pip_search_cmd += [package_name]
|
||||
pip_search_output = get_bash_output(
|
||||
' '.join(pip_search_cmd), strip=True)
|
||||
version_number = pip_search_output.split(package_name)[1]
|
||||
version_number = version_number.replace(
|
||||
'(', ' ').replace(')', ' ').split()[0]
|
||||
version_number = '.'.join([v for v in version_number.split('.')
|
||||
if v and '0' <= v[0] <= '9'])
|
||||
if not version_number:
|
||||
# somewhere along the way we failed
|
||||
raise ValueError()
|
||||
|
||||
package_name_version = package_name + '==' + version_number + cuda_ver_suffix
|
||||
if version_number in line:
|
||||
# make sure we have the specific version not >=
|
||||
tokens = line.split(';')
|
||||
line = ';'.join([package_name_version] + tokens[1:])
|
||||
else:
|
||||
# add version to the package_name
|
||||
line = line.replace(package_name, package_name_version, 1)
|
||||
|
||||
# print('pip install %s using CUDA v%s CuDNN v%s' % (package_name, cuda_ver, cuda_cudnn_ver))
|
||||
except ValueError:
|
||||
pass
|
||||
# print('Warning! could not find installed CUDA/CuDNN version for %s, '
|
||||
# 'using original requirements line: %s' % (package_name, line))
|
||||
# add the current line into the cuda requirements list
|
||||
return line
|
||||
|
||||
|
||||
win_condition = 'sys_platform != "win_32"'
|
||||
versions = ('1', '1.4', '1.4.9', '1.5.3.dev0',
|
||||
'1.5.3.dev3', '43.1.2.dev0.post1')
|
||||
|
||||
|
||||
def normalize(result):
|
||||
return result and re.sub(' ?; ?', ';', result)
|
||||
|
||||
|
||||
def parse_one(requirement):
|
||||
return MarkerRequirement(next(requirements.parse(requirement)))
|
||||
|
||||
|
||||
def compare(manager, current_session, pytest_config, requirement, expected):
|
||||
try:
|
||||
res1 = normalize(manager._replace_one(parse_one(requirement)))
|
||||
except ValueError:
|
||||
res1 = None
|
||||
res2 = old_replace(current_session, requirement)
|
||||
if res2 == requirement:
|
||||
res2 = None
|
||||
res2 = normalize(res2)
|
||||
expected = normalize(expected)
|
||||
if pytest_config.option.cpu:
|
||||
assert res1 is None
|
||||
else:
|
||||
assert res1 == expected
|
||||
if requirement not in FAILURES:
|
||||
assert res1 == res2
|
||||
return res1
|
||||
|
||||
|
||||
ARG_VERSIONED = pytest.mark.parametrize('arg', (
|
||||
'nothing{op}{version}{extra}',
|
||||
'something_else{op}{version}{extra}',
|
||||
'something-else{op}{version}{extra}',
|
||||
'something.else{op}{version}{extra}',
|
||||
'seematics.caffe{op}{version}{extra}',
|
||||
'lightnet{op}{version}{extra}',
|
||||
))
|
||||
OP = pytest.mark.parametrize('op', ('==', '<=', '>='))
|
||||
VERSION = pytest.mark.parametrize('version', versions)
|
||||
EXTRA = pytest.mark.parametrize('extra', ('', ' ; ' + win_condition))
|
||||
ARG_PLAIN = pytest.mark.parametrize('arg', (
|
||||
'nothing',
|
||||
'something_else',
|
||||
'something-else',
|
||||
'something.else',
|
||||
'seematics.caffe',
|
||||
'lightnet',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.config-1.0.2-py2.py3-none-any.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.api-1.0.2-py2.py3-none-any.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.sdk-1.1.0-py2.py3-none-any.whl',
|
||||
))
|
||||
|
||||
|
||||
@ARG_VERSIONED
|
||||
@OP
|
||||
@VERSION
|
||||
@EXTRA
|
||||
def test_with_version(manager, current_session, pytestconfig, arg, op, version, extra):
|
||||
requirement = arg.format(**locals())
|
||||
# expected = EXPECTED.get((arg, op, version))
|
||||
expected = get_expected(current_session, (arg, op, version))
|
||||
expected = expected and expected + extra
|
||||
compare(manager, current_session, pytestconfig, requirement, expected)
|
||||
|
||||
|
||||
@ARG_PLAIN
|
||||
@EXTRA
|
||||
def test_plain(arg, current_session, pytestconfig, extra, manager):
|
||||
# expected = EXPECTED.get(arg)
|
||||
expected = get_expected(current_session, arg)
|
||||
expected = expected and expected + extra
|
||||
compare(manager, current_session, pytestconfig, arg + extra, expected)
|
||||
|
||||
|
||||
def get_expected(session, key):
|
||||
result = EXPECTED.get(key)
|
||||
cuda_version, cudnn_version = session.config['agent.cuda_version'], session.config['agent.cudnn_version']
|
||||
suffix = '.post{cuda_version}.dev{cudnn_version}'.format(**locals()) \
|
||||
if (cuda_version and cudnn_version) \
|
||||
else ''
|
||||
return result and result.format(suffix=suffix)
|
||||
|
||||
|
||||
@ARG_VERSIONED
|
||||
@OP
|
||||
@VERSION
|
||||
@EXTRA
|
||||
def test_str_versioned(arg, op, version, extra):
|
||||
requirement = arg.format(**locals())
|
||||
assert normalize(str(parse_one(requirement))) == normalize(requirement)
|
||||
|
||||
|
||||
@ARG_PLAIN
|
||||
@EXTRA
|
||||
def test_str_plain(arg, extra, manager):
|
||||
requirement = arg.format(**locals())
|
||||
assert normalize(str(parse_one(requirement))) == normalize(requirement)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def current_session(pytestconfig):
|
||||
session = Session()
|
||||
if not pytestconfig.option.cpu:
|
||||
return session
|
||||
session.config['agent.cuda_version'] = None
|
||||
session.config['agent.cudnn_version'] = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def manager(current_session):
|
||||
manager = RequirementsManager(current_session.config)
|
||||
for requirement in (PytorchRequirement, ):
|
||||
manager.register(requirement)
|
||||
return manager
|
||||
|
||||
|
||||
SPECS = (
|
||||
# plain
|
||||
'',
|
||||
|
||||
# greater than
|
||||
'>=0', '>0.1', '>=0.1', '>0.1.2', '>=0.1.2', '>0.1.2.post30', '>=0.1.3.post30', '>0.post30.dev40',
|
||||
'>=0.post30.dev40', '>0.post30-dev40', '>=0.post30-dev40', '>0.0', '>=0.0', '>0.3', '>=0.3', '>=0.4',
|
||||
|
||||
# smaller than
|
||||
'<4', '<4.0', '<4.1.0', '<4.1.0.post80.dev60', '<4.1.0-post80.dev60', '<3', '<2.0', '<1.0.3', '<=4', '<=4.0',
|
||||
'<=4.1.0', '<=4.1.0.post80.dev60', '<=4.1.0-post80.dev60', '<=3', '<=2.0', '<=1.0.3',
|
||||
|
||||
# equals
|
||||
'==0.4.0',
|
||||
|
||||
# equals and
|
||||
'==0.4.0,>=0', '==0.4.0,>=0.1', '==0.4.0,>0.1', '==0.4.0,>=0.1.2', '==0.4.0,>0.1.2', '==0.4.0,<4', '==0.4.0,<=4',
|
||||
'==0.4.0,<4.0', '==0.4.0,<=4.0', '==0.4.0,<4.0.2',
|
||||
|
||||
# smaller and greater
|
||||
'>=0,<4', '>=0,<4.0', '>0.1,<1', '>0.1,<1.1.2', '>0.1,<1.1.2', '>=0,<=4', '>=0,<4.0',
|
||||
'>=0.1,<1', '>0.1,<=1.1.2', '>=0.1,<1.1.2',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('package_manager', ('pip', 'conda'))
|
||||
@pytest.mark.parametrize('os', ('linux', 'windows', 'macos'))
|
||||
@pytest.mark.parametrize('cuda', ('80', '90', '91'))
|
||||
@pytest.mark.parametrize('python', ('2.7', '3.5', '3.6'))
|
||||
@pytest.mark.parametrize('spec', SPECS)
|
||||
@pytest.mark.parametrize('condition', ('', ' ; ' + win_condition))
|
||||
def test_pytorch_success(manager, package_manager, os, cuda, python, spec, condition):
|
||||
pytorch_handler = manager.handlers[-1]
|
||||
pytorch_handler.package_manager = package_manager
|
||||
pytorch_handler.os = os
|
||||
pytorch_handler.cuda = cuda_ver = 'cuda{}'.format(cuda)
|
||||
pytorch_handler.python = python_ver = 'python{}'.format(python)
|
||||
req = 'torch{}{}'.format(spec, condition)
|
||||
expected = pytorch_handler.MAP[package_manager][os][cuda_ver][python_ver]
|
||||
if isinstance(expected, Exception):
|
||||
with pytest.raises(type(expected)):
|
||||
manager._replace_one(parse_one(req))
|
||||
else:
|
||||
expected = expected['0.4.0']
|
||||
result = manager._replace_one(parse_one(req))
|
||||
assert result == expected
|
||||
|
||||
|
||||
get_pip_version = RequirementSubstitution.get_pip_version
|
||||
|
||||
EXPECTED = {
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp35'
|
||||
'-cp35m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp27'
|
||||
'-cp27m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp35-cp35m'
|
||||
'-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp27-cp27mu'
|
||||
'-linux_x86_64.whl',
|
||||
'seematics.caffe': 'seematics.caffe=={}{{suffix}}'.format(get_pip_version('seematics.caffe')),
|
||||
'lightnet': 'lightnet=={}{{suffix}}'.format(get_pip_version('lightnet')),
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
}
|
||||
FAILURES = {
|
||||
'seematics.caffe ; {}'.format(win_condition),
|
||||
'lightnet ; {}'.format(win_condition)
|
||||
}
|
||||
0
tests/scripts/__init__.py
Normal file
0
tests/scripts/__init__.py
Normal file
7
tests/scripts/python2-test.py
Normal file
7
tests/scripts/python2-test.py
Normal file
@@ -0,0 +1,7 @@
|
||||
def main():
|
||||
assert 1 / 2 == 0
|
||||
print('success')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
tests/scripts/python3-test.py
Normal file
8
tests/scripts/python3-test.py
Normal file
@@ -0,0 +1,8 @@
|
||||
def main():
|
||||
if not (1 / 2 == 0.5):
|
||||
raise ValueError('failure')
|
||||
print('success')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
trains_agent/__init__.py
Normal file
1
trains_agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
99
trains_agent/__main__.py
Normal file
99
trains_agent/__main__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import print_function, unicode_literals, absolute_import
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from trains_agent.backend_api.session.datamodel import UnusedKwargsWarning
|
||||
|
||||
import trains_agent
|
||||
from trains_agent.config import get_config
|
||||
from trains_agent.definitions import FileBuffering, CONFIG_FILE
|
||||
from trains_agent.helper.base import reverse_home_folder_expansion, chain_map, named_temporary_file
|
||||
from trains_agent.helper.process import ExitStatus
|
||||
from . import interface, session, definitions, commands
|
||||
from .errors import ConfigFileNotFound, Sigterm, APIError
|
||||
from .helper.trace import PackageTrace
|
||||
from .interface import get_parser
|
||||
|
||||
|
||||
def run_command(parser, args, command_name):
|
||||
|
||||
debug = args.debug
|
||||
if len(command_name.split('.')) < 2:
|
||||
command_class = commands.Worker
|
||||
elif hasattr(args, 'func') and getattr(args, 'func'):
|
||||
command_class = getattr(commands, command_name.capitalize())
|
||||
command_name = args.func
|
||||
else:
|
||||
command_class, command_name = command_name.split('.')
|
||||
command_class = getattr(commands, command_class.capitalize())
|
||||
|
||||
args_dict = dict(vars(args))
|
||||
parser.remove_top_level_results(args_dict)
|
||||
|
||||
warnings.simplefilter('ignore', UnusedKwargsWarning)
|
||||
|
||||
try:
|
||||
command = command_class(**vars(args))
|
||||
get_config()['command'] = command
|
||||
debug = command._session.debug_mode
|
||||
func = getattr(command, command_name)
|
||||
return func(**args_dict)
|
||||
except ConfigFileNotFound:
|
||||
message = 'Cannot find configuration file in "{}".\n' \
|
||||
'To create a configuration file, run:\n' \
|
||||
'$ trains_agent init'.format(reverse_home_folder_expansion(CONFIG_FILE))
|
||||
command_class.exit(message)
|
||||
except APIError as api_error:
|
||||
if not debug:
|
||||
command_class.error(api_error)
|
||||
return ExitStatus.failure
|
||||
traceback = api_error.format_traceback()
|
||||
if traceback:
|
||||
print(traceback)
|
||||
print('Own traceback:')
|
||||
raise
|
||||
except Exception as e:
|
||||
if debug:
|
||||
raise
|
||||
command_class.error(e)
|
||||
return ExitStatus.failure
|
||||
except (KeyboardInterrupt, Sigterm):
|
||||
return ExitStatus.interrupted
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
command_name = args.command
|
||||
if not command_name:
|
||||
return parser.print_help()
|
||||
except AttributeError:
|
||||
parser.error(argparse._('too few arguments'))
|
||||
|
||||
if not args.trace:
|
||||
return run_command(parser, args, command_name)
|
||||
|
||||
with named_temporary_file(
|
||||
mode='w',
|
||||
buffering=FileBuffering.LINE_BUFFERING,
|
||||
prefix='.trains_agent_trace_',
|
||||
suffix='.txt',
|
||||
delete=False,
|
||||
) as output:
|
||||
print(
|
||||
'Saving trace for command '
|
||||
'"{definitions.PROGRAM_NAME} {command_name} {args.func}" to "{output.name}"'.format(
|
||||
**chain_map(locals(), globals())))
|
||||
tracer = PackageTrace(
|
||||
package=trains_agent,
|
||||
out_file=output,
|
||||
ignore_submodules=(__name__, interface, definitions, session))
|
||||
return tracer.runfunc(run_command, parser, args, command_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
2
trains_agent/backend_api/__init__.py
Normal file
2
trains_agent/backend_api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
||||
from .config import load as load_config
|
||||
16
trains_agent/backend_api/config/__init__.py
Normal file
16
trains_agent/backend_api/config/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ...backend_config import Config
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def load(*additional_module_paths):
|
||||
# type: (str) -> Config
|
||||
"""
|
||||
Load configuration with the API defaults, using the additional module path provided
|
||||
:param additional_module_paths: Additional config paths for modules who'se default
|
||||
configurations should be loaded as well
|
||||
:return: Config object
|
||||
"""
|
||||
config = Config(verbose=False)
|
||||
this_module_path = Path(__file__).parent
|
||||
config.load_relative_to(this_module_path, *additional_module_paths)
|
||||
return config
|
||||
79
trains_agent/backend_api/config/default/agent.conf
Normal file
79
trains_agent/backend_api/config/default/agent.conf
Normal file
@@ -0,0 +1,79 @@
|
||||
{
|
||||
# unique name of this worker, if None, created based on hostname:process_id
|
||||
# Override with os environment: TRAINS_WORKER_ID
|
||||
# worker_id: "trains-agent-machine1:gpu0"
|
||||
worker_id: ""
|
||||
|
||||
# worker name, replaces the hostname when creating a unique name for this worker
|
||||
# Override with os environment: TRAINS_WORKER_NAME
|
||||
# worker_name: "trains-agent-machine1"
|
||||
worker_name: ""
|
||||
|
||||
# Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials.
|
||||
# git_user: ""
|
||||
# git_pass: ""
|
||||
|
||||
# select python package manager:
|
||||
# currently supported pip and conda
|
||||
# poetry is used if pip selected and repository contains poetry.lock file
|
||||
package_manager: {
|
||||
# supported options: pip, conda
|
||||
type: pip,
|
||||
|
||||
# virtual environment inheres packages from system
|
||||
system_site_packages: false,
|
||||
|
||||
# install with --upgrade
|
||||
force_upgrade: false,
|
||||
|
||||
# additional artifact repositories to use when installing python packages
|
||||
# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]
|
||||
extra_index_url: []
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["defaults", "conda-forge", "pytorch", ]
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
venvs_dir = ~/.trains/venvs-builds
|
||||
|
||||
# cached git clone folder
|
||||
vcs_cache: {
|
||||
enabled: true,
|
||||
path: ~/.trains/vcs-cache
|
||||
},
|
||||
|
||||
# use venv-update in order to accelerate python virtual environment building
|
||||
# Still in beta, turned off by default
|
||||
venv_update: {
|
||||
enabled: false,
|
||||
},
|
||||
|
||||
# cached folder for specific python package download (used for pytorch package caching)
|
||||
pip_download_cache {
|
||||
enabled: true,
|
||||
path: ~/.trains/pip-download-cache
|
||||
},
|
||||
|
||||
translate_ssh: true,
|
||||
# reload configuration file every daemon execution
|
||||
reload_config: false,
|
||||
|
||||
# pip cache folder used mapped into docker, for python package caching
|
||||
docker_pip_cache = ~/.trains/pip-cache
|
||||
# apt cache folder used mapped into docker, for ubuntu package caching
|
||||
docker_apt_cache = ~/.trains/apt-cache
|
||||
|
||||
default_docker: {
|
||||
# default docker image to use when running in docker mode
|
||||
image: "nvidia/cuda"
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# arguments: ["--ipc=host", ]
|
||||
}
|
||||
|
||||
# cuda versions used for solving pytorch wheel packages
|
||||
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
|
||||
# cuda_version: 10.1
|
||||
# cudnn_version: 7.6
|
||||
}
|
||||
37
trains_agent/backend_api/config/default/api.conf
Normal file
37
trains_agent/backend_api/config/default/api.conf
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
version: 1.5
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
|
||||
# default version assigned to requests with no specific version. this is not expected to change
|
||||
# as it keeps us backwards compatible.
|
||||
default_version: 1.5
|
||||
|
||||
http {
|
||||
max_req_size = 15728640 # request size limit (smaller than that configured in api server)
|
||||
|
||||
retries {
|
||||
# retry values (int, 0 means fail on first retry)
|
||||
total: 240 # Total number of retries to allow. Takes precedence over other counts.
|
||||
connect: 240 # How many times to retry on connection-related errors (never reached server)
|
||||
read: 240 # How many times to retry on read errors (waiting for server)
|
||||
redirect: 240 # How many redirects to perform (HTTP response with a status code 301, 302, 303, 307 or 308)
|
||||
status: 240 # How many times to retry on bad status codes
|
||||
|
||||
# backoff parameters
|
||||
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
|
||||
backoff_factor: 1.0
|
||||
backoff_max: 120.0
|
||||
}
|
||||
|
||||
wait_on_maintenance_forever: true
|
||||
|
||||
pool_maxsize: 512
|
||||
pool_connections: 512
|
||||
}
|
||||
|
||||
auth {
|
||||
# When creating a request, if token will expire in less than this value, try to refresh the token
|
||||
token_expiration_threshold_sec = 360
|
||||
}
|
||||
}
|
||||
8
trains_agent/backend_api/config/default/logging.conf
Normal file
8
trains_agent/backend_api/config/default/logging.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
version: 1
|
||||
loggers {
|
||||
urllib3 {
|
||||
level: ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
150
trains_agent/backend_api/config/default/sdk.conf
Normal file
150
trains_agent/backend_api/config/default/sdk.conf
Normal file
@@ -0,0 +1,150 @@
|
||||
{
|
||||
# TRAINS - default SDK configuration
|
||||
|
||||
storage {
|
||||
cache {
|
||||
# Defaults to system temp folder / cache
|
||||
default_base_dir: "~/.trains/cache"
|
||||
size {
|
||||
# max_used_bytes = -1
|
||||
min_free_bytes = 10GB
|
||||
# cleanup_margin_percent = 5%
|
||||
}
|
||||
}
|
||||
|
||||
direct_access: [
|
||||
# Objects matching are considered to be available for direct access, i.e. they will not be downloaded
|
||||
# or cached, and any download request will return a direct reference.
|
||||
# Objects are specified in glob format, available for url and content_type.
|
||||
{ url: "file://*" } # file-urls are always directly referenced
|
||||
]
|
||||
}
|
||||
|
||||
metrics {
|
||||
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
|
||||
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
}
|
||||
|
||||
network {
|
||||
metrics {
|
||||
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
|
||||
# a specific iteration
|
||||
file_upload_threads: 4
|
||||
|
||||
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
|
||||
# being sent for upload
|
||||
file_upload_starvation_warning_sec: 120
|
||||
}
|
||||
|
||||
iteration {
|
||||
# Max number of retries when getting frames if the server returned an error (http code 500)
|
||||
max_retries_on_server_error: 5
|
||||
# Backoff factory for consecutive retry attempts.
|
||||
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
|
||||
retry_backoff_factor_sec: 10
|
||||
}
|
||||
}
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
# {
|
||||
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
# host: "my-minio-host:9000"
|
||||
# key: "12345678"
|
||||
# secret: "12345678"
|
||||
# multipart: false
|
||||
# secure: false
|
||||
# }
|
||||
]
|
||||
}
|
||||
boto3 {
|
||||
pool_connections: 512
|
||||
max_multipart_concurrency: 16
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# # Default project and credentials file
|
||||
# # Will be used when no bucket configuration is found
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
|
||||
# # Specific credentials per bucket and sub directory
|
||||
# credentials = [
|
||||
# {
|
||||
# bucket: "my-bucket"
|
||||
# subdir: "path/in/bucket" # Not required
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
# },
|
||||
# ]
|
||||
}
|
||||
azure.storage {
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "trains"
|
||||
# account_key: "secret"
|
||||
# # container_name:
|
||||
# }
|
||||
# ]
|
||||
}
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
}
|
||||
|
||||
development {
|
||||
# Development-mode options
|
||||
|
||||
# dev task reuse window
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: True
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
# Status report period in seconds
|
||||
report_period_sec: 2
|
||||
|
||||
# ping to the server - check connectivity
|
||||
ping_period_sec: 30
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
}
|
||||
}
|
||||
}
|
||||
0
trains_agent/backend_api/schema/__init__.py
Normal file
0
trains_agent/backend_api/schema/__init__.py
Normal file
38
trains_agent/backend_api/schema/action.py
Normal file
38
trains_agent/backend_api/schema/action.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
import attr
|
||||
from attr.converters import optional as optional_converter
|
||||
from attr.validators import instance_of, optional, and_
|
||||
from six import string_types
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
sequence = instance_of((list, tuple))
|
||||
|
||||
|
||||
def sequence_of(types):
|
||||
def validator(_, attrib, value):
|
||||
assert all(isinstance(x, types) for x in value), attrib.name
|
||||
|
||||
return and_(sequence, validator)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Action(object):
|
||||
name = attr.ib()
|
||||
version = attr.ib()
|
||||
service = attr.ib()
|
||||
definitions_keys = attr.ib(validator=sequence)
|
||||
authorize = attr.ib(validator=instance_of(bool), default=True)
|
||||
log_data = attr.ib(validator=instance_of(bool), default=True)
|
||||
log_result_data = attr.ib(validator=instance_of(bool), default=True)
|
||||
internal = attr.ib(default=False)
|
||||
allow_roles = attr.ib(default=None, validator=optional(sequence_of(string_types)))
|
||||
request = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
batch_request = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
response = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
method = attr.ib(default=None)
|
||||
description = attr.ib(
|
||||
default=None,
|
||||
validator=optional(instance_of(string_types)),
|
||||
)
|
||||
201
trains_agent/backend_api/schema/service.py
Normal file
201
trains_agent/backend_api/schema/service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
import re
|
||||
|
||||
import attr
|
||||
import six
|
||||
|
||||
import pyhocon
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
class Service(object):
|
||||
""" Service schema handler """
|
||||
|
||||
__jsonschema_ref_ex = re.compile("^#/definitions/(.*)$")
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def actions(self):
|
||||
return self._actions
|
||||
|
||||
@property
|
||||
def definitions(self):
|
||||
""" Raw service definitions (each might be dependant on some of its siblings) """
|
||||
return self._definitions
|
||||
|
||||
@property
|
||||
def definitions_refs(self):
|
||||
return self._definitions_refs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return self._doc
|
||||
|
||||
def __init__(self, name, service_config):
|
||||
self._name = name
|
||||
self._default = None
|
||||
self._actions = []
|
||||
self._definitions = None
|
||||
self._definitions_refs = None
|
||||
self._doc = None
|
||||
self.parse(service_config)
|
||||
|
||||
@classmethod
|
||||
def get_ref_name(cls, ref_string):
|
||||
m = cls.__jsonschema_ref_ex.match(ref_string)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
def parse(self, service_config):
|
||||
self._default = service_config.get(
|
||||
"_default", pyhocon.ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
|
||||
self._doc = '{} service'.format(self.name)
|
||||
description = service_config.get('_description', '')
|
||||
if description:
|
||||
self._doc += '\n\n{}'.format(description)
|
||||
self._definitions = service_config.get(
|
||||
"_definitions", pyhocon.ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
self._definitions_refs = {
|
||||
k: self._get_schema_references(v) for k, v in self._definitions.items()
|
||||
}
|
||||
all_refs = set(itertools.chain(*self.definitions_refs.values()))
|
||||
if not all_refs.issubset(self.definitions):
|
||||
raise ValueError(
|
||||
"Unresolved references (%s) in %s/definitions"
|
||||
% (", ".join(all_refs.difference(self.definitions)), self.name)
|
||||
)
|
||||
|
||||
actions = {
|
||||
k: v.as_plain_ordered_dict()
|
||||
for k, v in service_config.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
self._actions = {
|
||||
action_name: action
|
||||
for action_name, action in (
|
||||
(action_name, self._parse_action_versions(action_name, action_versions))
|
||||
for action_name, action_versions in actions.items()
|
||||
)
|
||||
if action
|
||||
}
|
||||
|
||||
def _parse_action_versions(self, action_name, action_versions):
|
||||
def parse_version(action_version):
|
||||
try:
|
||||
return float(action_version)
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise ValueError(
|
||||
"Failed parsing version number {} ({}) in {}/{}".format(
|
||||
action_version, ex.args[0], self.name, action_name
|
||||
)
|
||||
)
|
||||
|
||||
def add_internal(cfg):
|
||||
if "internal" in action_versions:
|
||||
cfg.setdefault("internal", action_versions["internal"])
|
||||
return cfg
|
||||
|
||||
return {
|
||||
parsed_version: action
|
||||
for parsed_version, action in (
|
||||
(parsed_version, self._parse_action(action_name, parsed_version, add_internal(cfg)))
|
||||
for parsed_version, cfg in (
|
||||
(parse_version(version), cfg)
|
||||
for version, cfg in action_versions.items()
|
||||
if version not in ["internal", "allow_roles", "authorize"]
|
||||
)
|
||||
)
|
||||
if action
|
||||
}
|
||||
|
||||
def _get_schema_references(self, s):
|
||||
refs = set()
|
||||
if isinstance(s, dict):
|
||||
for k, v in s.items():
|
||||
if isinstance(v, six.string_types):
|
||||
m = self.__jsonschema_ref_ex.match(v)
|
||||
if m:
|
||||
refs.add(m.group(1))
|
||||
continue
|
||||
elif k in ("oneOf", "anyOf") and isinstance(v, list):
|
||||
refs.update(*map(self._get_schema_references, v))
|
||||
refs.update(self._get_schema_references(v))
|
||||
return refs
|
||||
|
||||
def _expand_schema_references_with_definitions(self, schema, refs=None):
|
||||
definitions = schema.get("definitions", {})
|
||||
refs = refs if refs is not None else self._get_schema_references(schema)
|
||||
required_refs = set(refs).difference(definitions)
|
||||
if not required_refs:
|
||||
return required_refs
|
||||
if not required_refs.issubset(self.definitions):
|
||||
raise ValueError(
|
||||
"Unresolved references (%s)"
|
||||
% ", ".join(required_refs.difference(self.definitions))
|
||||
)
|
||||
|
||||
# update required refs with all sub requirements
|
||||
last_required_refs = None
|
||||
while last_required_refs != required_refs:
|
||||
last_required_refs = required_refs.copy()
|
||||
additional_refs = set(
|
||||
itertools.chain(
|
||||
*(self.definitions_refs.get(ref, []) for ref in required_refs)
|
||||
)
|
||||
)
|
||||
required_refs.update(additional_refs)
|
||||
return required_refs
|
||||
|
||||
def _resolve_schema_references(self, schema, refs=None):
|
||||
definitions = schema.get("definitions", {})
|
||||
definitions.update({k: v for k, v in self.definitions.items() if k in refs})
|
||||
schema["definitions"] = definitions
|
||||
|
||||
def _parse_action(self, action_name, action_version, action_config):
|
||||
data = self.default.copy()
|
||||
data.update(action_config)
|
||||
|
||||
if not action_config.get("generate", True):
|
||||
return None
|
||||
|
||||
definitions_keys = set()
|
||||
for schema_key in ("request", "response"):
|
||||
if schema_key in action_config:
|
||||
try:
|
||||
schema = action_config[schema_key]
|
||||
refs = self._expand_schema_references_with_definitions(schema)
|
||||
self._resolve_schema_references(schema, refs=refs)
|
||||
definitions_keys.update(refs)
|
||||
except ValueError as ex:
|
||||
name = "%s.%s/%.1f/%s" % (
|
||||
self.name,
|
||||
action_name,
|
||||
action_version,
|
||||
schema_key,
|
||||
)
|
||||
raise ValueError("%s in %s" % (str(ex), name))
|
||||
|
||||
return Action(
|
||||
name=action_name,
|
||||
version=action_version,
|
||||
definitions_keys=list(definitions_keys),
|
||||
service=self.name,
|
||||
**(
|
||||
{
|
||||
key: value
|
||||
for key, value in data.items()
|
||||
if key in attr.fields_dict(Action)
|
||||
}
|
||||
)
|
||||
)
|
||||
13
trains_agent/backend_api/services/__init__.py
Normal file
13
trains_agent/backend_api/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .v2_4 import auth
|
||||
from .v2_4 import debug
|
||||
from .v2_4 import queues
|
||||
from .v2_4 import tasks
|
||||
from .v2_4 import workers
|
||||
|
||||
__all__ = [
|
||||
'auth',
|
||||
'debug',
|
||||
'queues',
|
||||
'tasks',
|
||||
'workers',
|
||||
]
|
||||
0
trains_agent/backend_api/services/v2_4/__init__.py
Normal file
0
trains_agent/backend_api/services/v2_4/__init__.py
Normal file
623
trains_agent/backend_api/services/v2_4/auth.py
Normal file
623
trains_agent/backend_api/services/v2_4/auth.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
auth service
|
||||
|
||||
This service provides authentication management and authorization
|
||||
validation for the entire system.
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class Credentials(NonStrictDataModel):
|
||||
"""
|
||||
:param access_key: Credentials access key
|
||||
:type access_key: str
|
||||
:param secret_key: Credentials secret key
|
||||
:type secret_key: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, secret_key=None, **kwargs):
|
||||
super(Credentials, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
@schema_property('secret_key')
|
||||
def secret_key(self):
|
||||
return self._property_secret_key
|
||||
|
||||
@secret_key.setter
|
||||
def secret_key(self, value):
|
||||
if value is None:
|
||||
self._property_secret_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "secret_key", six.string_types)
|
||||
self._property_secret_key = value
|
||||
|
||||
|
||||
class CredentialKey(NonStrictDataModel):
|
||||
"""
|
||||
:param access_key:
|
||||
:type access_key: str
|
||||
:param last_used:
|
||||
:type last_used: datetime.datetime
|
||||
:param last_used_from:
|
||||
:type last_used_from: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'access_key': {'description': '', 'type': ['string', 'null']},
|
||||
'last_used': {
|
||||
'description': '',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'last_used_from': {'description': '', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, last_used=None, last_used_from=None, **kwargs):
|
||||
super(CredentialKey, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
self.last_used = last_used
|
||||
self.last_used_from = last_used_from
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
@schema_property('last_used')
|
||||
def last_used(self):
|
||||
return self._property_last_used
|
||||
|
||||
@last_used.setter
|
||||
def last_used(self, value):
|
||||
if value is None:
|
||||
self._property_last_used = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "last_used", six.string_types + (datetime,))
|
||||
if not isinstance(value, datetime):
|
||||
value = parse_datetime(value)
|
||||
self._property_last_used = value
|
||||
|
||||
@schema_property('last_used_from')
|
||||
def last_used_from(self):
|
||||
return self._property_last_used_from
|
||||
|
||||
@last_used_from.setter
|
||||
def last_used_from(self, value):
|
||||
if value is None:
|
||||
self._property_last_used_from = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "last_used_from", six.string_types)
|
||||
self._property_last_used_from = value
|
||||
|
||||
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Request):
|
||||
"""
|
||||
Creates a new set of credentials for the authenticated user.
|
||||
New key/secret is returned.
|
||||
Note: Secret will never be returned in any other API call.
|
||||
If a secret is lost or compromised, the key should be revoked
|
||||
and a new set of credentials can be created.
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.1"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
'properties': {},
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.create_credentials endpoint.
|
||||
|
||||
:param credentials: Created credentials
|
||||
:type credentials: Credentials
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.1"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credentials': {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'credentials': {
|
||||
'description': 'Created credentials',
|
||||
'oneOf': [{'$ref': '#/definitions/credentials'}, {'type': 'null'}],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, credentials=None, **kwargs):
|
||||
super(CreateCredentialsResponse, self).__init__(**kwargs)
|
||||
self.credentials = credentials
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
if isinstance(value, dict):
|
||||
value = Credentials.from_dict(value)
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", Credentials)
|
||||
self._property_credentials = value
|
||||
|
||||
|
||||
|
||||
|
||||
class EditUserRequest(Request):
|
||||
"""
|
||||
Edit a users' auth data properties
|
||||
|
||||
:param user: User ID
|
||||
:type user: str
|
||||
:param role: The new user's role within the company
|
||||
:type role: str
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.1"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'role': {
|
||||
'description': "The new user's role within the company",
|
||||
'enum': ['admin', 'superuser', 'user', 'annotator'],
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'user': {'description': 'User ID', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, user=None, role=None, **kwargs):
|
||||
super(EditUserRequest, self).__init__(**kwargs)
|
||||
self.user = user
|
||||
self.role = role
|
||||
|
||||
@schema_property('user')
|
||||
def user(self):
|
||||
return self._property_user
|
||||
|
||||
@user.setter
|
||||
def user(self, value):
|
||||
if value is None:
|
||||
self._property_user = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "user", six.string_types)
|
||||
self._property_user = value
|
||||
|
||||
@schema_property('role')
|
||||
def role(self):
|
||||
return self._property_role
|
||||
|
||||
@role.setter
|
||||
def role(self, value):
|
||||
if value is None:
|
||||
self._property_role = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "role", six.string_types)
|
||||
self._property_role = value
|
||||
|
||||
|
||||
class EditUserResponse(Response):
|
||||
"""
|
||||
Response of auth.edit_user endpoint.
|
||||
|
||||
:param updated: Number of users updated (0 or 1)
|
||||
:type updated: float
|
||||
:param fields: Updated fields names and values
|
||||
:type fields: dict
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.1"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'fields': {
|
||||
'additionalProperties': True,
|
||||
'description': 'Updated fields names and values',
|
||||
'type': ['object', 'null'],
|
||||
},
|
||||
'updated': {
|
||||
'description': 'Number of users updated (0 or 1)',
|
||||
'enum': [0, 1],
|
||||
'type': ['number', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, updated=None, fields=None, **kwargs):
|
||||
super(EditUserResponse, self).__init__(**kwargs)
|
||||
self.updated = updated
|
||||
self.fields = fields
|
||||
|
||||
@schema_property('updated')
|
||||
def updated(self):
|
||||
return self._property_updated
|
||||
|
||||
@updated.setter
|
||||
def updated(self, value):
|
||||
if value is None:
|
||||
self._property_updated = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "updated", six.integer_types + (float,))
|
||||
self._property_updated = value
|
||||
|
||||
@schema_property('fields')
|
||||
def fields(self):
|
||||
return self._property_fields
|
||||
|
||||
@fields.setter
|
||||
def fields(self, value):
|
||||
if value is None:
|
||||
self._property_fields = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "fields", (dict,))
|
||||
self._property_fields = value
|
||||
|
||||
|
||||
class GetCredentialsRequest(Request):
|
||||
"""
|
||||
Returns all existing credential keys for the authenticated user.
|
||||
Note: Only credential keys are returned.
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.1"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
'properties': {},
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
|
||||
class GetCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.get_credentials endpoint.
|
||||
|
||||
:param credentials: List of credentials, each with an empty secret field.
|
||||
:type credentials: Sequence[CredentialKey]
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.1"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credential_key': {
|
||||
'properties': {
|
||||
'access_key': {'description': '', 'type': ['string', 'null']},
|
||||
'last_used': {
|
||||
'description': '',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'last_used_from': {
|
||||
'description': '',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'credentials': {
|
||||
'description': 'List of credentials, each with an empty secret field.',
|
||||
'items': {'$ref': '#/definitions/credential_key'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, credentials=None, **kwargs):
|
||||
super(GetCredentialsResponse, self).__init__(**kwargs)
|
||||
self.credentials = credentials
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "credentials", (list, tuple))
|
||||
if any(isinstance(v, dict) for v in value):
|
||||
value = [CredentialKey.from_dict(v) if isinstance(v, dict) else v for v in value]
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", CredentialKey, is_array=True)
|
||||
self._property_credentials = value
|
||||
|
||||
|
||||
|
||||
|
||||
class LoginRequest(Request):
|
||||
"""
|
||||
Get a token based on supplied credentials (key/secret).
|
||||
Intended for use by users with key/secret credentials that wish to obtain a token
|
||||
for use with other services. Token will be limited by the same permissions that
|
||||
exist for the credentials used in this call.
|
||||
|
||||
:param expiration_sec: Requested token expiration time in seconds. Not
|
||||
guaranteed, might be overridden by the service
|
||||
:type expiration_sec: int
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.1"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'expiration_sec': {
|
||||
'description': 'Requested token expiration time in seconds. \n Not guaranteed, might be overridden by the service',
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, expiration_sec=None, **kwargs):
|
||||
super(LoginRequest, self).__init__(**kwargs)
|
||||
self.expiration_sec = expiration_sec
|
||||
|
||||
@schema_property('expiration_sec')
|
||||
def expiration_sec(self):
|
||||
return self._property_expiration_sec
|
||||
|
||||
@expiration_sec.setter
|
||||
def expiration_sec(self, value):
|
||||
if value is None:
|
||||
self._property_expiration_sec = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "expiration_sec", six.integer_types)
|
||||
self._property_expiration_sec = value
|
||||
|
||||
|
||||
class LoginResponse(Response):
|
||||
"""
|
||||
Response of auth.login endpoint.
|
||||
|
||||
:param token: Token string
|
||||
:type token: str
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.1"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'token': {'description': 'Token string', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, token=None, **kwargs):
|
||||
super(LoginResponse, self).__init__(**kwargs)
|
||||
self.token = token
|
||||
|
||||
@schema_property('token')
|
||||
def token(self):
|
||||
return self._property_token
|
||||
|
||||
@token.setter
|
||||
def token(self, value):
|
||||
if value is None:
|
||||
self._property_token = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "token", six.string_types)
|
||||
self._property_token = value
|
||||
|
||||
|
||||
class LogoutRequest(Request):
|
||||
"""
|
||||
Removes the authentication cookie from the current session
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.2"
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class LogoutResponse(Response):
|
||||
"""
|
||||
Response of auth.logout endpoint.
|
||||
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.2"
|
||||
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Request):
|
||||
"""
|
||||
Revokes (and deletes) a set (key, secret) of credentials for
|
||||
the authenticated user.
|
||||
|
||||
:param access_key: Credentials key
|
||||
:type access_key: str
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.1"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'required': ['key_id'],
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, **kwargs):
|
||||
super(RevokeCredentialsRequest, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
|
||||
class RevokeCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.revoke_credentials endpoint.
|
||||
|
||||
:param revoked: Number of credentials revoked
|
||||
:type revoked: int
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.1"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'revoked': {
|
||||
'description': 'Number of credentials revoked',
|
||||
'enum': [0, 1],
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, revoked=None, **kwargs):
|
||||
super(RevokeCredentialsResponse, self).__init__(**kwargs)
|
||||
self.revoked = revoked
|
||||
|
||||
@schema_property('revoked')
|
||||
def revoked(self):
|
||||
return self._property_revoked
|
||||
|
||||
@revoked.setter
|
||||
def revoked(self, value):
|
||||
if value is None:
|
||||
self._property_revoked = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "revoked", six.integer_types)
|
||||
self._property_revoked = value
|
||||
|
||||
|
||||
|
||||
|
||||
response_mapping = {
|
||||
LoginRequest: LoginResponse,
|
||||
LogoutRequest: LogoutResponse,
|
||||
CreateCredentialsRequest: CreateCredentialsResponse,
|
||||
GetCredentialsRequest: GetCredentialsResponse,
|
||||
RevokeCredentialsRequest: RevokeCredentialsResponse,
|
||||
EditUserRequest: EditUserResponse,
|
||||
}
|
||||
194
trains_agent/backend_api/services/v2_4/debug.py
Normal file
194
trains_agent/backend_api/services/v2_4/debug.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
debug service
|
||||
|
||||
Debugging utilities
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class ApiexRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ApiexResponse(Response):
|
||||
"""
|
||||
Response of debug.apiex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoRequest(Request):
|
||||
"""
|
||||
Return request data
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoResponse(Response):
|
||||
"""
|
||||
Response of debug.echo endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class ExRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ExResponse(Response):
|
||||
"""
|
||||
Response of debug.ex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingRequest(Request):
|
||||
"""
|
||||
Return a message. Does not require authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingResponse(Response):
|
||||
"""
|
||||
Response of debug.ping endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
class PingAuthRequest(Request):
|
||||
"""
|
||||
Return a message. Requires authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingAuthResponse(Response):
|
||||
"""
|
||||
Response of debug.ping_auth endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingAuthResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
response_mapping = {
|
||||
EchoRequest: EchoResponse,
|
||||
PingRequest: PingResponse,
|
||||
PingAuthRequest: PingAuthResponse,
|
||||
ApiexRequest: ApiexResponse,
|
||||
ExRequest: ExResponse,
|
||||
}
|
||||
2198
trains_agent/backend_api/services/v2_4/queues.py
Normal file
2198
trains_agent/backend_api/services/v2_4/queues.py
Normal file
File diff suppressed because it is too large
Load Diff
6643
trains_agent/backend_api/services/v2_4/tasks.py
Normal file
6643
trains_agent/backend_api/services/v2_4/tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
2368
trains_agent/backend_api/services/v2_4/workers.py
Normal file
2368
trains_agent/backend_api/services/v2_4/workers.py
Normal file
File diff suppressed because it is too large
Load Diff
7
trains_agent/backend_api/session/__init__.py
Normal file
7
trains_agent/backend_api/session/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .session import Session
|
||||
from .datamodel import DataModel, NonStrictDataModel, schema_property, StringEnum
|
||||
from .request import Request, BatchRequest, CompoundRequest
|
||||
from .response import Response
|
||||
from .token_manager import TokenManager
|
||||
from .errors import TimeoutExpiredError, ResultNotReadyError
|
||||
from .callresult import CallResult
|
||||
8
trains_agent/backend_api/session/apimodel.py
Normal file
8
trains_agent/backend_api/session/apimodel.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .datamodel import DataModel
|
||||
|
||||
|
||||
class ApiModel(DataModel):
|
||||
""" API-related data model """
|
||||
_service = None
|
||||
_action = None
|
||||
_version = None
|
||||
131
trains_agent/backend_api/session/callresult.py
Normal file
131
trains_agent/backend_api/session/callresult.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from ...backend_api.utils import get_response_cls
|
||||
|
||||
from .response import ResponseMeta, Response
|
||||
from .errors import ResultNotReadyError, TimeoutExpiredError
|
||||
|
||||
|
||||
class CallResult(object):
|
||||
@property
|
||||
def meta(self):
|
||||
return self.__meta
|
||||
|
||||
@property
|
||||
def response(self):
|
||||
return self.__response
|
||||
|
||||
@property
|
||||
def response_data(self):
|
||||
return self.__response_data
|
||||
|
||||
@property
|
||||
def async_accepted(self):
|
||||
return self.meta.result_code == 202
|
||||
|
||||
@property
|
||||
def request_cls(self):
|
||||
return self.__request_cls
|
||||
|
||||
def __init__(self, meta, response=None, response_data=None, request_cls=None, session=None):
|
||||
assert isinstance(meta, ResponseMeta)
|
||||
if response and not isinstance(response, Response):
|
||||
raise ValueError('response should be an instance of %s' % Response.__name__)
|
||||
elif response_data and not isinstance(response_data, dict):
|
||||
raise TypeError('data should be an instance of {}'.format(dict.__name__))
|
||||
|
||||
self.__meta = meta
|
||||
self.__response = response
|
||||
self.__request_cls = request_cls
|
||||
self.__session = session
|
||||
self.__async_result = None
|
||||
|
||||
if response_data is not None:
|
||||
self.__response_data = response_data
|
||||
elif response is not None:
|
||||
try:
|
||||
self.__response_data = response.to_dict()
|
||||
except AttributeError:
|
||||
raise TypeError('response should be an instance of {}'.format(Response.__name__))
|
||||
else:
|
||||
self.__response_data = None
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, res, request_cls=None, logger=None, service=None, action=None, session=None):
|
||||
""" From requests result """
|
||||
response_cls = get_response_cls(request_cls)
|
||||
try:
|
||||
data = res.json()
|
||||
except ValueError:
|
||||
service = service or (request_cls._service if request_cls else 'unknown')
|
||||
action = action or (request_cls._action if request_cls else 'unknown')
|
||||
return cls(request_cls=request_cls, meta=ResponseMeta.from_raw_data(
|
||||
status_code=res.status_code, text=res.text, endpoint='%(service)s.%(action)s' % locals()))
|
||||
if 'meta' not in data:
|
||||
raise ValueError('Missing meta section in response payload')
|
||||
try:
|
||||
meta = ResponseMeta(**data['meta'])
|
||||
# TODO: validate meta?
|
||||
# meta.validate()
|
||||
except Exception as ex:
|
||||
raise ValueError('Failed parsing meta section in response payload (data=%s, error=%s)' % (data, ex))
|
||||
|
||||
response = None
|
||||
response_data = None
|
||||
try:
|
||||
response_data = data.get('data', {})
|
||||
if response_cls:
|
||||
response = response_cls(**response_data)
|
||||
# TODO: validate response?
|
||||
# response.validate()
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning('Failed parsing response: %s' % str(e))
|
||||
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
|
||||
|
||||
def ok(self):
|
||||
return self.meta.result_code == 200
|
||||
|
||||
def ready(self):
|
||||
if not self.async_accepted:
|
||||
return True
|
||||
session = self.__session
|
||||
res = session.send_request(service='async', action='result', json=dict(id=self.meta.id), async_enable=False)
|
||||
if res.status_code != session._async_status_code:
|
||||
self.__async_result = CallResult.from_result(res=res, request_cls=self.request_cls, logger=session._logger)
|
||||
return True
|
||||
|
||||
def result(self):
|
||||
if not self.async_accepted:
|
||||
return self
|
||||
if self.__async_result is None:
|
||||
raise ResultNotReadyError(self._format_msg('Timeout expired'), call_id=self.meta.id)
|
||||
return self.__async_result
|
||||
|
||||
def wait(self, timeout=None, poll_interval=5, verbose=False):
|
||||
if not self.async_accepted:
|
||||
return self
|
||||
session = self.__session
|
||||
poll_interval = max(1, poll_interval)
|
||||
remaining = max(0, timeout) if timeout else sys.maxsize
|
||||
while remaining > 0:
|
||||
if not self.ready():
|
||||
# Still pending, log and continue
|
||||
if verbose and session._logger:
|
||||
progress = ('waiting forever'
|
||||
if timeout is False
|
||||
else '%.1f/%.1f seconds remaining' % (remaining, float(timeout or 0)))
|
||||
session._logger.info('Waiting for asynchronous call %s (%s)'
|
||||
% (self.request_cls.__name__, progress))
|
||||
time.sleep(poll_interval)
|
||||
remaining -= poll_interval
|
||||
continue
|
||||
# We've got something (good or bad, we don't know), create a call result and return
|
||||
return self.result()
|
||||
|
||||
# Timeout expired, return the asynchronous call's result (we've got nothing better to report)
|
||||
raise TimeoutExpiredError(self._format_msg('Timeout expired'), call_id=self.meta.id)
|
||||
|
||||
def _format_msg(self, msg):
|
||||
return msg + ' for call %s (%s)' % (self.request_cls.__name__, self.meta.id)
|
||||
1
trains_agent/backend_api/session/client/__init__.py
Normal file
1
trains_agent/backend_api/session/client/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .client import APIClient, StrictSession, APIError
|
||||
530
trains_agent/backend_api/session/client/client.py
Normal file
530
trains_agent/backend_api/session/client/client.py
Normal file
@@ -0,0 +1,530 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import abc
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from functools import reduce, wraps, WRAPPER_ASSIGNMENTS
|
||||
from importlib import import_module
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from types import ModuleType
|
||||
from typing import Dict, Text, Tuple, Type, Any, Sequence
|
||||
|
||||
import six
|
||||
from ... import services as api_services
|
||||
from ....backend_api.session import CallResult
|
||||
from ....backend_api.session import Session, Request as APIRequest
|
||||
from ....backend_api.session.response import ResponseMeta
|
||||
from ....backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR
|
||||
|
||||
SERVICE_TO_ENTITY_CLASS_NAMES = {"storage": "StorageItem"}
|
||||
|
||||
|
||||
def entity_class_name(service):
|
||||
# type: (ModuleType) -> Text
|
||||
service_name = api_entity_name(service)
|
||||
return SERVICE_TO_ENTITY_CLASS_NAMES.get(service_name.lower(), service_name)
|
||||
|
||||
|
||||
def api_entity_name(service):
|
||||
return module_name(service).rstrip("s")
|
||||
|
||||
|
||||
@six.python_2_unicode_compatible
|
||||
class APIError(Exception):
|
||||
"""
|
||||
Class for representing an API error.
|
||||
|
||||
self.data - ``dict`` of all returned JSON data
|
||||
self.code - HTTP response code
|
||||
self.subcode - server response subcode
|
||||
self.codes - (self.code, self.subcode) tuple
|
||||
self.message - result message sent from server
|
||||
"""
|
||||
|
||||
def __init__(self, response, extra_info=None):
|
||||
"""
|
||||
Create a new APIError from a server response
|
||||
"""
|
||||
super(APIError, self).__init__()
|
||||
self._response = response # type: CallResult
|
||||
self.extra_info = extra_info
|
||||
self.data = response.response_data # type: Dict
|
||||
self.meta = response.meta # type: ResponseMeta
|
||||
self.code = response.meta.result_code # type: int
|
||||
self.subcode = response.meta.result_subcode # type: int
|
||||
self.message = response.meta.result_msg # type: Text
|
||||
self.codes = (self.code, self.subcode) # type: Tuple[int, int]
|
||||
|
||||
def get_traceback(self):
|
||||
"""
|
||||
Return server traceback for error, or None if doesn't exist.
|
||||
"""
|
||||
try:
|
||||
return self.meta.error_stack
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
message = "{}: ".format(type(self).__name__)
|
||||
if self.extra_info:
|
||||
message += "{}: ".format(self.extra_info)
|
||||
if not self.meta:
|
||||
message += "no meta available"
|
||||
return message
|
||||
if not self.code:
|
||||
message += "no error code available"
|
||||
return message
|
||||
message += "code {0.code}".format(self)
|
||||
if self.subcode:
|
||||
message += "/{.subcode}".format(self)
|
||||
if self.message:
|
||||
message += ": {.message}".format(self)
|
||||
return message
|
||||
|
||||
|
||||
class StrictSession(Session):
|
||||
|
||||
"""
|
||||
Session that raises exceptions on errors, and be configured with explicit ``config_file`` path.
|
||||
"""
|
||||
|
||||
def __init__(self, config_file=None, initialize_logging=False, *args, **kwargs):
|
||||
"""
|
||||
:param config_file: configuration file to use, else use the default
|
||||
:type config_file: Path | Text
|
||||
"""
|
||||
|
||||
def init():
|
||||
super(StrictSession, self).__init__(
|
||||
initialize_logging=initialize_logging, *args, **kwargs
|
||||
)
|
||||
|
||||
if not config_file:
|
||||
init()
|
||||
return
|
||||
|
||||
original = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None)
|
||||
try:
|
||||
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = str(config_file)
|
||||
init()
|
||||
finally:
|
||||
if original is None:
|
||||
os.environ.pop(LOCAL_CONFIG_FILE_OVERRIDE_VAR, None)
|
||||
else:
|
||||
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = original
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
result = super(StrictSession, self).send(request, *args, **kwargs)
|
||||
if not result.ok():
|
||||
raise APIError(result)
|
||||
if not result.response:
|
||||
raise APIError(result, extra_info="Invalid response")
|
||||
return result
|
||||
|
||||
|
||||
class Response(object):
|
||||
|
||||
"""
|
||||
Proxy object for API result data.
|
||||
Exposes "meta" of the original result.
|
||||
"""
|
||||
|
||||
def __init__(self, result, dest=None):
|
||||
"""
|
||||
:param result: result of endpoint call
|
||||
:type result: CallResult
|
||||
:param dest: if all of a response's data is contained in one field, use that field
|
||||
:type dest: Text
|
||||
"""
|
||||
self._result = result
|
||||
response = getattr(result, "response", result)
|
||||
if dest:
|
||||
response = getattr(response, dest)
|
||||
self.response = response
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.response, attr)
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
return self._result.meta
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.response)
|
||||
|
||||
def __dir__(self):
|
||||
fields = [
|
||||
name
|
||||
for name in dir(self.response)
|
||||
if isinstance(getattr(type(self.response), name, None), property)
|
||||
]
|
||||
return list(set(chain(super(Response, self).__dir__(), fields)) - {"response"})
|
||||
|
||||
|
||||
@six.python_2_unicode_compatible
|
||||
class TableResponse(Response):
|
||||
|
||||
"""
|
||||
Representation of result containing an array of entities
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service, # type: Service
|
||||
entity, # type: Type[entity]
|
||||
fields=None, # type: Sequence[Text]
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param service: service of entity
|
||||
:param entity: class representing entity
|
||||
:param fields: entity attributes requested by client
|
||||
"""
|
||||
super(TableResponse, self).__init__(*args, **kwargs)
|
||||
self.service = service
|
||||
self.entity = entity
|
||||
self.fields = fields or ("id", "name")
|
||||
self.response = [entity(service, item) for item in self]
|
||||
|
||||
def __repr__(self, fields=None):
|
||||
return self._format_table(fields=fields)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def _format_table(self, fields=None):
|
||||
"""
|
||||
Display <fields> attributes of each element in a table
|
||||
:param fields:
|
||||
"""
|
||||
|
||||
def getter(obj, attr):
|
||||
result = reduce(
|
||||
lambda x, name: x if x is None else getattr(x, name, None),
|
||||
attr.split("."),
|
||||
obj,
|
||||
)
|
||||
return "" if result is None else result
|
||||
|
||||
fields = fields or self.fields
|
||||
from trains_agent.helper.base import create_table
|
||||
return create_table(
|
||||
(tuple(getter(item, attr) for attr in fields) for item in self),
|
||||
titles=fields, headers=True,
|
||||
)
|
||||
|
||||
def display(self, fields=None):
|
||||
print(self._format_table(fields=fields))
|
||||
|
||||
def where(self, predicate=None, **kwargs):
|
||||
"""
|
||||
Filter items.
|
||||
<predicate> is a callable from a single item to a boolean. Items for which <predicate> is True will be returned.
|
||||
Keyword arguments are interpreted as attribute equivalence, meaning:
|
||||
>>> datasets.where(name='foo')
|
||||
will return only datasets whose name is "foo".
|
||||
|
||||
Giving more than one condition (predicate and keyword arguments) establishes an "and" relation.
|
||||
"""
|
||||
|
||||
def compare_enum(x, y):
|
||||
return x == y or isinstance(x, Enum) and x.value == y
|
||||
|
||||
return TableResponse(
|
||||
self.service,
|
||||
self.entity,
|
||||
self.fields,
|
||||
[
|
||||
item
|
||||
for item in self
|
||||
if (not predicate or predicate(item))
|
||||
and all(
|
||||
compare_enum(getattr(item, key), value)
|
||||
for key, value in kwargs.items()
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.response[item]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.response)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.response)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Entity(object):
|
||||
|
||||
"""
|
||||
Represent a server object.
|
||||
Enables calls like:
|
||||
>>> entity = client.service.get_by_id(entity_id)
|
||||
>>> entity.action(**kwargs)
|
||||
instead of:
|
||||
>>> client.service.action(id=entity_id, **kwargs)
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def entity_name(self): # type: () -> Text
|
||||
"""
|
||||
Singular name of entity
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def get_by_id_request(self): # type: () -> Type[APIRequest]
|
||||
"""
|
||||
get_by_id request class
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, service, data):
|
||||
self._service = service
|
||||
self.data = getattr(data, self.entity_name, data)
|
||||
self.__doc__ = self.data.__doc__
|
||||
|
||||
def fetch(self):
|
||||
"""
|
||||
Update the entity data from the server.
|
||||
"""
|
||||
result = self._service.session.send(self.get_by_id_request(self.data.id))
|
||||
self.data = getattr(result.response, self.entity_name)
|
||||
|
||||
def _get_default_kwargs(self):
|
||||
return {self.entity_name: self.data.id}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""
|
||||
Inject the entity's ID to the method call.
|
||||
All missing properties are assumed to be functions.
|
||||
"""
|
||||
try:
|
||||
return getattr(self.data, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
func = getattr(self._service, attr)
|
||||
|
||||
@wrap_request_class(func)
|
||||
def new_func(*args, **kwargs):
|
||||
kwargs = dict(self._get_default_kwargs(), **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
def __dir__(self):
|
||||
"""
|
||||
Add ``self._service``'s methods to ``dir()`` result.
|
||||
"""
|
||||
try:
|
||||
dir_ = super(Entity, self).__dir__
|
||||
except AttributeError:
|
||||
base = self.__dict__
|
||||
else:
|
||||
base = dir_()
|
||||
return list(set(base).union(dir(self._service), dir(self.data)))
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Display entity type, ID, and - if available - name.
|
||||
"""
|
||||
parts = (type(self).__name__, ": ", "id={}".format(self.data.id))
|
||||
try:
|
||||
parts += (", ", 'name="{}"'.format(self.data.name))
|
||||
except AttributeError:
|
||||
pass
|
||||
return "<{}>".format("".join(parts))
|
||||
|
||||
|
||||
def wrap_request_class(cls):
|
||||
return wraps(cls, assigned=WRAPPER_ASSIGNMENTS + ("from_dict",))
|
||||
|
||||
|
||||
def make_action(service, request_cls):
|
||||
action = request_cls._action
|
||||
try:
|
||||
get_by_id_request = service.GetByIdRequest
|
||||
except AttributeError:
|
||||
get_by_id_request = None
|
||||
|
||||
wrap = wrap_request_class(request_cls)
|
||||
|
||||
if action not in ["get_all", "get_all_ex", "get_by_id", "create"]:
|
||||
|
||||
@wrap
|
||||
def new_func(self, *args, **kwargs):
|
||||
return Response(self.session.send(request_cls(*args, **kwargs)))
|
||||
|
||||
new_func.__name__ = new_func.__qualname__ = action
|
||||
return new_func
|
||||
|
||||
entity_name = api_entity_name(service)
|
||||
class_name = entity_class_name(service).capitalize()
|
||||
properties = {
|
||||
"__module__": __name__,
|
||||
"entity_name": entity_name.lower(),
|
||||
"get_by_id_request": get_by_id_request,
|
||||
}
|
||||
entity = type(str(class_name), (Entity,), properties)
|
||||
|
||||
if action == "get_by_id":
|
||||
|
||||
@wrap
|
||||
def get(self, *args, **kwargs):
|
||||
return entity(
|
||||
self, self.session.send(request_cls(*args, **kwargs)).response
|
||||
)
|
||||
|
||||
elif action == "create":
|
||||
|
||||
@wrap
|
||||
def get(self, *args, **kwargs):
|
||||
return entity(
|
||||
self,
|
||||
Namespace(
|
||||
id=self.session.send(request_cls(*args, **kwargs)).response.id
|
||||
),
|
||||
)
|
||||
|
||||
elif action in ["get_all", "get_all_ex"]:
|
||||
dest = service.response_mapping[request_cls]._get_data_props().popitem()[0]
|
||||
|
||||
@wrap
|
||||
def get(self, *args, **kwargs):
|
||||
return TableResponse(
|
||||
service=self,
|
||||
entity=entity,
|
||||
result=self.session.send(request_cls(*args, **kwargs)),
|
||||
dest=dest,
|
||||
fields=kwargs.pop("only_fields", None),
|
||||
)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
get.__name__ = get.__qualname__ = action
|
||||
|
||||
return get
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Service(object):
|
||||
|
||||
"""
|
||||
Superclass for action-grouping classes.
|
||||
"""
|
||||
|
||||
name = abc.abstractproperty()
|
||||
__doc__ = abc.abstractproperty()
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
|
||||
def get_requests(service):
|
||||
return OrderedDict(
|
||||
(key, value)
|
||||
for key, value in sorted(vars(service).items(), key=itemgetter(0))
|
||||
if isinstance(value, type) and issubclass(value, APIRequest) and value._action
|
||||
)
|
||||
|
||||
|
||||
def make_service_class(module):
|
||||
# type: (...) -> Type[Service]
|
||||
"""
|
||||
Create a service class from service module.
|
||||
"""
|
||||
properties = OrderedDict(
|
||||
[
|
||||
("__module__", __name__),
|
||||
("__doc__", module.__doc__),
|
||||
("name", module_name(module)),
|
||||
]
|
||||
)
|
||||
properties.update(
|
||||
(f.__name__, f)
|
||||
for f in (
|
||||
make_action(module, value) for key, value in get_requests(module).items()
|
||||
)
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
return type(str(module_name(module)), (Service,), properties)
|
||||
|
||||
|
||||
def module_name(module):
|
||||
try:
|
||||
module = module.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
base_name = module.split(".")[-1]
|
||||
return "".join(s.capitalize() for s in base_name.split("_"))
|
||||
|
||||
|
||||
class Version(Entity):
|
||||
entity_name = "version"
|
||||
get_by_id_request = None
|
||||
|
||||
def fetch(self):
|
||||
try:
|
||||
published = self.data.status == "published"
|
||||
except AttributeError:
|
||||
published = False
|
||||
|
||||
self.data = self._service.get_versions(
|
||||
dataset=self.dataset, only_published=published, versions=[self.id]
|
||||
)[0].data
|
||||
|
||||
def _get_default_kwargs(self):
|
||||
return dict(
|
||||
super(Version, self)._get_default_kwargs(), **{"dataset": self.data.dataset}
|
||||
)
|
||||
|
||||
|
||||
class APIClient(object):
|
||||
|
||||
auth = None # type: Any
|
||||
debug = None # type: Any
|
||||
queues = None # type: Any
|
||||
tasks = None # type: Any
|
||||
workers = None # type: Any
|
||||
|
||||
def __init__(self, session=None, api_version=None):
|
||||
self.session = session or StrictSession()
|
||||
|
||||
def import_(*args, **kwargs):
|
||||
try:
|
||||
return import_module(*args, **kwargs)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if api_version:
|
||||
api_version = "v{}".format(str(api_version).replace(".", "_"))
|
||||
services = OrderedDict(
|
||||
(name, mod)
|
||||
for name, mod in (
|
||||
(
|
||||
name,
|
||||
import_(".".join((api_services.__name__, api_version, name))),
|
||||
)
|
||||
for name in api_services.__all__
|
||||
)
|
||||
if mod
|
||||
)
|
||||
else:
|
||||
services = OrderedDict(
|
||||
(name, getattr(api_services, name)) for name in api_services.__all__
|
||||
)
|
||||
self.__dict__.update(
|
||||
dict(
|
||||
{
|
||||
name: make_service_class(module)(self.session)
|
||||
for name, module in services.items()
|
||||
},
|
||||
)
|
||||
)
|
||||
148
trains_agent/backend_api/session/datamodel.py
Normal file
148
trains_agent/backend_api/session/datamodel.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import keyword
|
||||
|
||||
import enum
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
import jsonschema
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
|
||||
|
||||
def format_date(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return str(obj)
|
||||
|
||||
|
||||
class SchemaProperty(property):
|
||||
def __init__(self, name=None, *args, **kwargs):
|
||||
super(SchemaProperty, self).__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
|
||||
def setter(self, fset):
|
||||
return type(self)(self.name, self.fget, fset, self.fdel, self.__doc__)
|
||||
|
||||
|
||||
def schema_property(name):
|
||||
def init(*args, **kwargs):
|
||||
return SchemaProperty(name, *args, **kwargs)
|
||||
return init
|
||||
|
||||
|
||||
class DataModel(object):
|
||||
""" Data Model"""
|
||||
_schema = None
|
||||
_data_props_list = None
|
||||
|
||||
@classmethod
|
||||
def _get_data_props(cls):
|
||||
props = cls._data_props_list
|
||||
if props is None:
|
||||
props = {}
|
||||
for c in cls.__mro__:
|
||||
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
|
||||
if isinstance(v, property)})
|
||||
cls._data_props_list = props
|
||||
return props.copy()
|
||||
|
||||
@classmethod
|
||||
def _to_base_type(cls, value):
|
||||
if isinstance(value, DataModel):
|
||||
return value.to_dict()
|
||||
elif isinstance(value, enum.Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [cls._to_base_type(model) for model in value]
|
||||
return value
|
||||
|
||||
def to_dict(self, only=None, except_=None):
|
||||
prop_values = {v: getattr(self, k) for k, v in self._get_data_props().items()}
|
||||
return {
|
||||
k: self._to_base_type(v)
|
||||
for k, v in prop_values.items()
|
||||
if v is not None and (not only or k in only) and (not except_ or k not in except_)
|
||||
}
|
||||
|
||||
def validate(self, schema=None):
|
||||
jsonschema.validate(
|
||||
self.to_dict(),
|
||||
schema or self._schema,
|
||||
types=dict(array=(list, tuple), integer=six.integer_types),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<{}.{}: {}>'.format(
|
||||
self.__module__.split('.')[-1],
|
||||
type(self).__name__,
|
||||
json.dumps(
|
||||
self.to_dict(),
|
||||
indent=4,
|
||||
default=format_date,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def assert_isinstance(value, field_name, expected, is_array=False):
|
||||
if not is_array:
|
||||
if not isinstance(value, expected):
|
||||
raise TypeError("Expected %s of type %s, got %s" % (field_name, expected, type(value).__name__))
|
||||
return
|
||||
|
||||
if not all(isinstance(x, expected) for x in value):
|
||||
raise TypeError(
|
||||
"Expected %s of type list[%s], got %s" % (
|
||||
field_name,
|
||||
expected,
|
||||
", ".join(set(type(x).__name__ for x in value)),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_key(prop_key):
|
||||
if keyword.iskeyword(prop_key):
|
||||
prop_key += '_'
|
||||
return prop_key.replace('.', '__')
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dct, strict=False):
|
||||
"""
|
||||
Create an instance from a dictionary while ignoring unnecessary keys
|
||||
"""
|
||||
allowed_keys = cls._get_data_props().values()
|
||||
invalid_keys = set(dct).difference(allowed_keys)
|
||||
if strict and invalid_keys:
|
||||
raise ValueError("Invalid keys %s" % tuple(invalid_keys))
|
||||
return cls(**{cls.normalize_key(key): value for key, value in dct.items() if key not in invalid_keys})
|
||||
|
||||
|
||||
class UnusedKwargsWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
||||
class NonStrictDataModelMixin(object):
|
||||
"""
|
||||
NonStrictDataModelMixin
|
||||
|
||||
:summary: supplies an __init__ method that warns about unused keywords
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
# unexpected = [key for key in kwargs if not key.startswith('_')]
|
||||
# if unexpected:
|
||||
# message = '{}: unused keyword argument(s) {}' \
|
||||
# .format(type(self).__name__, unexpected)
|
||||
# warnings.warn(message, UnusedKwargsWarning)
|
||||
|
||||
# ignore extra data warnings
|
||||
pass
|
||||
|
||||
|
||||
class NonStrictDataModel(DataModel, NonStrictDataModelMixin):
|
||||
pass
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
10
trains_agent/backend_api/session/defs.py
Normal file
10
trains_agent/backend_api/session/defs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from ...backend_config import EnvEntry
|
||||
|
||||
|
||||
ENV_HOST = EnvEntry("TRAINS_API_HOST", "TRAINS_API_HOST")
|
||||
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "TRAINS_WEB_HOST")
|
||||
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "TRAINS_FILES_HOST")
|
||||
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
17
trains_agent/backend_api/session/errors.py
Normal file
17
trains_agent/backend_api/session/errors.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class SessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncError(SessionError):
|
||||
def __init__(self, msg, *args, **kwargs):
|
||||
super(AsyncError, self).__init__(msg, *args)
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class TimeoutExpiredError(SessionError):
|
||||
pass
|
||||
|
||||
|
||||
class ResultNotReadyError(SessionError):
|
||||
pass
|
||||
76
trains_agent/backend_api/session/request.py
Normal file
76
trains_agent/backend_api/session/request.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import abc
|
||||
|
||||
import jsonschema
|
||||
import six
|
||||
|
||||
from .apimodel import ApiModel
|
||||
from .datamodel import DataModel
|
||||
|
||||
|
||||
class Request(ApiModel):
|
||||
_method = 'get'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if kwargs:
|
||||
raise ValueError('Unsupported keyword arguments: %s' % ', '.join(kwargs.keys()))
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class BatchRequest(Request):
|
||||
|
||||
_batched_request_cls = abc.abstractproperty()
|
||||
|
||||
_schema_errors = (jsonschema.SchemaError, jsonschema.ValidationError, jsonschema.FormatError,
|
||||
jsonschema.RefResolutionError)
|
||||
|
||||
def __init__(self, requests, validate_requests=False, allow_raw_requests=True, **kwargs):
|
||||
super(BatchRequest, self).__init__(**kwargs)
|
||||
self._validate_requests = validate_requests
|
||||
self._allow_raw_requests = allow_raw_requests
|
||||
self._property_requests = None
|
||||
self.requests = requests
|
||||
|
||||
@property
|
||||
def requests(self):
|
||||
return self._property_requests
|
||||
|
||||
@requests.setter
|
||||
def requests(self, value):
|
||||
assert issubclass(self._batched_request_cls, Request)
|
||||
assert isinstance(value, (list, tuple))
|
||||
if not self._allow_raw_requests:
|
||||
if any(isinstance(x, dict) for x in value):
|
||||
value = [self._batched_request_cls(**x) if isinstance(x, dict) else x for x in value]
|
||||
assert all(isinstance(x, self._batched_request_cls) for x in value)
|
||||
|
||||
self._property_requests = value
|
||||
|
||||
def validate(self):
|
||||
if not self._validate_requests or self._allow_raw_requests:
|
||||
return
|
||||
for i, req in enumerate(self.requests):
|
||||
try:
|
||||
req.validate()
|
||||
except (jsonschema.SchemaError, jsonschema.ValidationError,
|
||||
jsonschema.FormatError, jsonschema.RefResolutionError) as e:
|
||||
raise Exception('Validation error in batch item #%d: %s' % (i, str(e)))
|
||||
|
||||
def get_json(self):
|
||||
return [r if isinstance(r, dict) else r.to_dict() for r in self.requests]
|
||||
|
||||
|
||||
class CompoundRequest(Request):
|
||||
_item_prop_name = 'item'
|
||||
|
||||
def _get_item(self):
|
||||
item = getattr(self, self._item_prop_name, None)
|
||||
if item is None:
|
||||
raise ValueError('Item property is empty or missing')
|
||||
assert isinstance(item, DataModel)
|
||||
return item
|
||||
|
||||
def to_dict(self):
|
||||
return self._get_item().to_dict()
|
||||
|
||||
def validate(self):
|
||||
return self._get_item().validate(self._schema)
|
||||
57
trains_agent/backend_api/session/response.py
Normal file
57
trains_agent/backend_api/session/response.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import requests
|
||||
|
||||
import six
|
||||
import jsonmodels.models
|
||||
import jsonmodels.fields
|
||||
import jsonmodels.errors
|
||||
|
||||
from .apimodel import ApiModel
|
||||
from .datamodel import NonStrictDataModelMixin
|
||||
|
||||
|
||||
class FloatOrStringField(jsonmodels.fields.BaseField):
|
||||
|
||||
"""String field."""
|
||||
|
||||
types = (float, six.string_types,)
|
||||
|
||||
|
||||
class Response(ApiModel, NonStrictDataModelMixin):
|
||||
pass
|
||||
|
||||
|
||||
class _ResponseEndpoint(jsonmodels.models.Base):
|
||||
name = jsonmodels.fields.StringField()
|
||||
requested_version = FloatOrStringField()
|
||||
actual_version = FloatOrStringField()
|
||||
|
||||
|
||||
class ResponseMeta(jsonmodels.models.Base):
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self._is_valid
|
||||
|
||||
@classmethod
|
||||
def from_raw_data(cls, status_code, text, endpoint=None):
|
||||
return cls(is_valid=False, result_code=status_code, result_subcode=0, result_msg=text,
|
||||
endpoint=_ResponseEndpoint(name=(endpoint or 'unknown')))
|
||||
|
||||
def __init__(self, is_valid=True, **kwargs):
|
||||
super(ResponseMeta, self).__init__(**kwargs)
|
||||
self._is_valid = is_valid
|
||||
|
||||
id = jsonmodels.fields.StringField(required=True)
|
||||
trx = jsonmodels.fields.StringField(required=True)
|
||||
endpoint = jsonmodels.fields.EmbeddedField([_ResponseEndpoint], required=True)
|
||||
result_code = jsonmodels.fields.IntField(required=True)
|
||||
result_subcode = jsonmodels.fields.IntField()
|
||||
result_msg = jsonmodels.fields.StringField(required=True)
|
||||
error_stack = jsonmodels.fields.StringField()
|
||||
|
||||
def __str__(self):
|
||||
if self.result_code == requests.codes.ok:
|
||||
return "<%d: %s/v%s>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version)
|
||||
elif self._is_valid:
|
||||
return "<%d/%d: %s/v%s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name,
|
||||
self.endpoint.actual_version, self.result_msg)
|
||||
return "<%d/%d: %s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, self.result_msg)
|
||||
551
trains_agent/backend_api/session/session.py
Normal file
551
trains_agent/backend_api/session/session.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import json as json_lib
|
||||
import sys
|
||||
import types
|
||||
from socket import gethostname
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
import six
|
||||
from pyhocon import ConfigTree
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from .callresult import CallResult
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
|
||||
from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoginError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MaxRequestSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Session(TokenManager):
|
||||
""" TRAINS API Session class. """
|
||||
|
||||
_AUTHORIZATION_HEADER = "Authorization"
|
||||
_WORKER_HEADER = "X-Trains-Worker"
|
||||
_ASYNC_HEADER = "X-Trains-Async"
|
||||
_CLIENT_HEADER = "X-Trains-Agent"
|
||||
|
||||
_async_status_code = 202
|
||||
_session_requests = 0
|
||||
_session_initial_timeout = (3.0, 10.)
|
||||
_session_timeout = (10.0, 300.)
|
||||
_write_session_data_size = 15000
|
||||
_write_session_timeout = (300.0, 300.)
|
||||
|
||||
api_version = '2.1'
|
||||
default_host = "https://demoapi.trainsai.io"
|
||||
default_web = "https://demoapp.trainsai.io"
|
||||
default_files = "https://demofiles.trainsai.io"
|
||||
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
# TODO: add requests.codes.gateway_timeout once we support async commits
|
||||
_retry_codes = [
|
||||
requests.codes.bad_gateway,
|
||||
requests.codes.service_unavailable,
|
||||
requests.codes.bandwidth_limit_exceeded,
|
||||
requests.codes.too_many_requests,
|
||||
]
|
||||
|
||||
@property
|
||||
def access_key(self):
|
||||
return self.__access_key
|
||||
|
||||
@property
|
||||
def secret_key(self):
|
||||
return self.__secret_key
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.__host
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
return self.__worker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker=None,
|
||||
api_key=None,
|
||||
secret_key=None,
|
||||
host=None,
|
||||
logger=None,
|
||||
verbose=None,
|
||||
initialize_logging=True,
|
||||
client=None,
|
||||
config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = load()
|
||||
if initialize_logging:
|
||||
self.config.initialize_logging()
|
||||
|
||||
token_expiration_threshold_sec = self.config.get(
|
||||
"auth.token_expiration_threshold_sec", 60
|
||||
)
|
||||
|
||||
super(Session, self).__init__(
|
||||
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
|
||||
)
|
||||
|
||||
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
||||
self._logger = logger
|
||||
|
||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
|
||||
)
|
||||
if not self.access_key:
|
||||
raise ValueError(
|
||||
"Missing access_key. Please set in configuration file or pass in session init."
|
||||
)
|
||||
|
||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
"Missing secret_key. Please set in configuration file or pass in session init."
|
||||
)
|
||||
|
||||
host = host or self.get_api_server_host(config=self.config)
|
||||
if not host:
|
||||
raise ValueError("host is required in init or config")
|
||||
|
||||
self.__host = host.strip("/")
|
||||
http_retries_config = self.config.get(
|
||||
"api.http.retries", ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
http_retries_config["status_forcelist"] = self._retry_codes
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
self.__worker = worker or gethostname()
|
||||
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||
if not self.__max_req_size:
|
||||
raise ValueError("missing max request size")
|
||||
|
||||
self.client = client or "api-{}".format(__version__)
|
||||
|
||||
self.refresh_token()
|
||||
|
||||
# update api version from server response
|
||||
try:
|
||||
token_dict = jwt.decode(self.token, verify=False)
|
||||
api_version = token_dict.get('api_version')
|
||||
if not api_version:
|
||||
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||
|
||||
Session.api_version = str(api_version)
|
||||
except (jwt.DecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# now setup the session reporting, so one consecutive retries will show warning
|
||||
# we do that here, so if we have problems authenticating, we see them immediately
|
||||
# notice: this is across the board warning omission
|
||||
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
method="get",
|
||||
headers=None,
|
||||
auth=None,
|
||||
data=None,
|
||||
json=None,
|
||||
refresh_token_if_unauthorized=True,
|
||||
):
|
||||
""" Internal implementation for making a raw API request.
|
||||
- Constructs the api endpoint name
|
||||
- Injects the worker id into the headers
|
||||
- Allows custom authorization using a requests auth object
|
||||
- Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this
|
||||
case (only once). This is done since permissions are embedded in the token, and addresses a case where
|
||||
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
|
||||
generate a token with the updated permissions.
|
||||
"""
|
||||
host = self.host
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._WORKER_HEADER] = self.worker
|
||||
headers[self._CLIENT_HEADER] = self.client
|
||||
|
||||
token_refreshed_on_error = False
|
||||
url = (
|
||||
"{host}/v{version}/{service}.{action}"
|
||||
if version
|
||||
else "{host}/{service}.{action}"
|
||||
).format(**locals())
|
||||
while True:
|
||||
if data and len(data) > self._write_session_data_size:
|
||||
timeout = self._write_session_timeout
|
||||
elif self._session_requests < 1:
|
||||
timeout = self._session_initial_timeout
|
||||
else:
|
||||
timeout = self._session_timeout
|
||||
res = self.__http_session.request(
|
||||
method, url, headers=headers, auth=auth, data=data, json=json, timeout=timeout)
|
||||
|
||||
if (
|
||||
refresh_token_if_unauthorized
|
||||
and res.status_code == requests.codes.unauthorized
|
||||
and not token_refreshed_on_error
|
||||
):
|
||||
# it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since
|
||||
# the last time we got the token, and try again
|
||||
self.refresh_token()
|
||||
token_refreshed_on_error = True
|
||||
# try again
|
||||
continue
|
||||
if (
|
||||
res.status_code == requests.codes.service_unavailable
|
||||
and self.config.get("api.http.wait_on_maintenance_forever", True)
|
||||
):
|
||||
self._logger.warning(
|
||||
"Service unavailable: {} is undergoing maintenance, retrying...".format(
|
||||
host
|
||||
)
|
||||
)
|
||||
continue
|
||||
break
|
||||
self._session_requests += 1
|
||||
return res
|
||||
|
||||
def add_auth_headers(self, headers):
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
return headers
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
method="get",
|
||||
headers=None,
|
||||
data=None,
|
||||
json=None,
|
||||
async_enable=False,
|
||||
):
|
||||
"""
|
||||
Send a raw API request.
|
||||
:param service: service name
|
||||
:param action: action name
|
||||
:param version: version number (default is the preconfigured api version)
|
||||
:param method: method type (default is 'get')
|
||||
:param headers: request headers (authorization and content type headers will be automatically added)
|
||||
:param json: json to send in the request body (jsonable object or builtin types construct. if used,
|
||||
content type will be application/json)
|
||||
:param data: Dictionary, bytes, or file-like object to send in the request body
|
||||
:param async_enable: whether request is asynchronous
|
||||
:return: requests Response instance
|
||||
"""
|
||||
headers = self.add_auth_headers(
|
||||
headers.copy() if headers else {}
|
||||
)
|
||||
if async_enable:
|
||||
headers[self._ASYNC_HEADER] = "1"
|
||||
return self._send_request(
|
||||
service=service,
|
||||
action=action,
|
||||
version=version,
|
||||
method=method,
|
||||
headers=headers,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
|
||||
def send_request_batch(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
headers=None,
|
||||
data=None,
|
||||
json=None,
|
||||
method="get",
|
||||
):
|
||||
"""
|
||||
Send a raw batch API request. Batch requests always use application/json-lines content type.
|
||||
:param service: service name
|
||||
:param action: action name
|
||||
:param version: version number (default is the preconfigured api version)
|
||||
:param headers: request headers (authorization and content type headers will be automatically added)
|
||||
:param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will
|
||||
be sent as a multi-line payload in the request body.
|
||||
:param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the
|
||||
request body.
|
||||
:param method: HTTP method
|
||||
:return: requests Response instance
|
||||
"""
|
||||
if not all(
|
||||
isinstance(x, (list, tuple, type(None), types.GeneratorType))
|
||||
for x in (data, json)
|
||||
):
|
||||
raise ValueError("Expecting list, tuple or generator in 'data' or 'json'")
|
||||
|
||||
if not data and not json:
|
||||
raise ValueError(
|
||||
"Missing data (data or json), batch requests are meaningless without it."
|
||||
)
|
||||
|
||||
headers = headers.copy() if headers else {}
|
||||
headers["Content-Type"] = "application/json-lines"
|
||||
|
||||
if data:
|
||||
req_data = "\n".join(data)
|
||||
else:
|
||||
req_data = "\n".join(json_lib.dumps(x) for x in json)
|
||||
|
||||
cur = 0
|
||||
results = []
|
||||
while True:
|
||||
size = self.__max_req_size
|
||||
slice = req_data[cur: cur + size]
|
||||
if not slice:
|
||||
break
|
||||
if len(slice) < size:
|
||||
# this is the remainder, no need to search for newline
|
||||
pass
|
||||
elif slice[-1] != "\n":
|
||||
# search for the last newline in order to send a coherent request
|
||||
size = slice.rfind("\n") + 1
|
||||
# readjust the slice
|
||||
slice = req_data[cur: cur + size]
|
||||
if not slice:
|
||||
raise MaxRequestSizeError('Error: {}.{} request exceeds limit {} > {} bytes'.format(
|
||||
service, action, len(req_data), self.__max_req_size))
|
||||
res = self.send_request(
|
||||
method=method,
|
||||
service=service,
|
||||
action=action,
|
||||
data=slice,
|
||||
headers=headers,
|
||||
version=version,
|
||||
)
|
||||
results.append(res)
|
||||
if res.status_code != requests.codes.ok:
|
||||
break
|
||||
cur += size
|
||||
return results
|
||||
|
||||
def validate_request(self, req_obj):
|
||||
""" Validate an API request against the current version and the request's schema """
|
||||
|
||||
try:
|
||||
# make sure we're using a compatible version for this request
|
||||
# validate the request (checks required fields and specific field version restrictions)
|
||||
validate = req_obj.validate
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
'"req_obj" parameter must be an backend_api.session.Request object'
|
||||
)
|
||||
|
||||
validate()
|
||||
|
||||
def send_async(self, req_obj):
|
||||
"""
|
||||
Asynchronously sends an API request using a request object.
|
||||
:param req_obj: The request object
|
||||
:type req_obj: Request
|
||||
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
||||
"""
|
||||
return self.send(req_obj=req_obj, async_enable=True)
|
||||
|
||||
def send(self, req_obj, async_enable=False, headers=None):
|
||||
"""
|
||||
Sends an API request using a request object.
|
||||
:param req_obj: The request object
|
||||
:type req_obj: Request
|
||||
:param async_enable: Request this method be executed in an asynchronous manner
|
||||
:param headers: Additional headers to send with request
|
||||
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
||||
"""
|
||||
self.validate_request(req_obj)
|
||||
|
||||
if isinstance(req_obj, BatchRequest):
|
||||
# TODO: support async for batch requests as well
|
||||
if async_enable:
|
||||
raise NotImplementedError(
|
||||
"Async behavior is currently not implemented for batch requests"
|
||||
)
|
||||
|
||||
json_data = req_obj.get_json()
|
||||
res = self.send_request_batch(
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
version=req_obj._version,
|
||||
json=json_data,
|
||||
method=req_obj._method,
|
||||
headers=headers,
|
||||
)
|
||||
# TODO: handle multiple results in this case
|
||||
try:
|
||||
res = next(r for r in res if r.status_code != 200)
|
||||
except StopIteration:
|
||||
# all are 200
|
||||
res = res[0]
|
||||
else:
|
||||
res = self.send_request(
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
version=req_obj._version,
|
||||
json=req_obj.to_dict(),
|
||||
method=req_obj._method,
|
||||
async_enable=async_enable,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
call_result = CallResult.from_result(
|
||||
res=res,
|
||||
request_cls=req_obj.__class__,
|
||||
logger=self._logger,
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
session=self,
|
||||
)
|
||||
|
||||
return call_result
|
||||
|
||||
@classmethod
|
||||
def get_api_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
||||
config.get("api.host", None) or cls.default_host))
|
||||
|
||||
@classmethod
|
||||
def get_app_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
|
||||
# get from config/environment
|
||||
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
|
||||
if web_host:
|
||||
return web_host
|
||||
|
||||
# return default
|
||||
host = cls.get_api_server_host(config)
|
||||
if host == cls.default_host:
|
||||
return cls.default_web
|
||||
|
||||
# compose ourselves
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||
if '://api.' in host:
|
||||
return host.replace('://api.', '://app.', 1)
|
||||
|
||||
parsed = urlparse(host)
|
||||
if parsed.port == 8008:
|
||||
return host.replace(':8008', ':8080', 1)
|
||||
|
||||
raise ValueError('Could not detect TRAINS web application server')
|
||||
|
||||
@classmethod
|
||||
def get_files_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
# get from config/environment
|
||||
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
|
||||
if files_host:
|
||||
return files_host
|
||||
|
||||
# return default
|
||||
host = cls.get_api_server_host(config)
|
||||
if host == cls.default_host:
|
||||
return cls.default_files
|
||||
|
||||
# compose ourselves
|
||||
app_host = cls.get_app_server_host(config)
|
||||
parsed = urlparse(app_host)
|
||||
if parsed.port:
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||
elif parsed.netloc.startswith('demoapp.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||
elif parsed.netloc.startswith('app.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||
else:
|
||||
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
|
||||
|
||||
return urlunparse(parsed)
|
||||
|
||||
@classmethod
|
||||
def check_min_api_version(cls, min_api_version):
|
||||
"""
|
||||
Return True if Session.api_version is greater or equal >= to min_api_version
|
||||
"""
|
||||
def version_tuple(v):
|
||||
v = tuple(map(int, (v.split("."))))
|
||||
return v + (0,) * max(0, 3 - len(v))
|
||||
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
|
||||
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
""" TokenManager abstract method implementation.
|
||||
Here we ignore the old token and simply obtain a new token.
|
||||
"""
|
||||
verbose = self._verbose and self._logger
|
||||
if verbose:
|
||||
self._logger.info(
|
||||
"Refreshing token from {} (access_key={}, exp={})".format(
|
||||
self.host, self.access_key, exp
|
||||
)
|
||||
)
|
||||
|
||||
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||
res = None
|
||||
try:
|
||||
data = {"expiration_sec": exp} if exp else {}
|
||||
res = self._send_request(
|
||||
service="auth",
|
||||
action="login",
|
||||
auth=auth,
|
||||
json=data,
|
||||
refresh_token_if_unauthorized=False,
|
||||
)
|
||||
try:
|
||||
resp = res.json()
|
||||
except ValueError:
|
||||
resp = {}
|
||||
if res.status_code != 200:
|
||||
msg = resp.get("meta", {}).get("result_msg", res.reason)
|
||||
raise LoginError(
|
||||
"Failed getting token (error {} from {}): {}".format(
|
||||
res.status_code, self.host, msg
|
||||
)
|
||||
)
|
||||
if verbose:
|
||||
self._logger.info("Received new token")
|
||||
return resp["data"]["token"]
|
||||
except LoginError:
|
||||
six.reraise(*sys.exc_info())
|
||||
except KeyError as ex:
|
||||
# check if this is a misconfigured api server (getting 200 without the data section)
|
||||
if res and res.status_code == 200:
|
||||
raise ValueError('It seems *api_server* is misconfigured. '
|
||||
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
|
||||
else:
|
||||
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||
"exception: {}".format(res, ex))
|
||||
except Exception as ex:
|
||||
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||
|
||||
def __str__(self):
|
||||
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
||||
self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5)
|
||||
)
|
||||
95
trains_agent/backend_api/session/token_manager.py
Normal file
95
trains_agent/backend_api/session/token_manager.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from time import time
|
||||
|
||||
import jwt
|
||||
import six
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class TokenManager(object):
|
||||
|
||||
@property
|
||||
def token_expiration_threshold_sec(self):
|
||||
return self.__token_expiration_threshold_sec
|
||||
|
||||
@token_expiration_threshold_sec.setter
|
||||
def token_expiration_threshold_sec(self, value):
|
||||
self.__token_expiration_threshold_sec = value
|
||||
|
||||
@property
|
||||
def req_token_expiration_sec(self):
|
||||
""" Token expiration sec requested when refreshing token """
|
||||
return self.__req_token_expiration_sec
|
||||
|
||||
@req_token_expiration_sec.setter
|
||||
def req_token_expiration_sec(self, value):
|
||||
assert isinstance(value, (type(None), int))
|
||||
self.__req_token_expiration_sec = value
|
||||
|
||||
@property
|
||||
def token_expiration_sec(self):
|
||||
return self.__token_expiration_sec
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._get_token()
|
||||
|
||||
@property
|
||||
def raw_token(self):
|
||||
return self.__token
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token=None,
|
||||
req_token_expiration_sec=None,
|
||||
token_history=None,
|
||||
token_expiration_threshold_sec=60,
|
||||
**kwargs
|
||||
):
|
||||
super(TokenManager, self).__init__()
|
||||
assert isinstance(token_history, (type(None), dict))
|
||||
self.token_expiration_threshold_sec = token_expiration_threshold_sec
|
||||
self.req_token_expiration_sec = req_token_expiration_sec
|
||||
self._set_token(token)
|
||||
|
||||
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
|
||||
if token:
|
||||
try:
|
||||
exp = exp or self._get_token_exp(token)
|
||||
if at_least_sec:
|
||||
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
|
||||
else:
|
||||
at_least_sec = self.token_expiration_threshold_sec
|
||||
return max(0, (exp - time() - at_least_sec))
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def _get_token_exp(cls, token):
|
||||
""" Get token expiration time. If not present, assume forever """
|
||||
return jwt.decode(token, verify=False).get('exp', sys.maxsize)
|
||||
|
||||
def _set_token(self, token):
|
||||
if token:
|
||||
self.__token = token
|
||||
self.__token_expiration_sec = self._get_token_exp(token)
|
||||
else:
|
||||
self.__token = None
|
||||
self.__token_expiration_sec = 0
|
||||
|
||||
def get_token_valid_period_sec(self):
|
||||
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
|
||||
|
||||
def _get_token(self):
|
||||
if self.get_token_valid_period_sec() <= 0:
|
||||
self.refresh_token()
|
||||
return self.__token
|
||||
|
||||
@abstractmethod
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
pass
|
||||
|
||||
def refresh_token(self):
|
||||
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))
|
||||
132
trains_agent/backend_api/utils.py
Normal file
132
trains_agent/backend_api/utils.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import logging
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util import Retry
|
||||
from urllib3 import PoolManager
|
||||
import six
|
||||
|
||||
from .session.defs import ENV_HOST_VERIFY_CERT
|
||||
|
||||
if six.PY3:
|
||||
from functools import lru_cache
|
||||
elif six.PY2:
|
||||
# python 2 support
|
||||
from backports.functools_lru_cache import lru_cache
|
||||
|
||||
|
||||
__disable_certificate_verification_warning = 0
|
||||
|
||||
|
||||
class _RetryFilter(logging.Filter):
|
||||
last_instance = None
|
||||
|
||||
def __init__(self, total, warning_after=5):
|
||||
super(_RetryFilter, self).__init__()
|
||||
self.total = total
|
||||
self.display_warning_after = warning_after
|
||||
_RetryFilter.last_instance = self
|
||||
|
||||
def filter(self, record):
|
||||
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
|
||||
left = (record.args[0].total, record.args[0].connect, record.args[0].read,
|
||||
record.args[0].redirect, record.args[0].status)
|
||||
left = [l for l in left if isinstance(l, int)]
|
||||
if left:
|
||||
retry_left = max(left) - min(left)
|
||||
return retry_left >= self.display_warning_after
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
|
||||
for l in ('urllib3.connectionpool', 'requests.packages.urllib3.connectionpool'):
|
||||
urllib3_log = logging.getLogger(l)
|
||||
if urllib3_log:
|
||||
urllib3_log.removeFilter(_RetryFilter.last_instance)
|
||||
urllib3_log.addFilter(_RetryFilter(total_retries, display_warning_after))
|
||||
|
||||
|
||||
class TLSv1HTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||
self.poolmanager = PoolManager(num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1_2)
|
||||
|
||||
|
||||
def get_http_session_with_retry(
|
||||
total=0,
|
||||
connect=None,
|
||||
read=None,
|
||||
redirect=None,
|
||||
status=None,
|
||||
status_forcelist=None,
|
||||
backoff_factor=0,
|
||||
backoff_max=None,
|
||||
pool_connections=None,
|
||||
pool_maxsize=None,
|
||||
config=None):
|
||||
if not config:
|
||||
config = {}
|
||||
global __disable_certificate_verification_warning
|
||||
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
|
||||
raise ValueError('Bad configuration. All retry count values must be null or int')
|
||||
|
||||
if status_forcelist and not all(isinstance(x, int) for x in status_forcelist):
|
||||
raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints')
|
||||
|
||||
pool_maxsize = (
|
||||
pool_maxsize
|
||||
if pool_maxsize is not None
|
||||
else config.get('api.http.pool_maxsize', 512)
|
||||
)
|
||||
|
||||
pool_connections = (
|
||||
pool_connections
|
||||
if pool_connections is not None
|
||||
else config.get('api.http.pool_connections', 512)
|
||||
)
|
||||
|
||||
session = requests.Session()
|
||||
|
||||
if backoff_max is not None:
|
||||
Retry.BACKOFF_MAX = backoff_max
|
||||
|
||||
retry = Retry(
|
||||
total=total, connect=connect, read=read, redirect=redirect, status=status,
|
||||
status_forcelist=status_forcelist, backoff_factor=backoff_factor)
|
||||
|
||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
# update verify host certificate
|
||||
session.verify = ENV_HOST_VERIFY_CERT.get(default=config.get('api.verify_certificate', True))
|
||||
if not session.verify and __disable_certificate_verification_warning < 2:
|
||||
# show warning
|
||||
__disable_certificate_verification_warning += 1
|
||||
logging.getLogger('TRAINS').warning(
|
||||
msg='InsecureRequestWarning: Certificate verification is disabled! Adding '
|
||||
'certificate verification is strongly advised. See: '
|
||||
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings')
|
||||
# make sure we only do not see the warning
|
||||
import urllib3
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
except Exception:
|
||||
pass
|
||||
return session
|
||||
|
||||
|
||||
def get_response_cls(request_cls):
|
||||
""" Extract a request's response class using the mapping found in the module defining the request's service """
|
||||
for req_cls in request_cls.mro():
|
||||
module = sys.modules[req_cls.__module__]
|
||||
if hasattr(module, 'action_mapping'):
|
||||
return module.action_mapping[(request_cls._action, request_cls._version)][1]
|
||||
elif hasattr(module, 'response_mapping'):
|
||||
return module.response_mapping[req_cls]
|
||||
raise TypeError('no response class!')
|
||||
4
trains_agent/backend_config/__init__.py
Normal file
4
trains_agent/backend_config/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .defs import Environment
|
||||
from .config import Config, ConfigEntry
|
||||
from .errors import ConfigurationError
|
||||
from .environment import EnvEntry
|
||||
340
trains_agent/backend_config/config.py
Normal file
340
trains_agent/backend_config/config.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from fnmatch import fnmatch
|
||||
from os.path import expanduser
|
||||
from typing import Any
|
||||
|
||||
import pyhocon
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
RecursiveGrammarException,
|
||||
ParseSyntaxException,
|
||||
)
|
||||
|
||||
from .defs import (
|
||||
Environment,
|
||||
DEFAULT_CONFIG_FOLDER,
|
||||
LOCAL_CONFIG_PATHS,
|
||||
ENV_CONFIG_PATHS,
|
||||
LOCAL_CONFIG_FILES,
|
||||
LOCAL_CONFIG_FILE_OVERRIDE_VAR,
|
||||
ENV_CONFIG_PATH_OVERRIDE_VAR,
|
||||
)
|
||||
from .defs import is_config_file
|
||||
from .entry import Entry, NotSet
|
||||
from .errors import ConfigurationError
|
||||
from .log import initialize as initialize_log, logger
|
||||
from .utils import get_options
|
||||
|
||||
try:
|
||||
from typing import Text
|
||||
except ImportError:
|
||||
# windows conda-less hack
|
||||
Text = Any
|
||||
|
||||
|
||||
log = logger(__file__)
|
||||
|
||||
|
||||
class ConfigEntry(Entry):
|
||||
logger = None
|
||||
|
||||
def __init__(self, config, *keys, **kwargs):
|
||||
# type: (Config, Text, Any) -> None
|
||||
super(ConfigEntry, self).__init__(*keys, **kwargs)
|
||||
self.config = config
|
||||
|
||||
def _get(self, key):
|
||||
# type: (Text) -> Any
|
||||
return self.config.get(key, NotSet)
|
||||
|
||||
def error(self, message):
|
||||
# type: (Text) -> None
|
||||
log.error(message.capitalize())
|
||||
|
||||
|
||||
class Config(object):
|
||||
"""
|
||||
Represents a server configuration.
|
||||
If watch=True, will watch configuration folders for changes and reload itself.
|
||||
NOTE: will not watch folders that were created after initialization.
|
||||
"""
|
||||
|
||||
# used in place of None in Config.get as default value because None is a valid value
|
||||
_MISSING = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_folder=None,
|
||||
env=None,
|
||||
verbose=True,
|
||||
relative_to=None,
|
||||
app=None,
|
||||
is_server=False,
|
||||
**_
|
||||
):
|
||||
self._app = app
|
||||
self._verbose = verbose
|
||||
self._folder_name = config_folder or DEFAULT_CONFIG_FOLDER
|
||||
self._roots = []
|
||||
self._config = ConfigTree()
|
||||
self._env = env or os.environ.get("TRAINS_ENV", Environment.default)
|
||||
self.config_paths = set()
|
||||
self.is_server = is_server
|
||||
|
||||
if self._verbose:
|
||||
print("Config env:%s" % str(self._env))
|
||||
|
||||
if not self._env:
|
||||
raise ValueError(
|
||||
"Missing environment in either init of environment variable"
|
||||
)
|
||||
if self._env not in get_options(Environment):
|
||||
raise ValueError("Invalid environment %s" % env)
|
||||
if relative_to is not None:
|
||||
self.load_relative_to(relative_to)
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self.roots[0] if self.roots else None
|
||||
|
||||
@property
|
||||
def roots(self):
|
||||
return self._roots
|
||||
|
||||
@roots.setter
|
||||
def roots(self, value):
|
||||
self._roots = value
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self._env
|
||||
|
||||
def logger(self, path=None):
|
||||
return logger(path)
|
||||
|
||||
def load_relative_to(self, *module_paths):
|
||||
def normalize(p):
|
||||
return Path(os.path.abspath(str(p))).with_name(self._folder_name)
|
||||
|
||||
self.roots = list(map(normalize, module_paths))
|
||||
self.reload()
|
||||
|
||||
def _reload(self):
|
||||
env = self._env
|
||||
config = self._config.copy()
|
||||
|
||||
if self.is_server:
|
||||
env_config_paths = ENV_CONFIG_PATHS
|
||||
else:
|
||||
env_config_paths = []
|
||||
|
||||
env_config_path_override = os.environ.get(ENV_CONFIG_PATH_OVERRIDE_VAR)
|
||||
if env_config_path_override:
|
||||
env_config_paths = [expanduser(env_config_path_override)]
|
||||
|
||||
# merge configuration from root and other environment config paths
|
||||
if self.roots or env_config_paths:
|
||||
config = functools.reduce(
|
||||
lambda cfg, path: ConfigTree.merge_configs(
|
||||
cfg,
|
||||
self._read_recursive_for_env(path, env, verbose=self._verbose),
|
||||
copy_trees=True,
|
||||
),
|
||||
self.roots + env_config_paths,
|
||||
config,
|
||||
)
|
||||
|
||||
# merge configuration from local configuration paths
|
||||
if LOCAL_CONFIG_PATHS:
|
||||
config = functools.reduce(
|
||||
lambda cfg, path: ConfigTree.merge_configs(
|
||||
cfg, self._read_recursive(path, verbose=self._verbose), copy_trees=True
|
||||
),
|
||||
LOCAL_CONFIG_PATHS,
|
||||
config,
|
||||
)
|
||||
|
||||
local_config_files = LOCAL_CONFIG_FILES
|
||||
local_config_override = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR)
|
||||
if local_config_override:
|
||||
local_config_files = [expanduser(local_config_override)]
|
||||
|
||||
# merge configuration from local configuration files
|
||||
if local_config_files:
|
||||
config = functools.reduce(
|
||||
lambda cfg, file_path: ConfigTree.merge_configs(
|
||||
cfg,
|
||||
self._read_single_file(file_path, verbose=self._verbose),
|
||||
copy_trees=True,
|
||||
),
|
||||
local_config_files,
|
||||
config,
|
||||
)
|
||||
|
||||
config["env"] = env
|
||||
return config
|
||||
|
||||
def replace(self, config):
|
||||
self._config = config
|
||||
|
||||
def reload(self):
|
||||
self.replace(self._reload())
|
||||
|
||||
def initialize_logging(self):
|
||||
logging_config = self._config.get("logging", None)
|
||||
if not logging_config:
|
||||
return False
|
||||
|
||||
# handle incomplete file handlers
|
||||
deleted = []
|
||||
handlers = logging_config.get("handlers", {})
|
||||
for name, handler in list(handlers.items()):
|
||||
cls = handler.get("class", None)
|
||||
is_file = cls and "FileHandler" in cls
|
||||
if cls is None or (is_file and "filename" not in handler):
|
||||
deleted.append(name)
|
||||
del handlers[name]
|
||||
elif is_file:
|
||||
file = Path(handler.get("filename"))
|
||||
if not file.is_file():
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
file.touch()
|
||||
|
||||
# remove dependency in deleted handlers
|
||||
root_logger = logging_config.get("root", None)
|
||||
loggers = list(logging_config.get("loggers", {}).values()) + (
|
||||
[root_logger] if root_logger else []
|
||||
)
|
||||
for logger in loggers:
|
||||
handlers = logger.get("handlers", None)
|
||||
if not handlers:
|
||||
continue
|
||||
logger["handlers"] = [h for h in handlers if h not in deleted]
|
||||
|
||||
extra = None
|
||||
if self._app:
|
||||
extra = {"app": self._app}
|
||||
initialize_log(logging_config, extra=extra)
|
||||
return True
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return self._config[key]
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getattr__(self, key):
|
||||
c = self.__getattribute__('_config')
|
||||
if key.split('.')[0] in c:
|
||||
try:
|
||||
return c[key]
|
||||
except Exception:
|
||||
return None
|
||||
return getattr(c, key)
|
||||
|
||||
def get(self, key, default=_MISSING):
|
||||
value = self._config.get(key, default)
|
||||
if value is self._MISSING and not default:
|
||||
raise KeyError(
|
||||
"Unable to find value for key '{}' and default value was not provided.".format(
|
||||
key
|
||||
)
|
||||
)
|
||||
return value
|
||||
|
||||
def to_dict(self):
|
||||
return self._config.as_plain_ordered_dict()
|
||||
|
||||
def as_json(self):
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
def _read_recursive_for_env(self, root_path_str, env, verbose=True):
|
||||
root_path = Path(root_path_str)
|
||||
if root_path.exists():
|
||||
default_config = self._read_recursive(
|
||||
root_path / Environment.default, verbose=verbose
|
||||
)
|
||||
if (root_path / env) != (root_path / Environment.default):
|
||||
env_config = self._read_recursive(
|
||||
root_path / env, verbose=verbose
|
||||
) # None is ok, will return empty config
|
||||
config = ConfigTree.merge_configs(default_config, env_config, True)
|
||||
else:
|
||||
config = default_config
|
||||
else:
|
||||
config = ConfigTree()
|
||||
|
||||
return config
|
||||
|
||||
def _read_recursive(self, conf_root, verbose=True):
|
||||
conf = ConfigTree()
|
||||
if not conf_root:
|
||||
return conf
|
||||
conf_root = Path(conf_root)
|
||||
|
||||
if not conf_root.exists():
|
||||
if verbose:
|
||||
print("No config in %s" % str(conf_root))
|
||||
return conf
|
||||
|
||||
if verbose:
|
||||
print("Loading config from %s" % str(conf_root))
|
||||
for root, dirs, files in os.walk(str(conf_root)):
|
||||
|
||||
rel_dir = str(Path(root).relative_to(conf_root))
|
||||
if rel_dir == ".":
|
||||
rel_dir = ""
|
||||
prefix = rel_dir.replace("/", ".")
|
||||
|
||||
for filename in files:
|
||||
if not is_config_file(filename):
|
||||
continue
|
||||
|
||||
if prefix != "":
|
||||
key = prefix + "." + Path(filename).stem
|
||||
else:
|
||||
key = Path(filename).stem
|
||||
|
||||
file_path = str(Path(root) / filename)
|
||||
|
||||
conf.put(key, self._read_single_file(file_path, verbose=verbose))
|
||||
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def _read_single_file(file_path, verbose=True):
|
||||
if not file_path or not Path(file_path).is_file():
|
||||
return ConfigTree()
|
||||
|
||||
if verbose:
|
||||
print("Loading config from file %s" % file_path)
|
||||
|
||||
try:
|
||||
return pyhocon.ConfigFactory.parse_file(file_path)
|
||||
except ParseSyntaxException as ex:
|
||||
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
|
||||
file_path, ex
|
||||
)
|
||||
six.reraise(
|
||||
ConfigurationError,
|
||||
ConfigurationError(msg, file_path=file_path),
|
||||
sys.exc_info()[2],
|
||||
)
|
||||
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
|
||||
msg = "Failed parsing {0} ({1.__class__.__name__}): {1}".format(
|
||||
file_path, ex
|
||||
)
|
||||
six.reraise(ConfigurationError, ConfigurationError(msg), sys.exc_info()[2])
|
||||
except Exception as ex:
|
||||
print("Failed loading %s: %s" % (file_path, ex))
|
||||
raise
|
||||
53
trains_agent/backend_config/converters.py
Normal file
53
trains_agent/backend_config/converters.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import base64
|
||||
from distutils.util import strtobool
|
||||
from typing import Union, Optional, Any, TypeVar, Callable, Tuple
|
||||
|
||||
import six
|
||||
|
||||
try:
|
||||
from typing import Text
|
||||
except ImportError:
|
||||
# windows conda-less hack
|
||||
Text = Any
|
||||
|
||||
|
||||
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
|
||||
|
||||
|
||||
def base64_to_text(value):
|
||||
# type: (Any) -> Text
|
||||
return base64.b64decode(value).decode("utf-8")
|
||||
|
||||
|
||||
def text_to_bool(value):
|
||||
# type: (Text) -> bool
|
||||
return bool(strtobool(value))
|
||||
|
||||
|
||||
def any_to_bool(value):
|
||||
# type: (Optional[Union[int, float, Text]]) -> bool
|
||||
if isinstance(value, six.text_type):
|
||||
return text_to_bool(value)
|
||||
return bool(value)
|
||||
|
||||
|
||||
def or_(*converters, **kwargs):
|
||||
# type: (ConverterType, Tuple[Exception, ...]) -> ConverterType
|
||||
"""
|
||||
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
|
||||
for which a set of exceptions is ignored (and the original value is returned)
|
||||
:param converters: A converter callable
|
||||
:param exceptions: A tuple of exception types to ignore
|
||||
"""
|
||||
# noinspection PyUnresolvedReferences
|
||||
exceptions = kwargs.get("exceptions", (ValueError, TypeError))
|
||||
|
||||
def wrapper(value):
|
||||
for converter in converters:
|
||||
try:
|
||||
return converter(value)
|
||||
except exceptions:
|
||||
pass
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
53
trains_agent/backend_config/defs.py
Normal file
53
trains_agent/backend_config/defs.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from os.path import expanduser
|
||||
from pathlib2 import Path
|
||||
|
||||
ENV_VAR = 'TRAINS_ENV'
|
||||
""" Name of system environment variable that can be used to specify the config environment name """
|
||||
|
||||
|
||||
DEFAULT_CONFIG_FOLDER = 'config'
|
||||
""" Default config folder to search for when loading relative to a given path """
|
||||
|
||||
|
||||
ENV_CONFIG_PATHS = [
|
||||
]
|
||||
|
||||
|
||||
""" Environment-related config paths """
|
||||
|
||||
|
||||
LOCAL_CONFIG_PATHS = [
|
||||
# '/etc/opt/trains', # used by servers for docker-generated configuration
|
||||
# expanduser('~/.trains/config'),
|
||||
]
|
||||
""" Local config paths, not related to environment """
|
||||
|
||||
|
||||
LOCAL_CONFIG_FILES = [
|
||||
expanduser('~/trains.conf'), # used for workstation configuration (end-users, workers)
|
||||
]
|
||||
""" Local config files (not paths) """
|
||||
|
||||
|
||||
LOCAL_CONFIG_FILE_OVERRIDE_VAR = 'TRAINS_CONFIG_FILE'
|
||||
""" Local config file override environment variable. If this is set, no other local config files will be used. """
|
||||
|
||||
|
||||
ENV_CONFIG_PATH_OVERRIDE_VAR = 'TRAINS_CONFIG_PATH'
|
||||
"""
|
||||
Environment-related config path override environment variable. If this is set, no other env config path will be used.
|
||||
"""
|
||||
|
||||
|
||||
class Environment(object):
|
||||
""" Supported environment names """
|
||||
default = 'default'
|
||||
demo = 'demo'
|
||||
local = 'local'
|
||||
|
||||
|
||||
CONFIG_FILE_EXTENSION = '.conf'
|
||||
|
||||
|
||||
def is_config_file(path):
|
||||
return Path(path).suffix == CONFIG_FILE_EXTENSION
|
||||
103
trains_agent/backend_config/entry.py
Normal file
103
trains_agent/backend_config/entry.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import abc
|
||||
from typing import Optional, Any, Tuple, Callable, Dict
|
||||
|
||||
import six
|
||||
|
||||
from .converters import any_to_bool
|
||||
|
||||
try:
|
||||
from typing import Text
|
||||
except ImportError:
|
||||
# windows conda-less hack
|
||||
Text = Any
|
||||
|
||||
|
||||
NotSet = object()
|
||||
|
||||
Converter = Callable[[Any], Any]
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Entry(object):
|
||||
"""
|
||||
Configuration entry definition
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_conversions(cls):
|
||||
# type: () -> Dict[Any, Converter]
|
||||
return {
|
||||
bool: any_to_bool,
|
||||
six.text_type: lambda s: six.text_type(s).strip(),
|
||||
}
|
||||
|
||||
def __init__(self, key, *more_keys, **kwargs):
|
||||
# type: (Text, Text, Any) -> None
|
||||
"""
|
||||
:param key: Entry's key (at least one).
|
||||
:param more_keys: More alternate keys for this entry.
|
||||
:param type: Value type. If provided, will be used choosing a default conversion or
|
||||
(if none exists) for casting the environment value.
|
||||
:param converter: Value converter. If provided, will be used to convert the environment value.
|
||||
:param default: Default value. If provided, will be used as the default value on calls to get() and get_pair()
|
||||
in case no value is found for any key and no specific default value was provided in the call.
|
||||
Default value is None.
|
||||
:param help: Help text describing this entry
|
||||
"""
|
||||
self.keys = (key,) + more_keys
|
||||
self.type = kwargs.pop("type", six.text_type)
|
||||
self.converter = kwargs.pop("converter", None)
|
||||
self.default = kwargs.pop("default", None)
|
||||
self.help = kwargs.pop("help", None)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self.keys[0]
|
||||
|
||||
def convert(self, value, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Any]
|
||||
converter = converter or self.converter
|
||||
if not converter:
|
||||
converter = self.default_conversions().get(self.type, self.type)
|
||||
return converter(value)
|
||||
|
||||
def get_pair(self, default=NotSet, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Tuple[Text, Any]]
|
||||
for key in self.keys:
|
||||
value = self._get(key)
|
||||
if value is NotSet:
|
||||
continue
|
||||
try:
|
||||
value = self.convert(value, converter)
|
||||
except Exception as ex:
|
||||
self.error("invalid value {key}={value}: {ex}".format(**locals()))
|
||||
break
|
||||
return key, value
|
||||
result = self.default if default is NotSet else default
|
||||
return self.key, result
|
||||
|
||||
def get(self, default=NotSet, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Any]
|
||||
return self.get_pair(default=default, converter=converter)[1]
|
||||
|
||||
def set(self, value):
|
||||
# type: (Any, Any) -> (Text, Any)
|
||||
key, _ = self.get_pair(default=None, converter=None)
|
||||
self._set(key, str(value))
|
||||
|
||||
def _set(self, key, value):
|
||||
# type: (Text, Text) -> None
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get(self, key):
|
||||
# type: (Text) -> Any
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def error(self, message):
|
||||
# type: (Text) -> None
|
||||
pass
|
||||
25
trains_agent/backend_config/environment.py
Normal file
25
trains_agent/backend_config/environment.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from os import getenv, environ
|
||||
|
||||
from .converters import text_to_bool
|
||||
from .entry import Entry, NotSet
|
||||
|
||||
|
||||
class EnvEntry(Entry):
|
||||
@classmethod
|
||||
def default_conversions(cls):
|
||||
conversions = super(EnvEntry, cls).default_conversions().copy()
|
||||
conversions[bool] = text_to_bool
|
||||
return conversions
|
||||
|
||||
def _get(self, key):
|
||||
value = getenv(key, "").strip()
|
||||
return value or NotSet
|
||||
|
||||
def _set(self, key, value):
|
||||
environ[key] = value
|
||||
|
||||
def __str__(self):
|
||||
return "env:{}".format(super(EnvEntry, self).__str__())
|
||||
|
||||
def error(self, message):
|
||||
print("Environment configuration: {}".format(message))
|
||||
5
trains_agent/backend_config/errors.py
Normal file
5
trains_agent/backend_config/errors.py
Normal file
@@ -0,0 +1,5 @@
|
||||
class ConfigurationError(Exception):
|
||||
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super(ConfigurationError, self).__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
30
trains_agent/backend_config/log.py
Normal file
30
trains_agent/backend_config/log.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging.config
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def logger(path=None):
|
||||
name = "trains"
|
||||
if path:
|
||||
p = Path(path)
|
||||
module = (p.parent if p.stem.startswith('_') else p).stem
|
||||
name = "trains.%s" % module
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def initialize(logging_config=None, extra=None):
|
||||
if extra is not None:
|
||||
from logging import Logger
|
||||
|
||||
class _Logger(Logger):
|
||||
__extra = extra.copy()
|
||||
|
||||
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
|
||||
extra = extra or {}
|
||||
extra.update(self.__extra)
|
||||
super(_Logger, self)._log(level, msg, args, exc_info=exc_info, extra=extra, **kwargs)
|
||||
|
||||
Logger.manager.loggerClass = _Logger
|
||||
|
||||
if logging_config is not None:
|
||||
logging.config.dictConfig(dict(logging_config))
|
||||
9
trains_agent/backend_config/utils.py
Normal file
9
trains_agent/backend_config/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
return {k: v for k, v in vars(cls).items() if not k.startswith('_')}
|
||||
|
||||
|
||||
def get_options(cls):
|
||||
""" get options from an enum-like class (members represent enumeration key/value) """
|
||||
return get_items(cls).values()
|
||||
3
trains_agent/commands/__init__.py
Normal file
3
trains_agent/commands/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from .worker import Worker
|
||||
401
trains_agent/commands/base.py
Normal file
401
trains_agent/commands/base.py
Normal file
@@ -0,0 +1,401 @@
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import copy
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from functools import wraps
|
||||
from operator import attrgetter
|
||||
from traceback import print_exc
|
||||
from typing import Text
|
||||
|
||||
from trains_agent.helper.console import ListFormatter, print_text
|
||||
from trains_agent.helper.dicts import filter_keys
|
||||
|
||||
try:
|
||||
from typing import NoReturn
|
||||
except ImportError:
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
import six
|
||||
from trains_agent.backend_api import services
|
||||
|
||||
from trains_agent.errors import APIError, CommandFailedError
|
||||
from trains_agent.helper.base import Singleton, return_list, print_parameters, dump_yaml, load_yaml, error, warning
|
||||
from trains_agent.interface.base import ObjectID
|
||||
from trains_agent.session import Session
|
||||
|
||||
|
||||
class NameResolutionError(CommandFailedError):
|
||||
|
||||
def __init__(self, message, suggestions=''):
|
||||
super(NameResolutionError, self).__init__(message)
|
||||
self.message = message
|
||||
self.suggestions = suggestions
|
||||
|
||||
def __str__(self):
|
||||
return self.message + self.suggestions
|
||||
|
||||
|
||||
def resolve_names(func):
|
||||
def safe_resolve(command, arg):
|
||||
try:
|
||||
result = command._resolve_name(arg.name, arg.service)
|
||||
return result, None
|
||||
except NameResolutionError:
|
||||
return arg.name, sys.exc_info()
|
||||
|
||||
def _resolve_single(command, arg):
|
||||
if isinstance(arg, ObjectID):
|
||||
return command._resolve_name(arg.name, arg.service)
|
||||
elif isinstance(arg, (list, tuple)) and all(isinstance(x, ObjectID) for x in arg):
|
||||
result = [safe_resolve(command, x) for x in arg]
|
||||
if len(result) == 1:
|
||||
name, ex = result[0]
|
||||
if ex:
|
||||
six.reraise(*ex)
|
||||
return [name]
|
||||
for _, ex in result:
|
||||
if ex:
|
||||
command.warning(ex[1].message)
|
||||
return [name for (name, _) in result]
|
||||
return arg
|
||||
|
||||
@wraps(func)
|
||||
def newfunc(self, *args, **kwargs):
|
||||
args = [_resolve_single(self, arg) for arg in args]
|
||||
kwargs = {key: _resolve_single(self, value) for key, value in kwargs.items()}
|
||||
return func(self, *args, **kwargs)
|
||||
return newfunc
|
||||
|
||||
|
||||
class BaseCommandSection(object):
|
||||
"""
|
||||
Base class for command sections which do not interact with the allegro API.
|
||||
Has basic utilities for user interaction.
|
||||
"""
|
||||
warning = staticmethod(warning)
|
||||
error = staticmethod(error)
|
||||
|
||||
@staticmethod
|
||||
def log(message, *args):
|
||||
print("trains-agent: {}".format(message % args))
|
||||
|
||||
@classmethod
|
||||
def exit(cls, message, code=1): # type: (Text, int) -> NoReturn
|
||||
cls.error(message)
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
@six.add_metaclass(Singleton)
|
||||
class ServiceCommandSection(BaseCommandSection):
|
||||
"""
|
||||
Base class for command sections which interact with the allegro API.
|
||||
Contains API functionality which is common across services.
|
||||
"""
|
||||
|
||||
_worker_name = None
|
||||
MAX_SUGGESTIONS = 10
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ServiceCommandSection, self).__init__()
|
||||
self._session = self._get_session(*args, **kwargs)
|
||||
self._list_formatter = ListFormatter(self.service)
|
||||
|
||||
@staticmethod
|
||||
def _get_session(*args, **kwargs):
|
||||
return Session(*args, **kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def service(self):
|
||||
""" The name of the REST service used by this command """
|
||||
pass
|
||||
|
||||
def get(self, endpoint, *args, **kwargs):
|
||||
return self._session.get(service=self.service, action=endpoint, *args, **kwargs)
|
||||
|
||||
def post(self, endpoint, *args, **kwargs):
|
||||
return self._session.post(service=self.service, action=endpoint, *args, **kwargs)
|
||||
|
||||
def get_with_act_as(self, endpoint, *args, **kwargs):
|
||||
return self._session.get_with_act_as(service=self.service, action=endpoint, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.service.title()
|
||||
|
||||
@property
|
||||
def name_single(self):
|
||||
return self.name.rstrip('s')
|
||||
|
||||
@property
|
||||
def service_single(self):
|
||||
return self.service.rstrip('s')
|
||||
|
||||
@resolve_names
|
||||
def __info(self, id=None, yaml=None, **kwargs):
|
||||
ids = return_list(id)
|
||||
if not ids:
|
||||
return
|
||||
|
||||
yaml_dump = {}
|
||||
|
||||
for i in ids:
|
||||
get_fields = {self.service_single: i}
|
||||
try:
|
||||
info = self.get('get_by_id', **get_fields)
|
||||
yaml_dump[i] = info[self.service_single]
|
||||
except APIError:
|
||||
self.error('Failed retrieving info for {} {}'.format(self.service_single, i))
|
||||
|
||||
self.output_info(yaml_dump, yaml_path=yaml, **kwargs)
|
||||
return yaml_dump
|
||||
|
||||
@resolve_names
|
||||
def _info(self, *args, **kwargs):
|
||||
self.__info(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def output_info(entries, quiet=False, yaml_path=None, **_):
|
||||
if not quiet and entries:
|
||||
print_parameters(entries, indent=4)
|
||||
|
||||
if yaml_path:
|
||||
print('Storing entries to [{}]'.format(yaml_path))
|
||||
dump_yaml(entries, yaml_path)
|
||||
|
||||
@staticmethod
|
||||
def _make_query(json, table, sort=None, projection_from_table=False, extra_fields=None):
|
||||
json = json.copy()
|
||||
if isinstance(table, six.string_types):
|
||||
table = table.split('#')
|
||||
|
||||
if extra_fields:
|
||||
table.extend(extra_fields)
|
||||
|
||||
if projection_from_table:
|
||||
json['only_fields'] = table
|
||||
|
||||
if sort:
|
||||
# does nothing if 'order_by' is not in get_fields
|
||||
json['order_by'] = sort.split('#')[0]
|
||||
return json, table
|
||||
|
||||
def _get_all(self, endpoint, json, retpoint=None):
|
||||
return self.get(endpoint, **json).get(retpoint or self.service, [])
|
||||
|
||||
@resolve_names
|
||||
def _update(self,
|
||||
endpoint='update',
|
||||
send_diff=False,
|
||||
quiet=False,
|
||||
primary_key='id',
|
||||
override=None,
|
||||
model_desc=None,
|
||||
yaml=None,
|
||||
**kwargs):
|
||||
|
||||
if not yaml and primary_key not in kwargs:
|
||||
raise ValueError('Update must supply either yaml file or %s-id' % self.service_single)
|
||||
|
||||
data_entries = {}
|
||||
original_data_entries = {}
|
||||
if yaml:
|
||||
data_entries = load_yaml(yaml)
|
||||
|
||||
if send_diff or (not yaml and primary_key in kwargs):
|
||||
i = kwargs.get(primary_key) or next(iter(data_entries))
|
||||
original_info = self.__info(id=i, quiet=True)[i]
|
||||
|
||||
if send_diff:
|
||||
original_data_entries[i] = copy.deepcopy(original_info)
|
||||
if yaml and i not in data_entries:
|
||||
if len(data_entries) > 1:
|
||||
raise ValueError(
|
||||
'Error: yaml file [%s] contains more than one task id' % yaml)
|
||||
first_key = next(iter(data_entries))
|
||||
if first_key != i:
|
||||
if kwargs.get('force'):
|
||||
if not quiet:
|
||||
print('Warning: overriding yaml task id [%s] with id=%s' % (first_key, i))
|
||||
else:
|
||||
raise ValueError(
|
||||
'Error: yaml task id [%s] != id [%s], use --force to override' % (first_key, i))
|
||||
data_entries = {i: data_entries[first_key]}
|
||||
data_entries[i][primary_key] = i
|
||||
elif not yaml:
|
||||
data_entries[i] = kwargs
|
||||
|
||||
if model_desc:
|
||||
first_key = next(iter(data_entries))
|
||||
with open(model_desc) as f:
|
||||
proto_data = f.read()
|
||||
info = data_entries[first_key]
|
||||
info['execution']['model_desc']['prototxt'] = proto_data
|
||||
|
||||
if override:
|
||||
first_key = next(iter(data_entries))
|
||||
info = data_entries[first_key]
|
||||
for p in override:
|
||||
key, val = p.split('=') if isinstance(p, six.string_types) else p
|
||||
info_key = info
|
||||
keys = key.split('.')
|
||||
for k in keys[:-1]:
|
||||
if not info_key.get(k):
|
||||
info_key[k] = dict()
|
||||
info_key = info_key[k]
|
||||
info_key[keys[-1]] = val
|
||||
|
||||
# always make sure tags is a list of strings
|
||||
# split string to tokens ':'
|
||||
# examples tags='auto_generated:draft'
|
||||
if info.get('tags') is not None:
|
||||
# remove empty strings from list
|
||||
info['tags'] = [t for t in info['tags'].split(':') if t]
|
||||
|
||||
# send only change set
|
||||
if send_diff:
|
||||
# only send the values that changed
|
||||
for i, info in data_entries.items():
|
||||
org_info = original_data_entries[i]
|
||||
out_info = {}
|
||||
recursive_diff(org_info, info, out_info)
|
||||
data_entries[i] = out_info
|
||||
|
||||
for i, info in data_entries.items():
|
||||
if not info or len(info) == 0 or list(info) == [primary_key]:
|
||||
if not quiet:
|
||||
print('Skipping: nothing to update for %s id [%s]' % (self.service_single, i))
|
||||
continue
|
||||
if not quiet:
|
||||
print('Updating %s id [%s]' % (self.service_single, i))
|
||||
info[self.service_single] = i
|
||||
result = self.get(endpoint, **info)
|
||||
if not result['updated']:
|
||||
raise ValueError('Failed updating %s id [%s]' % (self.service_single, i))
|
||||
if not quiet:
|
||||
print('%s [%s] updated fields: %s' % (self.name_single, i, result.get('fields', '')))
|
||||
|
||||
@resolve_names
|
||||
def remove(self, ids, **kwargs):
|
||||
return self._apply_command(
|
||||
request_cls=getattr(services, self.service).DeleteRequest,
|
||||
object_ids=ids,
|
||||
response_validation_field='deleted',
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _apply_command(self, request_cls, object_ids, response_validation_field=None, **kwargs):
|
||||
object_ids = return_list(object_ids)
|
||||
|
||||
def call_one(object_id):
|
||||
error_message = '[{object_id}]: failed'.format(**locals())
|
||||
try:
|
||||
response = self._session.send_api(request_cls(object_id, **kwargs))
|
||||
except APIError as e:
|
||||
if not self._session.debug_mode:
|
||||
self.error('{}: {}'.format(error_message, e))
|
||||
else:
|
||||
traceback = e.format_traceback()
|
||||
if traceback:
|
||||
print(traceback)
|
||||
print('Own traceback:')
|
||||
print_exc()
|
||||
return False
|
||||
|
||||
if not response_validation_field or getattr(response, response_validation_field) == 1:
|
||||
return True
|
||||
else:
|
||||
self.error(error_message)
|
||||
return False
|
||||
|
||||
succeeded = [call_one(object_id) for object_id in object_ids].count(True)
|
||||
message = '{}/{} succeeded'.format(succeeded, len(object_ids))
|
||||
(self.log if succeeded == len(object_ids) else self.exit)(message)
|
||||
|
||||
def get_service(self, service_class):
|
||||
return service_class(config=self._session.config)
|
||||
|
||||
def _resolve_name(self, name, service=None):
|
||||
"""
|
||||
Resolve an object name to an object ID.
|
||||
Operation:
|
||||
- If the argument "looks like" an ID, return it.
|
||||
- Else, get all object with names containing the argument
|
||||
- if an object with the argument as its name exists, return the object's ID
|
||||
- Else, print a list of suggestions and exit
|
||||
:param str name: ID (returned unmodified) or Name to resolve
|
||||
:param str service: Service to resolve from (type of object). Defaults to service represented by the class
|
||||
:return: ID of object
|
||||
:rtype: str
|
||||
"""
|
||||
service = service or self.service
|
||||
if re.match(r'^[-a-f0-9]{30,}$', name):
|
||||
return name
|
||||
|
||||
try:
|
||||
request_cls = getattr(services, service).GetAllRequest
|
||||
except AttributeError:
|
||||
raise NameResolutionError('Name resolution unavailable for {}'.format(service))
|
||||
|
||||
request = request_cls.from_dict(dict(name=name, only_fields=['name', 'id']))
|
||||
# from_dict will ignore unrecognised keyword arguments - not all GetAll's have only_fields
|
||||
response = getattr(self._session.send_api(request), service)
|
||||
matches = [db_object for db_object in response if name.lower() == db_object.name.lower()]
|
||||
|
||||
def truncated_bullet_list(format_string, elements, callback, **kwargs):
|
||||
if len(elements) > self.MAX_SUGGESTIONS:
|
||||
kwargs.update(
|
||||
dict(details=' (showing {}/{})'.format(self.MAX_SUGGESTIONS, len(elements)), suffix='\n...'))
|
||||
else:
|
||||
kwargs.update(dict(details='', suffix=''))
|
||||
bullet_list = '\n'.join('* {}'.format(callback(item)) for item in elements[:self.MAX_SUGGESTIONS])
|
||||
return format_string.format(bullet_list, **kwargs)
|
||||
|
||||
if len(matches) == 1:
|
||||
return matches.pop().id
|
||||
elif len(matches) > 1:
|
||||
message = truncated_bullet_list(
|
||||
'Found multiple {service} with name "{name}"{details}:\n{}{suffix}',
|
||||
matches,
|
||||
callback=attrgetter('id'),
|
||||
**locals())
|
||||
self.exit(message)
|
||||
|
||||
message = 'Could not find {} with name "{}"'.format(service.rstrip('s'), name)
|
||||
|
||||
if not response:
|
||||
raise NameResolutionError(message)
|
||||
|
||||
suggestions = truncated_bullet_list(
|
||||
'. Did you mean this?{details}\n{}{suffix}',
|
||||
sorted(response, key=attrgetter('name')),
|
||||
lambda db_object: '({}) {}'.format(db_object.id, db_object.name)
|
||||
)
|
||||
raise NameResolutionError(message, suggestions)
|
||||
|
||||
|
||||
def recursive_diff(org, upd, out):
|
||||
if isinstance(upd, dict) and isinstance(org, dict):
|
||||
diff_keys = [
|
||||
k for k in upd
|
||||
if k not in org or upd[k] != org[k]
|
||||
]
|
||||
if diff_keys:
|
||||
has_nested_dict = False
|
||||
for k in diff_keys:
|
||||
if isinstance(upd[k], dict):
|
||||
out[k] = {}
|
||||
has_nested_dict = True
|
||||
k_has_nested = recursive_diff(
|
||||
org.get(k, {}), upd[k], out[k])
|
||||
if not k_has_nested:
|
||||
out[k] = upd[k]
|
||||
elif upd[k] is not None:
|
||||
out[k] = upd[k]
|
||||
return has_nested_dict
|
||||
elif isinstance(upd, list) and isinstance(org, list):
|
||||
diff_list = [k for k in upd if k not in org]
|
||||
out.extend(diff_list)
|
||||
return False
|
||||
215
trains_agent/commands/config.py
Normal file
215
trains_agent/commands/config.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import input
|
||||
from pyhocon import ConfigFactory
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from trains_agent.backend_api.session.defs import ENV_HOST
|
||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
|
||||
|
||||
|
||||
description = """
|
||||
Please create new credentials using the web app: {}/profile
|
||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
||||
|
||||
Paste credentials here: """
|
||||
|
||||
def_host = 'http://localhost:8080'
|
||||
try:
|
||||
def_host = ENV_HOST.get(default=def_host) or def_host
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
host_description = """
|
||||
Editing configuration file: {CONFIG_FILE}
|
||||
Enter the url of the trains-server's Web service, for example: {HOST}
|
||||
""".format(
|
||||
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
|
||||
HOST=def_host,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
print('TRAINS-AGENT setup process')
|
||||
conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute()
|
||||
if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0:
|
||||
print('Configuration file already exists: {}'.format(str(conf_file)))
|
||||
print('Leaving setup, feel free to edit the configuration file.')
|
||||
return
|
||||
|
||||
print(host_description)
|
||||
web_host = input_url('Web Application Host', '')
|
||||
parsed_host = verify_url(web_host)
|
||||
|
||||
if parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||
else:
|
||||
api_host = ''
|
||||
web_host = ''
|
||||
files_host = ''
|
||||
if not parsed_host.port:
|
||||
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
|
||||
replace_port = input().lower()
|
||||
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
|
||||
if not api_host:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
|
||||
api_host = input_url('API Host', api_host)
|
||||
files_host = input_url('File Store Host', files_host)
|
||||
|
||||
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
|
||||
api_host, web_host, files_host))
|
||||
|
||||
while True:
|
||||
print(description.format(web_host), end='')
|
||||
parse_input = input()
|
||||
# check if these are valid credentials
|
||||
credentials = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
credentials = parsed.get("credentials", None)
|
||||
except Exception:
|
||||
credentials = None
|
||||
|
||||
if not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse user credentials, try again one after the other.')
|
||||
credentials = {}
|
||||
# parse individual
|
||||
print('Enter user access key: ', end='')
|
||||
credentials['access_key'] = input()
|
||||
print('Enter user secret: ', end='')
|
||||
credentials['secret_key'] = input()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'], ))
|
||||
|
||||
from trains_agent.backend_api.session import Session
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
print('Verifying credentials ...')
|
||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
||||
print('Credentials verified!')
|
||||
break
|
||||
except Exception:
|
||||
print('Error: could not verify credentials: host={} access={} secret={}'.format(
|
||||
api_host, credentials['access_key'], credentials['secret_key']))
|
||||
|
||||
# get GIT User/Pass for cloning
|
||||
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
|
||||
git_user = input()
|
||||
if git_user.strip():
|
||||
print('Enter password for user \'{}\': '.format(git_user), end='')
|
||||
git_pass = input()
|
||||
print('Git repository cloning will be using user={} password={}'.format(git_user, git_pass))
|
||||
else:
|
||||
git_user = None
|
||||
git_pass = None
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conf_folder = Path(__file__).parent.absolute() / '..' / 'backend_api' / 'config' / 'default'
|
||||
default_conf = ''
|
||||
for conf in ('agent.conf', 'sdk.conf', ):
|
||||
conf_file_section = conf_folder / conf
|
||||
with open(str(conf_file_section), 'rt') as f:
|
||||
default_conf += conf.split('.')[0] + ' '
|
||||
default_conf += f.read()
|
||||
default_conf += '\n'
|
||||
except Exception:
|
||||
print('Error! Could not read default configuration file')
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(str(conf_file), 'wt') as f:
|
||||
header = '# TRAINS-AGENT configuration file\n' \
|
||||
'api {\n' \
|
||||
' api_server: %s\n' \
|
||||
' web_server: %s\n' \
|
||||
' files_server: %s\n' \
|
||||
' # Credentials are generated in the webapp, %s/profile\n' \
|
||||
' # Override with os environment: TRAINS_API_ACCESS_KEY / TRAINS_API_SECRET_KEY\n' \
|
||||
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
|
||||
'}\n\n' % (api_host, web_host, files_host,
|
||||
web_host, credentials['access_key'], credentials['secret_key'])
|
||||
f.write(header)
|
||||
git_credentials = '# Set GIT user/pass credentials\n' \
|
||||
'# leave blank for GIT SSH credentials\n' \
|
||||
'agent.git_user=\"{}\"\n' \
|
||||
'agent.git_pass=\"{}\"\n' \
|
||||
'\n'.format(git_user or '', git_pass or '')
|
||||
f.write(git_credentials)
|
||||
f.write(default_conf)
|
||||
except Exception:
|
||||
print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
|
||||
return
|
||||
|
||||
print('\nNew configuration stored in {}'.format(str(conf_file)))
|
||||
print('TRAINS-AGENT setup completed successfully.')
|
||||
|
||||
|
||||
def input_url(host_type, host=None):
|
||||
while True:
|
||||
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||
parse_input = input()
|
||||
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
|
||||
break
|
||||
if parse_input and verify_url(parse_input):
|
||||
host = parse_input
|
||||
break
|
||||
return host
|
||||
|
||||
|
||||
def verify_url(parse_input):
|
||||
try:
|
||||
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
||||
# if we have a specific port, use http prefix, otherwise assume https
|
||||
if ':' in parse_input:
|
||||
parse_input = 'http://' + parse_input
|
||||
else:
|
||||
parse_input = 'https://' + parse_input
|
||||
parsed_host = urlparse(parse_input)
|
||||
if parsed_host.scheme not in ('http', 'https'):
|
||||
parsed_host = None
|
||||
except Exception:
|
||||
parsed_host = None
|
||||
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
||||
return parsed_host
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
97
trains_agent/commands/events.py
Normal file
97
trains_agent/commands/events.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from future.builtins import super
|
||||
|
||||
from trains_agent.commands.base import ServiceCommandSection
|
||||
from trains_agent.helper.base import return_list
|
||||
|
||||
|
||||
class Events(ServiceCommandSection):
|
||||
max_packet_size = 1024 * 1024
|
||||
max_event_size = 64 * 1024
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Events, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def service(self):
|
||||
""" Events command service endpoint """
|
||||
return 'events'
|
||||
|
||||
def send_events(self, list_events):
|
||||
def send_packet(jsonlines):
|
||||
if not jsonlines:
|
||||
return 0
|
||||
num_lines = len(jsonlines)
|
||||
jsonlines = '\n'.join(jsonlines)
|
||||
|
||||
new_events = self.post('add_batch', data=jsonlines, headers={'Content-type': 'application/json-lines'})
|
||||
if new_events['added'] != num_lines:
|
||||
print('Error (%s) sending events only %d of %d registered' %
|
||||
(new_events['errors'], new_events['added'], num_lines))
|
||||
return int(new_events['added'])
|
||||
# print('Sent %d events' % num_lines)
|
||||
return num_lines
|
||||
|
||||
# json every line and push into list of json strings
|
||||
count_bytes = 0
|
||||
lines = []
|
||||
sent_events = 0
|
||||
for i, event in enumerate(list_events):
|
||||
line = json.dumps(event)
|
||||
line_len = len(line) + 1
|
||||
if count_bytes + line_len > self.max_packet_size:
|
||||
# flush packet, and restart
|
||||
sent_events += send_packet(lines)
|
||||
count_bytes = 0
|
||||
lines = []
|
||||
|
||||
count_bytes += line_len
|
||||
lines.append(line)
|
||||
|
||||
# flush leftovers
|
||||
sent_events += send_packet(lines)
|
||||
# print('Sending events done: %d / %d events sent' % (sent_events, len(list_events)))
|
||||
return sent_events
|
||||
|
||||
def send_log_events(self, worker_id, task_id, lines, level='DEBUG'):
|
||||
log_events = []
|
||||
base_timestamp = int(time.time() * 1000)
|
||||
base_log_items = {
|
||||
'type': 'log',
|
||||
'level': level,
|
||||
'task': task_id,
|
||||
'worker': worker_id,
|
||||
}
|
||||
|
||||
def get_event(c):
|
||||
d = base_log_items.copy()
|
||||
d.update(msg=msg, timestamp=base_timestamp + c)
|
||||
return d
|
||||
|
||||
# break log lines into event packets
|
||||
msg = ''
|
||||
count = 0
|
||||
for l in return_list(lines):
|
||||
# HACK ignore terminal reset ANSI code
|
||||
if l == '\x1b[0m':
|
||||
continue
|
||||
while l:
|
||||
if len(msg) + len(l) < self.max_event_size:
|
||||
msg += l
|
||||
l = None
|
||||
else:
|
||||
left_over = self.max_event_size - len(msg)
|
||||
msg += l[:left_over]
|
||||
l = l[left_over:]
|
||||
log_events.append(get_event(count))
|
||||
msg = ''
|
||||
count += 1
|
||||
if msg:
|
||||
log_events.append(get_event(count))
|
||||
|
||||
# now send the events
|
||||
return self.send_events(list_events=log_events)
|
||||
1681
trains_agent/commands/worker.py
Normal file
1681
trains_agent/commands/worker.py
Normal file
File diff suppressed because it is too large
Load Diff
87
trains_agent/complete.py
Normal file
87
trains_agent/complete.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Script for generating command-line completion.
|
||||
Called by trains_agent/utilities/complete.sh (or a copy of it) like so:
|
||||
|
||||
python -m trains_agent.complete "current command line"
|
||||
|
||||
And writes line-separated completion targets to stdout.
|
||||
Results are line-separated in order to enable other whitespace in results.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from trains_agent.interface import get_parser
|
||||
|
||||
|
||||
def is_argument_required(action):
|
||||
return isinstance(action, argparse._StoreAction)
|
||||
|
||||
|
||||
def format_option(option, argument_required):
|
||||
"""
|
||||
Return appropriate string for flags requiring arguments and flags that do not
|
||||
:param option: flag to format
|
||||
:param argument_required: whether argument is required
|
||||
"""
|
||||
return option + '=' if argument_required else option + ' '
|
||||
|
||||
|
||||
def get_options(parser):
|
||||
"""
|
||||
Return all possible flags for parser
|
||||
:param parser: argparse.ArgumentParser instance
|
||||
:return: list of options
|
||||
"""
|
||||
return [
|
||||
format_option(option, is_argument_required(action))
|
||||
for action in parser._actions
|
||||
for option in action.option_strings
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
return 1
|
||||
|
||||
comp_words = iter(sys.argv[1].split()[1:])
|
||||
|
||||
parser = get_parser()
|
||||
|
||||
seen = []
|
||||
for word in comp_words:
|
||||
if word in parser.choices:
|
||||
parser = parser[word]
|
||||
continue
|
||||
actions = {name: action for action in parser._actions for name in action.option_strings}
|
||||
first, _, rest = word.partition('=')
|
||||
is_one_word_store_action = rest and first in actions
|
||||
if is_one_word_store_action:
|
||||
word = first
|
||||
seen.append(word)
|
||||
try:
|
||||
action = actions[word]
|
||||
except KeyError:
|
||||
break
|
||||
if isinstance(action, argparse._StoreAction) and not isinstance(action, argparse._StoreConstAction):
|
||||
if not is_one_word_store_action:
|
||||
try:
|
||||
next(comp_words)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
options = list(parser.choices)
|
||||
|
||||
options = [format_option(option, argument_required=False) for option in options]
|
||||
options.extend(get_options(parser))
|
||||
options = [option for option in options if option.rstrip('= ') not in seen]
|
||||
|
||||
print('\n'.join(options))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
34
trains_agent/config.py
Normal file
34
trains_agent/config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from pyhocon import ConfigTree
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import Singleton
|
||||
|
||||
|
||||
@six.add_metaclass(Singleton)
|
||||
class Config(object):
|
||||
|
||||
def __init__(self, tree=None):
|
||||
self.__dict__['_tree'] = tree or ConfigTree()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._tree[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return self._tree.__setitem__(key, value)
|
||||
|
||||
def new(self, name):
|
||||
return self._tree.setdefault(name, ConfigTree())
|
||||
|
||||
__getattr__ = __getitem__
|
||||
__setattr__ = __setitem__
|
||||
|
||||
|
||||
def get_config(name=None):
|
||||
config = Config()
|
||||
if name:
|
||||
return getattr(config, name)
|
||||
return config
|
||||
|
||||
|
||||
def make_config(name):
|
||||
return get_config().new(name)
|
||||
131
trains_agent/definitions.py
Normal file
131
trains_agent/definitions.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from datetime import timedelta
|
||||
from distutils.util import strtobool
|
||||
from enum import IntEnum
|
||||
from os import getenv
|
||||
from typing import Text, Optional, Union, Tuple, Any
|
||||
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import normalize_path
|
||||
|
||||
PROGRAM_NAME = "trains-agent"
|
||||
FROM_FILE_PREFIX_CHARS = "@"
|
||||
|
||||
CONFIG_DIR = normalize_path("~/.trains")
|
||||
TOKEN_CACHE_FILE = normalize_path("~/.trains.trains_agent.tmp")
|
||||
|
||||
CONFIG_FILE_CANDIDATES = ["~/trains.conf"]
|
||||
|
||||
|
||||
def find_config_path():
|
||||
for candidate in CONFIG_FILE_CANDIDATES:
|
||||
if Path(candidate).expanduser().exists():
|
||||
return candidate
|
||||
return CONFIG_FILE_CANDIDATES[0]
|
||||
|
||||
|
||||
CONFIG_FILE = normalize_path(find_config_path())
|
||||
|
||||
|
||||
class EnvironmentConfig(object):
|
||||
|
||||
conversions = {
|
||||
bool: lambda value: bool(strtobool(value)),
|
||||
six.text_type: lambda s: six.text_type(s).strip(),
|
||||
}
|
||||
|
||||
def __init__(self, *names, **kwargs):
|
||||
self.vars = names
|
||||
self.type = kwargs.pop("type", six.text_type)
|
||||
|
||||
def convert(self, value):
|
||||
return self.conversions.get(self.type, self.type)(value)
|
||||
|
||||
def get(self, key=False): # type: (bool) -> Optional[Union[Any, Tuple[Text, Any]]]
|
||||
for name in self.vars:
|
||||
value = getenv(name)
|
||||
if value:
|
||||
value = self.convert(value)
|
||||
if key:
|
||||
return name, value
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
ENVIRONMENT_CONFIG = {
|
||||
"api.api_server": EnvironmentConfig("TRAINS_API_HOST", "ALG_API_HOST"),
|
||||
"api.credentials.access_key": EnvironmentConfig(
|
||||
"TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY"
|
||||
),
|
||||
"api.credentials.secret_key": EnvironmentConfig(
|
||||
"TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY"
|
||||
),
|
||||
"agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", "ALG_WORKER_NAME"),
|
||||
"agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", "ALG_WORKER_ID"),
|
||||
"agent.cuda_version": EnvironmentConfig(
|
||||
"TRAINS_CUDA_VERSION", "ALG_CUDA_VERSION", "CUDA_VERSION"
|
||||
),
|
||||
"agent.cudnn_version": EnvironmentConfig(
|
||||
"TRAINS_CUDNN_VERSION", "ALG_CUDNN_VERSION", "CUDNN_VERSION"
|
||||
),
|
||||
"agent.cpu_only": EnvironmentConfig(
|
||||
"TRAINS_CPU_ONLY", "ALG_CPU_ONLY", "CPU_ONLY", type=bool
|
||||
),
|
||||
}
|
||||
|
||||
CONFIG_FILE_ENV = EnvironmentConfig("ALG_CONFIG_FILE")
|
||||
|
||||
ENVIRONMENT_SDK_PARAMS = {
|
||||
"task_id": ("TRAINS_TASK_ID", "ALG_TASK_ID"),
|
||||
"config_file": ("TRAINS_CONFIG_FILE", "ALG_CONFIG_FILE", "TRAINS_CONFIG_FILE"),
|
||||
"log_level": ("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL"),
|
||||
"log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND"),
|
||||
}
|
||||
|
||||
VIRTUAL_ENVIRONMENT_PATH = {
|
||||
"python2": normalize_path(CONFIG_DIR, "py2venv"),
|
||||
"python3": normalize_path(CONFIG_DIR, "py3venv"),
|
||||
}
|
||||
|
||||
DEFAULT_BASE_DIR = normalize_path(CONFIG_DIR, "data_cache")
|
||||
DEFAULT_HOST = "https://demoai.trainsai.io"
|
||||
MAX_DATASET_SOURCES_COUNT = 50000
|
||||
|
||||
INVALID_WORKER_ID = (400, 1001)
|
||||
WORKER_ALREADY_REGISTERED = (400, 1003)
|
||||
|
||||
API_VERSION = "v1.5"
|
||||
TOKEN_EXPIRATION_SECONDS = int(timedelta(days=2).total_seconds())
|
||||
|
||||
HTTP_HEADERS = {
|
||||
"worker": "X-Trains-Worker",
|
||||
"act-as": "X-Trains-Act-As",
|
||||
"client": "X-Trains-Agent",
|
||||
}
|
||||
METADATA_EXTENSION = ".json"
|
||||
|
||||
DEFAULT_VENV_UPDATE_URL = (
|
||||
"https://raw.githubusercontent.com/Yelp/venv-update/v3.2.2/venv_update.py"
|
||||
)
|
||||
WORKING_REPOSITORY_DIR = "task_repository"
|
||||
DEFAULT_VCS_CACHE = normalize_path(CONFIG_DIR, "vcs-cache")
|
||||
PIP_EXTRA_INDICES = [
|
||||
]
|
||||
DEFAULT_PIP_DOWNLOAD_CACHE = normalize_path(CONFIG_DIR, "pip-download-cache")
|
||||
|
||||
|
||||
class FileBuffering(IntEnum):
|
||||
"""
|
||||
File buffering options:
|
||||
- UNSET: follows the defaults for the type of file,
|
||||
line-buffered for interactive (tty) text files and with a default chunk size otherwise
|
||||
- UNBUFFERED: no buffering at all
|
||||
- LINE_BUFFERED: per-line buffering, only valid for text files
|
||||
- values bigger than 1 indicate the size of the buffer in bytes and are not represented by the enum
|
||||
"""
|
||||
|
||||
UNSET = -1
|
||||
UNBUFFERED = 0
|
||||
LINE_BUFFERING = 1
|
||||
86
trains_agent/errors.py
Normal file
86
trains_agent/errors.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typing import Union, Optional, Text
|
||||
|
||||
import requests
|
||||
import six
|
||||
from .backend_api import CallResult
|
||||
from .backend_api.session.client import APIError as ClientAPIError
|
||||
from .backend_api.session.response import ResponseMeta
|
||||
|
||||
INTERNAL_SERVER_ERROR = 500
|
||||
|
||||
|
||||
# TODO: hack: should NOT inherit from ValueError
|
||||
class APIError(ClientAPIError, ValueError):
|
||||
"""
|
||||
Class for representing an API error.
|
||||
|
||||
self.data - ``dict`` of all returned JSON data
|
||||
self.code - HTTP response code
|
||||
self.subcode - server response subcode
|
||||
self.codes - (self.code, self.subcode) tuple
|
||||
self.message - result message sent from server
|
||||
"""
|
||||
|
||||
def __init__(self, response, extra_info=None):
|
||||
# type: (Union[requests.Response, CallResult], Optional[Text]) -> None
|
||||
"""
|
||||
Create a new APIError from a server response
|
||||
"""
|
||||
if not isinstance(response, CallResult):
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
data = {}
|
||||
meta = data.get('meta')
|
||||
if meta:
|
||||
response_meta = ResponseMeta(is_valid=False, **meta)
|
||||
else:
|
||||
response_meta = ResponseMeta.from_raw_data(response.status_code, response.text)
|
||||
response = CallResult(
|
||||
meta=response_meta,
|
||||
response=response,
|
||||
response_data=data,
|
||||
)
|
||||
super(APIError, self).__init__(response, extra_info=extra_info)
|
||||
|
||||
def format_traceback(self):
|
||||
if self.code != INTERNAL_SERVER_ERROR:
|
||||
return ''
|
||||
traceback = self.get_traceback()
|
||||
if traceback:
|
||||
return 'Server traceback:\n{}'.format(traceback)
|
||||
else:
|
||||
return 'Could not print server traceback'
|
||||
|
||||
|
||||
class CommandFailedError(Exception):
|
||||
|
||||
def __init__(self, message=None, *args, **kwargs):
|
||||
super(CommandFailedError, self).__init__(message, *args, **kwargs)
|
||||
self.message = message
|
||||
|
||||
|
||||
class UsageError(CommandFailedError):
|
||||
"""
|
||||
Used for usage errors that are checked post-argparsing
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileNotFound(CommandFailedError):
|
||||
pass
|
||||
|
||||
|
||||
class Sigterm(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
@six.python_2_unicode_compatible
|
||||
class MissingPackageError(CommandFailedError):
|
||||
def __init__(self, name):
|
||||
super(MissingPackageError, self).__init__(name)
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return '{self.__class__.__name__}: ' \
|
||||
'"{self.name}" package is required. Please run "pip install {self.name}"'.format(self=self)
|
||||
0
trains_agent/helper/__init__.py
Normal file
0
trains_agent/helper/__init__.py
Normal file
509
trains_agent/helper/base.py
Normal file
509
trains_agent/helper/base.py
Normal file
@@ -0,0 +1,509 @@
|
||||
""" TRAINS-AGENT Stdout Helper Functions """
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from abc import ABCMeta
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import total_ordering
|
||||
from typing import Text, Dict, Any, Optional, AnyStr, IO, Union
|
||||
|
||||
import attr
|
||||
import furl
|
||||
import pyhocon
|
||||
import yaml
|
||||
from attr import fields_dict
|
||||
from pathlib2 import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import six
|
||||
from six.moves import reduce
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.dicts import filter_keys
|
||||
|
||||
pretty_lines = False
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def which(cmd, path=None):
|
||||
result = find_executable(cmd, path)
|
||||
if not result:
|
||||
raise ValueError('command "{}" not found'.format(cmd))
|
||||
return result
|
||||
|
||||
|
||||
def select_for_platform(linux, windows):
|
||||
"""
|
||||
Select between multiple values according to the OS
|
||||
:param linux: value to return if OS is linux
|
||||
:param windows: value to return if OS is Windows
|
||||
"""
|
||||
return windows if is_windows_platform() else linux
|
||||
|
||||
|
||||
def bash_c():
|
||||
return 'bash -c' if not is_windows_platform() else 'cmd /c'
|
||||
|
||||
|
||||
def return_list(arg):
|
||||
if arg and not isinstance(arg, (tuple, list)):
|
||||
return [arg]
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def print_table(entries, columns=(), titles=(), csv=None, headers=True):
|
||||
table = create_table(entries, columns=columns, titles=titles, csv=csv, headers=headers)
|
||||
if csv:
|
||||
with open(csv, 'w') as output:
|
||||
print(table, file=output)
|
||||
else:
|
||||
print(table)
|
||||
|
||||
|
||||
def create_table(entries, columns=(), titles=(), csv=None, headers=True):
|
||||
table = [
|
||||
[
|
||||
reduce(
|
||||
lambda obj, key: obj.get(key, {}),
|
||||
column.split('.'),
|
||||
entry
|
||||
) or ''
|
||||
for column in columns
|
||||
]
|
||||
for entry in entries
|
||||
]
|
||||
if headers:
|
||||
headers = [titles[i] if i < len(titles) and titles[i] else c for i, c in enumerate(columns)]
|
||||
else:
|
||||
headers = []
|
||||
output = ''
|
||||
if csv:
|
||||
if headers:
|
||||
output += ','.join(headers) + '\n'
|
||||
for entry in table:
|
||||
output += ','.join(map(str, entry)) + '\n'
|
||||
else:
|
||||
min_col_width = 3
|
||||
col_widths = [max(min_col_width, len(h)+1) for h in (headers or table[0])]
|
||||
for e in table:
|
||||
col_widths = list(map(max, zip(col_widths, [len(h)+1 for h in e])))
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
if headers:
|
||||
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*headers) + ' |' + '\n'
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
|
||||
for entry in table:
|
||||
line = map(str, entry)
|
||||
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*line) + ' |' + '\n'
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
return output
|
||||
|
||||
|
||||
def create_tree(entries, id='id', parent='parent', node_title='%(id)'):
|
||||
tree = OrderedDict()
|
||||
all_nodes = dict()
|
||||
for t in entries:
|
||||
i = t.get(id, None)
|
||||
p = t.get(parent, None)
|
||||
if not p and i not in tree:
|
||||
# push roots
|
||||
myd = all_nodes.get(i, OrderedDict())
|
||||
# add node title
|
||||
tree[node_title % t] = myd
|
||||
all_nodes[i] = myd
|
||||
elif p:
|
||||
# update parent dictionary
|
||||
d = all_nodes.get(p, OrderedDict())
|
||||
# get node dictionary
|
||||
myd = all_nodes.get(i, OrderedDict())
|
||||
# add node title
|
||||
d[node_title % t] = myd
|
||||
all_nodes[p] = d
|
||||
all_nodes[i] = myd
|
||||
else:
|
||||
pass
|
||||
return {'': tree}
|
||||
|
||||
|
||||
def print_parameters(param_struct, indent=1):
|
||||
text = yaml.safe_dump(param_struct, allow_unicode=True, indent=indent, default_flow_style=False)
|
||||
print(text)
|
||||
|
||||
|
||||
def get_list_files(basefolder, filext=('.jpg')):
|
||||
filext = [e.lower() for e in filext]
|
||||
fileiter = (os.path.join(root, f)
|
||||
for root, _, files in os.walk(basefolder)
|
||||
for f in files if os.path.splitext(f)[1].lower() in filext)
|
||||
return fileiter
|
||||
|
||||
|
||||
def is_windows_platform():
|
||||
return any(platform.win32_ver())
|
||||
|
||||
|
||||
def normalize_path(*paths):
|
||||
"""
|
||||
normalize_path
|
||||
|
||||
Joins ``*paths``, expands ``~`` and normalizes path separators.
|
||||
|
||||
:param paths: path components to create path from
|
||||
"""
|
||||
return os.path.normpath(os.path.expandvars(os.path.expanduser(os.path.join(*map(str, paths)))))
|
||||
|
||||
|
||||
def safe_remove_file(filename, error_message=None):
|
||||
try:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
if error_message:
|
||||
print(error_message)
|
||||
|
||||
|
||||
class Singleton(ABCMeta):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class CompareAnything(object):
|
||||
"""
|
||||
CompareAnything
|
||||
|
||||
Creates an object which is always the smallest when compared to other objects.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __eq__(_):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __lt__(_):
|
||||
return True
|
||||
|
||||
|
||||
def nonstrict_in_place_sort(lst, reverse, *keys):
|
||||
"""
|
||||
nonstrict_in_place_sort
|
||||
|
||||
Sorts a list of dictionaries in-place by ``keys``.
|
||||
An element without a certain ``key`` will be considered the smallest in respect to that key.
|
||||
|
||||
:param lst: list to sort
|
||||
:type lst: ``[dict]``
|
||||
:param reverse: whether to reverse sorting
|
||||
:type reverse: ``bool``
|
||||
:param keys: Keys to sort by.
|
||||
Elements will be sorted pseudo-lexicographically by the values corresponding to ``*keys``, i.e:
|
||||
the list will be first sorted by the first element of ``*keys``,
|
||||
elements which are equal by the first sort will be internally sorted by
|
||||
the second element of ``*keys`` and so on.
|
||||
:type keys: ``[str]``
|
||||
"""
|
||||
lst.sort(
|
||||
key=lambda item: tuple(item.get(key, CompareAnything()) for key in keys),
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
|
||||
def load_yaml(path):
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
try:
|
||||
with open(path) as data_file:
|
||||
return yaml.safe_load(data_file) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError('Failed parsing yaml file [{}]: {}'.format(path, e))
|
||||
|
||||
|
||||
def dump_yaml(obj, path=None, dump_all=False, **kwargs):
|
||||
base_kwargs = dict(indent=4, allow_unicode=True, default_flow_style=False)
|
||||
base_kwargs.update(kwargs)
|
||||
if dump_all:
|
||||
base_kwargs['Dumper'] = AllDumper
|
||||
dump_func = yaml.dump
|
||||
else:
|
||||
dump_func = yaml.safe_dump
|
||||
if not path:
|
||||
return dump_func(obj, **base_kwargs)
|
||||
path = str(path)
|
||||
with open(path, 'w') as output:
|
||||
dump_func(obj, output, **base_kwargs)
|
||||
|
||||
|
||||
def one_value(dct):
|
||||
return next(iter(six.itervalues(dct)))
|
||||
|
||||
|
||||
@attr.s
|
||||
class RepoInfo(object):
|
||||
type = attr.ib(type=str)
|
||||
url = attr.ib(type=str)
|
||||
branch = attr.ib(type=str)
|
||||
commit = attr.ib(type=str)
|
||||
root = attr.ib(type=str)
|
||||
|
||||
|
||||
def get_repo_info(repo_type, path):
|
||||
assert repo_type in ['git', 'hg']
|
||||
if repo_type == 'git':
|
||||
commands = dict(
|
||||
url='git remote get-url origin',
|
||||
branch='git rev-parse --abbrev-ref HEAD',
|
||||
commit='git rev-parse HEAD',
|
||||
root='git rev-parse --show-toplevel'
|
||||
)
|
||||
elif repo_type == 'hg':
|
||||
commands = dict(
|
||||
url='hg paths --verbose',
|
||||
branch='hg --debug id -b',
|
||||
commit='hg --debug id -i',
|
||||
root='hg root'
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown repository type '{}'".format(repo_type))
|
||||
|
||||
commands_result = {
|
||||
name: subprocess.check_output(command.split(), cwd=path).decode().strip()
|
||||
for name, command in commands.items()
|
||||
}
|
||||
return RepoInfo(type=repo_type, **commands_result)
|
||||
|
||||
|
||||
def reverse_home_folder_expansion(path):
|
||||
path = str(path)
|
||||
if is_windows_platform():
|
||||
return path
|
||||
return re.sub('^{}/'.format(re.escape(str(Path.home()))), '~/', path)
|
||||
|
||||
|
||||
def represent_ordered_dict(dumper, data):
|
||||
"""
|
||||
Serializes ``OrderedDict`` to YAML by its proper order.
|
||||
Registering this function to ``yaml.SafeDumper`` enables using ``yaml.safe_dump`` with ``OrderedDict``s.
|
||||
"""
|
||||
return dumper.represent_mapping(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items())
|
||||
|
||||
|
||||
def construct_mapping(loader, node):
|
||||
"""
|
||||
Deserialize YAML mappings as ``OrderedDict``s.
|
||||
"""
|
||||
loader.flatten_mapping(node)
|
||||
return OrderedDict(loader.construct_pairs(node))
|
||||
|
||||
|
||||
yaml.SafeDumper.add_representer(OrderedDict, represent_ordered_dict)
|
||||
yaml.SafeLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
|
||||
|
||||
|
||||
class AllDumper(yaml.SafeDumper):
|
||||
pass
|
||||
|
||||
|
||||
AllDumper.add_multi_representer(object, lambda dumper, data: dumper.represent_str(str(data)))
|
||||
|
||||
|
||||
def error(message):
|
||||
print('\ntrains_agent: ERROR: {}\n'.format(message))
|
||||
|
||||
|
||||
def warning(message):
|
||||
print('trains_agent: Warning: {}'.format(message))
|
||||
|
||||
|
||||
class TqdmStream(object):
|
||||
|
||||
def __init__(self, file_object):
|
||||
self.buffer = file_object
|
||||
|
||||
def write(self, data):
|
||||
self.buffer.write(data.strip())
|
||||
|
||||
def flush(self):
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
class TqdmLog(tqdm):
|
||||
|
||||
def __init__(self, iterable=None, file=None, **kwargs):
|
||||
super(TqdmLog, self).__init__(iterable, file=TqdmStream(file or sys.stderr), **kwargs)
|
||||
|
||||
|
||||
def url_join(first, *rest):
|
||||
"""
|
||||
Join url parts similarly to Path.join
|
||||
"""
|
||||
return str(furl.furl(first).path.add(rest)).lstrip('/')
|
||||
|
||||
|
||||
class LowercaseFormatter(logging.Formatter):
|
||||
def format(self, record, *args, **kwargs):
|
||||
record.levelname = record.levelname.lower()
|
||||
return super(LowercaseFormatter, self).format(record, *args, **kwargs)
|
||||
|
||||
|
||||
def mkstemp(
|
||||
open_kwargs=None, # type: Optional[Dict[Text, Any]]
|
||||
text=True, # type: bool
|
||||
name_only=False, # type: bool
|
||||
*args,
|
||||
**kwargs):
|
||||
# type: (...) -> Union[(IO[AnyStr], Text), Text]
|
||||
"""
|
||||
WARNING: the returned file object is strict about its input type,
|
||||
make sure to feed it binary/text input in correspondence to the ``text`` argument
|
||||
:param open_kwargs: keyword arguments for ``io.open``
|
||||
:param text: open in text mode
|
||||
:param name_only: close the file and return its name
|
||||
:param args: tempfile.mkstemp args
|
||||
:param kwargs: tempfile.mkstemp kwargs
|
||||
"""
|
||||
fd, name = tempfile.mkstemp(text=text, *args, **kwargs)
|
||||
mode = 'w+'
|
||||
if not text:
|
||||
mode += 'b'
|
||||
if name_only:
|
||||
os.close(fd)
|
||||
return name
|
||||
return io.open(fd, mode, **open_kwargs or {}), name
|
||||
|
||||
|
||||
def named_temporary_file(*args, **kwargs):
|
||||
if six.PY2:
|
||||
buffering = kwargs.pop('buffering', None)
|
||||
if buffering:
|
||||
kwargs['bufsize'] = buffering
|
||||
return tempfile.NamedTemporaryFile(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_override(string):
|
||||
return pyhocon.ConfigFactory.parse_string(string).as_plain_ordered_dict()
|
||||
|
||||
|
||||
def chain_map(*args):
|
||||
return reduce(lambda x, y: x.update(y) or x, args, {})
|
||||
|
||||
|
||||
def check_directory_path(path):
|
||||
message = 'Could not create directory "{}": {}'
|
||||
if not is_windows_platform():
|
||||
match = re.search(r'\s', path)
|
||||
if match:
|
||||
raise CommandFailedError(
|
||||
'directories may not contain whitespace (char: {!r}, position: {})'.format(match.group(0),
|
||||
match.endpos))
|
||||
try:
|
||||
Path(os.path.expandvars(path)).expanduser().mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
raise CommandFailedError(message.format(path, e.strerror))
|
||||
except Exception as e:
|
||||
raise CommandFailedError(message.format(path, e))
|
||||
|
||||
|
||||
def create_file_if_not_exists(path):
|
||||
if not os.path.exists(os.path.expanduser(os.path.expandvars(path))):
|
||||
open(path, "w").close()
|
||||
|
||||
|
||||
def rm_tree(root): # type: (Union[Path, Text]) -> None
|
||||
"""
|
||||
A version of shutil.rmtree that handles access errors, specifically hidden files on Windows
|
||||
"""
|
||||
def on_error(func, path, _):
|
||||
try:
|
||||
if os.path.exists(path) and not os.access(path, os.W_OK):
|
||||
os.chmod(path, stat.S_IWUSR)
|
||||
func(path)
|
||||
except Exception:
|
||||
pass
|
||||
return shutil.rmtree(os.path.expanduser(os.path.expandvars(Text(root))), onerror=on_error)
|
||||
|
||||
|
||||
def is_conda(config):
|
||||
return config['agent.package_manager.type'].lower() == 'conda'
|
||||
|
||||
|
||||
class NonStrictAttrs(object):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs):
|
||||
fields = fields_dict(cls)
|
||||
return cls(**filter_keys(lambda key: key in fields, kwargs))
|
||||
|
||||
|
||||
def python_version_string():
|
||||
return '{v.major}.{v.minor}'.format(v=sys.version_info)
|
||||
|
||||
|
||||
join_lines = '\n'.join
|
||||
|
||||
|
||||
class HOCONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
pyhocon bugs:
|
||||
1. "\\t" is dumped as "\t" instead of "\\t", which is read as the character "\t".
|
||||
2. parsed config trees have dummy `pyhocon.config_tree.NoneValue` in them.
|
||||
(see: https://github.com/chimpler/pyhocon/issues/111)
|
||||
Workaround: dump HOCON to JSON, of which it is a subset, taking care of `NoneValue`s.
|
||||
"""
|
||||
|
||||
def default(self, o):
|
||||
"""
|
||||
If o is `pyhocon.config_tree.NoneValue`, encode it the same way as `None`.
|
||||
"""
|
||||
if isinstance(o, pyhocon.config_tree.NoneValue):
|
||||
return super(HOCONEncoder, self).encode(None)
|
||||
return super(HOCONEncoder, self).default(o)
|
||||
|
||||
|
||||
nullable_string = attr.ib(default="", converter=lambda x: x.strip())
|
||||
normal_path = attr.ib(default="", converter=lambda p: p and normalize_path(p))
|
||||
|
||||
|
||||
@attr.s
|
||||
class ExecutionInfo(NonStrictAttrs):
|
||||
repository = nullable_string
|
||||
entry_point = normal_path
|
||||
working_dir = normal_path
|
||||
branch = nullable_string
|
||||
version_num = nullable_string
|
||||
tag = nullable_string
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task_info):
|
||||
# type: (...) -> ExecutionInfo
|
||||
"""
|
||||
extract ExecutionInfo tuple from task parameters
|
||||
"""
|
||||
if not task_info.script:
|
||||
raise CommandFailedError("can not run task without script information")
|
||||
execution = cls.from_dict(task_info.script.to_dict())
|
||||
if not execution.entry_point:
|
||||
log.warning("notice: `script.entry_point` is empty")
|
||||
if not execution.working_dir:
|
||||
entry_point, _, working_dir = execution.entry_point.partition(":")
|
||||
execution.entry_point = entry_point
|
||||
execution.working_dir = working_dir or ""
|
||||
|
||||
return execution
|
||||
60
trains_agent/helper/check_update.py
Normal file
60
trains_agent/helper/check_update.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import requests
|
||||
import json
|
||||
from threading import Thread
|
||||
from semantic_version import Version
|
||||
from ..version import __version__
|
||||
|
||||
__check_update_thread = None
|
||||
|
||||
|
||||
def start_check_update_daemon():
|
||||
global __check_update_thread
|
||||
if __check_update_thread:
|
||||
return
|
||||
__check_update_thread = Thread(target=_check_update_daemon)
|
||||
__check_update_thread.daemon = True
|
||||
__check_update_thread.start()
|
||||
|
||||
|
||||
def _check_new_version_available():
|
||||
cur_version = __version__
|
||||
update_server_releases = requests.get('https://updates.trainsai.io/updates',
|
||||
data=json.dumps({"versions": {"trains-agent": str(cur_version)}}),
|
||||
timeout=3.0)
|
||||
if update_server_releases.ok:
|
||||
update_server_releases = update_server_releases.json()
|
||||
else:
|
||||
return None
|
||||
trains_answer = update_server_releases.get("trains-agent", {})
|
||||
latest_version = trains_answer.get("version")
|
||||
cur_version = Version(cur_version)
|
||||
latest_version = Version(latest_version)
|
||||
if cur_version >= latest_version:
|
||||
return None
|
||||
patch_upgrade = latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
|
||||
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")
|
||||
|
||||
|
||||
def _check_update_daemon():
|
||||
counter = 0
|
||||
while True:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
latest_version = _check_new_version_available()
|
||||
# only print when we begin
|
||||
if latest_version:
|
||||
if latest_version[1]:
|
||||
sep = os.linesep
|
||||
print('TRAINS-AGENT new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
|
||||
latest_version[0], sep.join(latest_version[2])))
|
||||
else:
|
||||
print('TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
|
||||
latest_version[0]))
|
||||
except Exception:
|
||||
pass
|
||||
# sleep until the next day
|
||||
sleep(60 * 60 * 24)
|
||||
counter += 1
|
||||
100
trains_agent/helper/console.py
Normal file
100
trains_agent/helper/console.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import csv
|
||||
import sys
|
||||
from collections import Iterable
|
||||
from typing import List, Dict, Text, Any
|
||||
|
||||
from attr import attrs, attrib
|
||||
|
||||
import six
|
||||
from six import binary_type, text_type
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
|
||||
|
||||
|
||||
def print_text(text, newline=True):
|
||||
if newline:
|
||||
text += '\n'
|
||||
data = text.encode(sys.stdout.encoding or 'utf8', errors='replace')
|
||||
try:
|
||||
sys.stdout.buffer.write(data)
|
||||
except AttributeError:
|
||||
sys.stdout.write(data)
|
||||
|
||||
|
||||
def ensure_text(s, encoding='utf-8', errors='strict'):
|
||||
"""Coerce *s* to six.text_type.
|
||||
For Python 2:
|
||||
- `unicode` -> `unicode`
|
||||
- `str` -> `unicode`
|
||||
For Python 3:
|
||||
- `str` -> `str`
|
||||
- `bytes` -> decoded to `str`
|
||||
"""
|
||||
if isinstance(s, binary_type):
|
||||
return s.decode(encoding, errors)
|
||||
elif isinstance(s, text_type):
|
||||
return s
|
||||
else:
|
||||
raise TypeError("not expecting type '%s'" % type(s))
|
||||
|
||||
|
||||
def ensure_binary(s, encoding='utf-8', errors='strict'):
|
||||
"""Coerce **s** to six.binary_type.
|
||||
For Python 2:
|
||||
- `unicode` -> encoded to `str`
|
||||
- `str` -> `str`
|
||||
For Python 3:
|
||||
- `str` -> encoded to `bytes`
|
||||
- `bytes` -> `bytes`
|
||||
"""
|
||||
if isinstance(s, text_type):
|
||||
return s.encode(encoding, errors)
|
||||
elif isinstance(s, binary_type):
|
||||
return s
|
||||
else:
|
||||
raise TypeError("not expecting type '%s'" % type(s))
|
||||
|
||||
|
||||
class ListFormatter(object):
|
||||
|
||||
@attrs(init=False)
|
||||
class Table(object):
|
||||
entries = attrib(type=List[Dict])
|
||||
columns = attrib(type=List[Text])
|
||||
|
||||
def __init__(self, entries, columns):
|
||||
self.entries = entries
|
||||
if isinstance(columns, str):
|
||||
columns = columns.split('#')
|
||||
self.columns = columns
|
||||
|
||||
def as_rows(self): # type: () -> Iterable[Iterable[Any]]
|
||||
return (
|
||||
map(entry.get, self.columns)
|
||||
for entry in self.entries
|
||||
)
|
||||
|
||||
def __init__(self, service_name):
|
||||
self.service_name = service_name
|
||||
|
||||
def get_total(self, entries):
|
||||
return '\nTotal {} {}'.format(self.service_name, len(entries))
|
||||
|
||||
@classmethod
|
||||
def write_csv(cls, entries, columns, dest, headers=True):
|
||||
table = cls.Table(entries, columns)
|
||||
with open(dest, 'w') as output:
|
||||
writer = csv.DictWriter(output, fieldnames=table.columns, extrasaction='ignore')
|
||||
if headers:
|
||||
writer.writeheader()
|
||||
writer.writerows(table.entries)
|
||||
|
||||
@staticmethod
|
||||
def sort_in_place(entries, key, reverse=None):
|
||||
if isinstance(key, six.string_types):
|
||||
nonstrict_in_place_sort(entries, reverse, *key.split('#'))
|
||||
elif callable(key):
|
||||
entries.sort(key=key, reverse=reverse)
|
||||
else:
|
||||
raise ValueError('"sort" argument must be either a string or a callable object')
|
||||
5
trains_agent/helper/dicts.py
Normal file
5
trains_agent/helper/dicts.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from typing import Callable, Dict, Any
|
||||
|
||||
|
||||
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
|
||||
return {key: value for key, value in dct.items() if filter_(key)}
|
||||
0
trains_agent/helper/gpu/__init__.py
Normal file
0
trains_agent/helper/gpu/__init__.py
Normal file
385
trains_agent/helper/gpu/gpustat.py
Normal file
385
trains_agent/helper/gpu/gpustat.py
Normal file
@@ -0,0 +1,385 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Implementation of gpustat
|
||||
@author Jongwook Choi
|
||||
@url https://github.com/wookayin/gpustat
|
||||
|
||||
@ copied from gpu-stat 0.6.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import os.path
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import psutil
|
||||
from ..gpu import pynvml as N
|
||||
|
||||
NOT_SUPPORTED = 'Not Supported'
|
||||
MB = 1024 * 1024
|
||||
|
||||
|
||||
class GPUStat(object):
|
||||
|
||||
def __init__(self, entry):
|
||||
if not isinstance(entry, dict):
|
||||
raise TypeError(
|
||||
'entry should be a dict, {} given'.format(type(entry))
|
||||
)
|
||||
self.entry = entry
|
||||
|
||||
def keys(self):
|
||||
return self.entry.keys()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.entry[key]
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""
|
||||
Returns the index of GPU (as in nvidia-smi).
|
||||
"""
|
||||
return self.entry['index']
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
"""
|
||||
Returns the uuid returned by nvidia-smi,
|
||||
e.g. GPU-12345678-abcd-abcd-uuid-123456abcdef
|
||||
"""
|
||||
return self.entry['uuid']
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
Returns the name of GPU card (e.g. Geforce Titan X)
|
||||
"""
|
||||
return self.entry['name']
|
||||
|
||||
@property
|
||||
def memory_total(self):
|
||||
"""
|
||||
Returns the total memory (in MB) as an integer.
|
||||
"""
|
||||
return int(self.entry['memory.total'])
|
||||
|
||||
@property
|
||||
def memory_used(self):
|
||||
"""
|
||||
Returns the occupied memory (in MB) as an integer.
|
||||
"""
|
||||
return int(self.entry['memory.used'])
|
||||
|
||||
@property
|
||||
def memory_free(self):
|
||||
"""
|
||||
Returns the free (available) memory (in MB) as an integer.
|
||||
"""
|
||||
v = self.memory_total - self.memory_used
|
||||
return max(v, 0)
|
||||
|
||||
@property
|
||||
def memory_available(self):
|
||||
"""
|
||||
Returns the available memory (in MB) as an integer.
|
||||
Alias of memory_free.
|
||||
"""
|
||||
return self.memory_free
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
"""
|
||||
Returns the temperature (in celcius) of GPU as an integer,
|
||||
or None if the information is not available.
|
||||
"""
|
||||
v = self.entry['temperature.gpu']
|
||||
return int(v) if v is not None else None
|
||||
|
||||
@property
|
||||
def fan_speed(self):
|
||||
"""
|
||||
Returns the fan speed percentage (0-100) of maximum intended speed
|
||||
as an integer, or None if the information is not available.
|
||||
"""
|
||||
v = self.entry['fan.speed']
|
||||
return int(v) if v is not None else None
|
||||
|
||||
@property
|
||||
def utilization(self):
|
||||
"""
|
||||
Returns the GPU utilization (in percentile),
|
||||
or None if the information is not available.
|
||||
"""
|
||||
v = self.entry['utilization.gpu']
|
||||
return int(v) if v is not None else None
|
||||
|
||||
@property
|
||||
def power_draw(self):
|
||||
"""
|
||||
Returns the GPU power usage in Watts,
|
||||
or None if the information is not available.
|
||||
"""
|
||||
v = self.entry['power.draw']
|
||||
return int(v) if v is not None else None
|
||||
|
||||
@property
|
||||
def power_limit(self):
|
||||
"""
|
||||
Returns the (enforced) GPU power limit in Watts,
|
||||
or None if the information is not available.
|
||||
"""
|
||||
v = self.entry['enforced.power.limit']
|
||||
return int(v) if v is not None else None
|
||||
|
||||
@property
|
||||
def processes(self):
|
||||
"""
|
||||
Get the list of running processes on the GPU.
|
||||
"""
|
||||
return self.entry['processes']
|
||||
|
||||
def jsonify(self):
|
||||
o = dict(self.entry)
|
||||
if self.entry['processes'] is not None:
|
||||
o['processes'] = [{k: v for (k, v) in p.items() if k != 'gpu_uuid'}
|
||||
for p in self.entry['processes']]
|
||||
else:
|
||||
o['processes'] = '({})'.format(NOT_SUPPORTED)
|
||||
return o
|
||||
|
||||
|
||||
class GPUStatCollection(object):
|
||||
global_processes = {}
|
||||
_initialized = False
|
||||
_device_count = None
|
||||
_gpu_device_info = {}
|
||||
|
||||
def __init__(self, gpu_list, driver_version=None):
|
||||
self.gpus = gpu_list
|
||||
|
||||
# attach additional system information
|
||||
self.hostname = platform.node()
|
||||
self.query_time = datetime.now()
|
||||
self.driver_version = driver_version
|
||||
|
||||
@staticmethod
|
||||
def clean_processes():
|
||||
for pid in list(GPUStatCollection.global_processes.keys()):
|
||||
if not psutil.pid_exists(pid):
|
||||
del GPUStatCollection.global_processes[pid]
|
||||
|
||||
@staticmethod
|
||||
def new_query(shutdown=False, per_process_stats=False, get_driver_info=False):
|
||||
"""Query the information of all the GPUs on local machine"""
|
||||
|
||||
if not GPUStatCollection._initialized:
|
||||
N.nvmlInit()
|
||||
GPUStatCollection._initialized = True
|
||||
|
||||
def _decode(b):
|
||||
if isinstance(b, bytes):
|
||||
return b.decode() # for python3, to unicode
|
||||
return b
|
||||
|
||||
def get_gpu_info(index, handle):
|
||||
"""Get one GPU information specified by nvml handle"""
|
||||
|
||||
def get_process_info(nv_process):
|
||||
"""Get the process information of specific pid"""
|
||||
process = {}
|
||||
if nv_process.pid not in GPUStatCollection.global_processes:
|
||||
GPUStatCollection.global_processes[nv_process.pid] = \
|
||||
psutil.Process(pid=nv_process.pid)
|
||||
ps_process = GPUStatCollection.global_processes[nv_process.pid]
|
||||
process['username'] = ps_process.username()
|
||||
# cmdline returns full path;
|
||||
# as in `ps -o comm`, get short cmdnames.
|
||||
_cmdline = ps_process.cmdline()
|
||||
if not _cmdline:
|
||||
# sometimes, zombie or unknown (e.g. [kworker/8:2H])
|
||||
process['command'] = '?'
|
||||
process['full_command'] = ['?']
|
||||
else:
|
||||
process['command'] = os.path.basename(_cmdline[0])
|
||||
process['full_command'] = _cmdline
|
||||
# Bytes to MBytes
|
||||
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
|
||||
process['cpu_percent'] = ps_process.cpu_percent()
|
||||
process['cpu_memory_usage'] = \
|
||||
round((ps_process.memory_percent() / 100.0) *
|
||||
psutil.virtual_memory().total)
|
||||
process['pid'] = nv_process.pid
|
||||
return process
|
||||
|
||||
if not GPUStatCollection._gpu_device_info.get(index):
|
||||
name = _decode(N.nvmlDeviceGetName(handle))
|
||||
uuid = _decode(N.nvmlDeviceGetUUID(handle))
|
||||
GPUStatCollection._gpu_device_info[index] = (name, uuid)
|
||||
|
||||
name, uuid = GPUStatCollection._gpu_device_info[index]
|
||||
|
||||
try:
|
||||
temperature = N.nvmlDeviceGetTemperature(
|
||||
handle, N.NVML_TEMPERATURE_GPU
|
||||
)
|
||||
except N.NVMLError:
|
||||
temperature = None # Not supported
|
||||
|
||||
try:
|
||||
fan_speed = N.nvmlDeviceGetFanSpeed(handle)
|
||||
except N.NVMLError:
|
||||
fan_speed = None # Not supported
|
||||
|
||||
try:
|
||||
memory = N.nvmlDeviceGetMemoryInfo(handle) # in Bytes
|
||||
except N.NVMLError:
|
||||
memory = None # Not supported
|
||||
|
||||
try:
|
||||
utilization = N.nvmlDeviceGetUtilizationRates(handle)
|
||||
except N.NVMLError:
|
||||
utilization = None # Not supported
|
||||
|
||||
try:
|
||||
power = N.nvmlDeviceGetPowerUsage(handle)
|
||||
except N.NVMLError:
|
||||
power = None
|
||||
|
||||
try:
|
||||
power_limit = N.nvmlDeviceGetEnforcedPowerLimit(handle)
|
||||
except N.NVMLError:
|
||||
power_limit = None
|
||||
|
||||
try:
|
||||
nv_comp_processes = \
|
||||
N.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||
except N.NVMLError:
|
||||
nv_comp_processes = None # Not supported
|
||||
try:
|
||||
nv_graphics_processes = \
|
||||
N.nvmlDeviceGetGraphicsRunningProcesses(handle)
|
||||
except N.NVMLError:
|
||||
nv_graphics_processes = None # Not supported
|
||||
|
||||
if not per_process_stats or (nv_comp_processes is None and nv_graphics_processes is None):
|
||||
processes = None
|
||||
else:
|
||||
processes = []
|
||||
nv_comp_processes = nv_comp_processes or []
|
||||
nv_graphics_processes = nv_graphics_processes or []
|
||||
for nv_process in nv_comp_processes + nv_graphics_processes:
|
||||
try:
|
||||
process = get_process_info(nv_process)
|
||||
processes.append(process)
|
||||
except psutil.NoSuchProcess:
|
||||
# TODO: add some reminder for NVML broken context
|
||||
# e.g. nvidia-smi reset or reboot the system
|
||||
pass
|
||||
|
||||
# TODO: Do not block if full process info is not requested
|
||||
time.sleep(0.1)
|
||||
for process in processes:
|
||||
pid = process['pid']
|
||||
cache_process = GPUStatCollection.global_processes[pid]
|
||||
process['cpu_percent'] = cache_process.cpu_percent()
|
||||
|
||||
index = N.nvmlDeviceGetIndex(handle)
|
||||
gpu_info = {
|
||||
'index': index,
|
||||
'uuid': uuid,
|
||||
'name': name,
|
||||
'temperature.gpu': temperature,
|
||||
'fan.speed': fan_speed,
|
||||
'utilization.gpu': utilization.gpu if utilization else None,
|
||||
'power.draw': power // 1000 if power is not None else None,
|
||||
'enforced.power.limit': power_limit // 1000
|
||||
if power_limit is not None else None,
|
||||
# Convert bytes into MBytes
|
||||
'memory.used': memory.used // MB if memory else None,
|
||||
'memory.total': memory.total // MB if memory else None,
|
||||
'processes': processes,
|
||||
}
|
||||
if per_process_stats:
|
||||
GPUStatCollection.clean_processes()
|
||||
return gpu_info
|
||||
|
||||
# 1. get the list of gpu and status
|
||||
gpu_list = []
|
||||
if GPUStatCollection._device_count is None:
|
||||
GPUStatCollection._device_count = N.nvmlDeviceGetCount()
|
||||
|
||||
for index in range(GPUStatCollection._device_count):
|
||||
handle = N.nvmlDeviceGetHandleByIndex(index)
|
||||
gpu_info = get_gpu_info(index, handle)
|
||||
gpu_stat = GPUStat(gpu_info)
|
||||
gpu_list.append(gpu_stat)
|
||||
|
||||
# 2. additional info (driver version, etc).
|
||||
if get_driver_info:
|
||||
try:
|
||||
driver_version = _decode(N.nvmlSystemGetDriverVersion())
|
||||
except N.NVMLError:
|
||||
driver_version = None # N/A
|
||||
else:
|
||||
driver_version = None
|
||||
|
||||
# no need to shutdown:
|
||||
if shutdown:
|
||||
N.nvmlShutdown()
|
||||
GPUStatCollection._initialized = False
|
||||
|
||||
return GPUStatCollection(gpu_list, driver_version=driver_version)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.gpus)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.gpus)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.gpus[index]
|
||||
|
||||
def __repr__(self):
|
||||
s = 'GPUStatCollection(host=%s, [\n' % self.hostname
|
||||
s += '\n'.join(' ' + str(g) for g in self.gpus)
|
||||
s += '\n])'
|
||||
return s
|
||||
|
||||
# --- Printing Functions ---
|
||||
def jsonify(self):
|
||||
return {
|
||||
'hostname': self.hostname,
|
||||
'query_time': self.query_time,
|
||||
"gpus": [g.jsonify() for g in self]
|
||||
}
|
||||
|
||||
def print_json(self, fp=sys.stdout):
|
||||
def date_handler(obj):
|
||||
if hasattr(obj, 'isoformat'):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
raise TypeError(type(obj))
|
||||
|
||||
o = self.jsonify()
|
||||
json.dump(o, fp, indent=4, separators=(',', ': '),
|
||||
default=date_handler)
|
||||
fp.write('\n')
|
||||
fp.flush()
|
||||
|
||||
|
||||
def new_query(shutdown=False, per_process_stats=False, get_driver_info=False):
|
||||
'''
|
||||
Obtain a new GPUStatCollection instance by querying nvidia-smi
|
||||
to get the list of GPUs and running process information.
|
||||
'''
|
||||
return GPUStatCollection.new_query(shutdown=shutdown, per_process_stats=per_process_stats,
|
||||
get_driver_info=get_driver_info)
|
||||
1877
trains_agent/helper/gpu/pynvml.py
Normal file
1877
trains_agent/helper/gpu/pynvml.py
Normal file
File diff suppressed because it is too large
Load Diff
0
trains_agent/helper/package/__init__.py
Normal file
0
trains_agent/helper/package/__init__.py
Normal file
107
trains_agent/helper/package/base.py
Normal file
107
trains_agent/helper/package/base.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import abc
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Iterable, Union
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
|
||||
from trains_agent.helper.process import Executable, Argv, PathLike
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PackageManager(object):
|
||||
"""
|
||||
ABC for classes providing python package management interface
|
||||
"""
|
||||
|
||||
_selected_manager = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def bin(self):
|
||||
# type: () -> PathLike
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def remove(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def install_from_file(self, path):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def freeze(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_requirements(self, requirements):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def install_packages(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
"""
|
||||
Install packages, upgrading depends on config
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _install(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
"""
|
||||
Run install command
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def uninstall_packages(self, *packages):
|
||||
# type: (Iterable[Text]) -> None
|
||||
pass
|
||||
|
||||
def upgrade_pip(self):
|
||||
return self._install("pip", "--upgrade")
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
# type: (...) -> Executable
|
||||
return Argv(self.bin, *extra)
|
||||
|
||||
@contextmanager
|
||||
def temp_file(self, prefix, contents, suffix=".txt"):
|
||||
# type: (Union[Text, Iterable[Text]], Iterable[Text], Text) -> Text
|
||||
"""
|
||||
Write contents to a temporary file, yielding its path. Finally, delete it.
|
||||
:param prefix: file name prefix
|
||||
:param contents: text lines to write
|
||||
:param suffix: file name suffix
|
||||
"""
|
||||
f, temp_path = mkstemp(suffix=suffix, prefix=prefix)
|
||||
with f:
|
||||
f.write(
|
||||
contents
|
||||
if isinstance(contents, six.text_type)
|
||||
else join_lines(contents)
|
||||
)
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
if not self.session.debug_mode:
|
||||
safe_remove_file(temp_path)
|
||||
|
||||
def set_selected_package_manager(self):
|
||||
# set this instance as the selected package manager
|
||||
# this is helpful when we want out of context requirement installations
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
@classmethod
|
||||
def out_of_scope_install_package(cls, package_name):
|
||||
if PackageManager._selected_manager is not None:
|
||||
try:
|
||||
return PackageManager._selected_manager._install(package_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
365
trains_agent/helper/package/conda_api.py
Normal file
365
trains_agent/helper/package/conda_api.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
|
||||
|
||||
import yaml
|
||||
from attr import attrs, attrib, Factory
|
||||
from pathlib2 import Path
|
||||
from semantic_version import Version
|
||||
from requirements import parse
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
from .requirements import RequirementsManager, MarkerRequirement
|
||||
|
||||
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
|
||||
|
||||
|
||||
def package_set(packages):
|
||||
return set(map(package_normalize, packages))
|
||||
|
||||
|
||||
def _package_diff(path, packages):
|
||||
# type: (Union[Path, Text], Iterable[Text]) -> Set[Text]
|
||||
return package_set(Path(path).read_text().splitlines()) - package_set(packages)
|
||||
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
if not self.source:
|
||||
return super(CondaPip, self).run_with_env(command, output=output, **kwargs)
|
||||
command = CommandSequence(self.source, Argv("pip", *command))
|
||||
return (command.get_output if output else command.check_call)(
|
||||
stdin=DEVNULL, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class CondaAPI(PackageManager):
|
||||
|
||||
"""
|
||||
A programmatic interface for controlling conda
|
||||
"""
|
||||
|
||||
MINIMUM_VERSION = Version("4.3.30", partial=True)
|
||||
|
||||
def __init__(self, session, path, python, requirements_manager):
|
||||
# type: (Session, PathLike, float, RequirementsManager) -> None
|
||||
"""
|
||||
:param python: base python version to use (e.g python3.6)
|
||||
:param path: path of env
|
||||
"""
|
||||
self.session = session
|
||||
self.python = python
|
||||
self.source = None
|
||||
self.requirements_manager = requirements_manager
|
||||
self.path = path
|
||||
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
self.conda = (
|
||||
find_executable("conda")
|
||||
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
|
||||
)
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output()
|
||||
except subprocess.CalledProcessError as ex:
|
||||
raise CommandFailedError(
|
||||
"Unable to determine conda version: {ex}, output={ex.output}".format(
|
||||
ex=ex
|
||||
)
|
||||
)
|
||||
self.conda_version = self.get_conda_version(output)
|
||||
if Version(self.conda_version, partial=True) < self.MINIMUM_VERSION:
|
||||
raise CommandFailedError(
|
||||
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
|
||||
self.conda_version, self.MINIMUM_VERSION
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_conda_version(output):
|
||||
match = re.search(r"(\d+\.){0,2}\d+", output)
|
||||
if not match:
|
||||
raise CommandFailedError("Unidentified conda version string:", output)
|
||||
return match.group(0)
|
||||
|
||||
@property
|
||||
def bin(self):
|
||||
return self.pip.bin
|
||||
|
||||
def upgrade_pip(self):
|
||||
return self.pip.upgrade_pip()
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create a new environment
|
||||
"""
|
||||
output = Argv(
|
||||
self.conda,
|
||||
"create",
|
||||
"--yes",
|
||||
"--mkdir",
|
||||
"--prefix",
|
||||
self.path,
|
||||
"python={}".format(self.python),
|
||||
).get_output(stderr=DEVNULL)
|
||||
match = re.search(
|
||||
r"\W*(.*activate) ({})".format(re.escape(str(self.path))), output
|
||||
)
|
||||
self.source = self.pip.source = (
|
||||
tuple(match.group(1).split()) + (match.group(2),)
|
||||
if match
|
||||
else ("activate", self.path)
|
||||
)
|
||||
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
try:
|
||||
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
|
||||
if cuda_version > 0:
|
||||
self._install('cudatoolkit={:.1f}'.format(cuda_version))
|
||||
except Exception:
|
||||
pass
|
||||
return self
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Delete a conda environment.
|
||||
Use 'conda env remove', then 'rm_tree' to be safe.
|
||||
|
||||
Conda seems to load "vcruntime140.dll" from all its environment on startup.
|
||||
This means environment have to be deleted using 'conda env remove'.
|
||||
If necessary, conda can be fooled into deleting a partially-deleted environment by creating an empty file
|
||||
in '<ENV>\conda-meta\history' (value found in 'conda.gateways.disk.test.PREFIX_MAGIC_FILE').
|
||||
Otherwise, it complains that said directory is not a conda environment.
|
||||
|
||||
See: https://github.com/conda/conda/issues/7682
|
||||
"""
|
||||
try:
|
||||
self._run_command(("env", "remove", "-p", self.path))
|
||||
except Exception:
|
||||
pass
|
||||
rm_tree(self.path)
|
||||
|
||||
def _install_from_file(self, path):
|
||||
"""
|
||||
Install packages from requirement file.
|
||||
"""
|
||||
self._install("--file", path)
|
||||
|
||||
def _install(self, *args):
|
||||
# type: (*PathLike) -> ()
|
||||
channels_args = tuple(
|
||||
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
|
||||
)
|
||||
self._run_command(("install", "-p", self.path) + channels_args + args)
|
||||
|
||||
def _get_pip_packages(self, packages):
|
||||
# type: (Iterable[Text]) -> Sequence[Text]
|
||||
"""
|
||||
Return subset of ``packages`` which are not available on conda
|
||||
"""
|
||||
pips = []
|
||||
while True:
|
||||
with self.temp_file("conda_reqs", packages) as path:
|
||||
try:
|
||||
self._install_from_file(path)
|
||||
except PackageNotFoundError as e:
|
||||
pips.append(e.pkg)
|
||||
packages = _package_diff(path, {e.pkg})
|
||||
else:
|
||||
break
|
||||
return pips
|
||||
|
||||
def install_packages(self, *packages):
|
||||
# type: (*Text) -> ()
|
||||
return self._install(*packages)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
return self._run_command(("uninstall", "-p", self.path))
|
||||
|
||||
def install_from_file(self, path):
|
||||
"""
|
||||
Try to install packages from conda. Install packages which are not available from conda with pip.
|
||||
"""
|
||||
try:
|
||||
self._install_from_file(path)
|
||||
return
|
||||
except PackageNotFoundError as e:
|
||||
pip_packages = [e.pkg]
|
||||
except PackagesNotFoundError as e:
|
||||
pip_packages = package_set(e.packages)
|
||||
with self.temp_file("conda_reqs", _package_diff(path, pip_packages)) as reqs:
|
||||
self.install_from_file(reqs)
|
||||
with self.temp_file("pip_reqs", pip_packages) as reqs:
|
||||
self.pip.install_from_file(reqs)
|
||||
|
||||
def freeze(self):
|
||||
# result = yaml.load(
|
||||
# self._run_command((self.conda, "env", "export", "-p", self.path), raw=True)
|
||||
# )
|
||||
# for key in "name", "prefix":
|
||||
# result.pop(key, None)
|
||||
# freeze = {"conda": result}
|
||||
# try:
|
||||
# freeze["pip"] = result["dependencies"][-1]["pip"]
|
||||
# except (TypeError, KeyError):
|
||||
# freeze["pip"] = []
|
||||
# else:
|
||||
# del result["dependencies"][-1]
|
||||
# return freeze
|
||||
return self.pip.freeze()
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']]
|
||||
pip_requirements = []
|
||||
|
||||
while reqs:
|
||||
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if not bad_req:
|
||||
break
|
||||
|
||||
solved = False
|
||||
for bad_r in bad_req:
|
||||
name = bad_r.split('[')[0].split('=')[0]
|
||||
# look for name in requirements
|
||||
for r in reqs:
|
||||
if r.name.lower() == name.lower():
|
||||
pip_requirements.append(r)
|
||||
reqs.remove(r)
|
||||
solved = True
|
||||
break
|
||||
|
||||
# we couldn't remove even one package,
|
||||
# nothing we can do but try pip
|
||||
if not solved:
|
||||
pip_requirements.extend(reqs)
|
||||
break
|
||||
|
||||
if pip_requirements:
|
||||
try:
|
||||
pip_req_str = [r.tostr() for r in pip_requirements]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
self.pip.load_requirements('\n'.join(pip_req_str))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
self.requirements_manager.post_install()
|
||||
return True
|
||||
|
||||
def _parse_conda_result_bad_packges(self, result_dict):
|
||||
if not result_dict:
|
||||
return None
|
||||
|
||||
if 'bad_deps' in result_dict and result_dict['bad_deps']:
|
||||
return result_dict['bad_deps']
|
||||
|
||||
if result_dict.get('error'):
|
||||
error_lines = result_dict['error'].split('\n')
|
||||
if error_lines[0].strip().lower().startswith("unsatisfiableerror:"):
|
||||
empty_lines = [i for i, l in enumerate(error_lines) if not l.strip()]
|
||||
if len(empty_lines) >= 2:
|
||||
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
|
||||
try:
|
||||
return yaml.load('\n'.join(deps))
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _run_command(self, command, raw=False, **kwargs):
|
||||
# type: (Iterable[Text], bool, Any) -> Union[Dict, Text]
|
||||
"""
|
||||
Run a conda command, returning JSON output.
|
||||
The command is prepended with 'conda' and run with JSON output flags.
|
||||
:param command: command to run
|
||||
:param raw: return text output and don't change command
|
||||
:param kwargs: kwargs for Argv.get_output()
|
||||
:return: JSON output or text output
|
||||
"""
|
||||
command = Argv(*command) # type: Executable
|
||||
if not raw:
|
||||
command = (self.conda,) + command + ("--quiet", "--json")
|
||||
try:
|
||||
print('Executing Conda: {}'.format(command.serialize()))
|
||||
result = command.get_output(stdin=DEVNULL, **kwargs)
|
||||
except Exception as e:
|
||||
if raw:
|
||||
raise
|
||||
result = e.output if hasattr(e, 'output') else ''
|
||||
if raw:
|
||||
return result
|
||||
result = json.loads(result) if result else {}
|
||||
if result.get('success', False):
|
||||
print('Pass')
|
||||
elif result.get('error'):
|
||||
print('Conda error: {}'.format(result.get('error')))
|
||||
return result
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on unhashable exceptions
|
||||
exception = attrs(str=True, cmp=False)
|
||||
|
||||
|
||||
@exception
|
||||
class CondaException(Exception, NonStrictAttrs):
|
||||
command = attrib()
|
||||
message = attrib(default=None)
|
||||
|
||||
|
||||
@exception
|
||||
class UnknownCondaError(CondaException):
|
||||
data = attrib(default=Factory(dict))
|
||||
|
||||
|
||||
@exception
|
||||
class PackagesNotFoundError(CondaException):
|
||||
"""
|
||||
Conda 4.5 exception - this reports all missing packages.
|
||||
"""
|
||||
|
||||
packages = attrib(default=())
|
||||
|
||||
|
||||
@exception
|
||||
class PackageNotFoundError(CondaException):
|
||||
"""
|
||||
Conda 4.3 exception - this reports one missing package at a time,
|
||||
as a singleton YAML list.
|
||||
"""
|
||||
|
||||
pkg = attrib(default="", converter=lambda val: yaml.load(val)[0].replace(" ", ""))
|
||||
25
trains_agent/helper/package/cython_req.py
Normal file
25
trains_agent/helper/package/cython_req.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class CythonRequirement(SimpleSubstitution):
|
||||
|
||||
name = "cython"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CythonRequirement, self).__init__(*args, **kwargs)
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return self.name == req.name.lower()
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# install Cython before
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
32
trains_agent/helper/package/horovod_req.py
Normal file
32
trains_agent/helper/package/horovod_req.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class HorovodRequirement(SimpleSubstitution):
|
||||
|
||||
name = "horovod"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(HorovodRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return self.name == req.name.lower()
|
||||
|
||||
def post_install(self):
|
||||
if self.post_install_req:
|
||||
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
|
||||
self.post_install_req = None
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req = req
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
0
trains_agent/helper/package/pip_api/__init__.py
Normal file
0
trains_agent/helper/package/pip_api/__init__.py
Normal file
92
trains_agent/helper/package/pip_api/system.py
Normal file
92
trains_agent/helper/package/pip_api/system.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import Text
|
||||
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES, PROGRAM_NAME
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
|
||||
|
||||
class SystemPip(PackageManager):
|
||||
|
||||
indices_args = None
|
||||
|
||||
def __init__(self, interpreter=None):
|
||||
# type: (Text) -> ()
|
||||
"""
|
||||
Program interface to the system pip.
|
||||
"""
|
||||
self._bin = interpreter or sys.executable
|
||||
|
||||
@property
|
||||
def bin(self):
|
||||
return self._bin
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
def remove(self):
|
||||
pass
|
||||
|
||||
def install_from_file(self, path):
|
||||
self.run_with_env(('install', '-r', path) + self.install_flags())
|
||||
|
||||
def install_packages(self, *packages):
|
||||
self._install(*(packages + self.install_flags()))
|
||||
|
||||
def _install(self, *args):
|
||||
self.run_with_env(('install',) + args)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
self.run_with_env(('uninstall', '-y') + packages)
|
||||
|
||||
def download_package(self, package, cache_dir):
|
||||
self.run_with_env(
|
||||
(
|
||||
'download',
|
||||
package,
|
||||
'--dest', cache_dir,
|
||||
'--no-deps',
|
||||
) + self.install_flags()
|
||||
)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
requirements = requirements.get('pip') if isinstance(requirements, dict) else requirements
|
||||
if not requirements:
|
||||
return
|
||||
with self.temp_file('cached-reqs', requirements) as path:
|
||||
self.install_from_file(path)
|
||||
|
||||
def uninstall(self, package):
|
||||
self.run_with_env(('uninstall', '-y', package))
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
pip freeze to all install packages except the running program
|
||||
:return: Dict contains pip as key and pip's packages to install
|
||||
:rtype: Dict[str: List[str]]
|
||||
"""
|
||||
packages = self.run_with_env(('freeze',), output=True).splitlines()
|
||||
packages_without_program = [package for package in packages if PROGRAM_NAME not in package]
|
||||
return {'pip': packages_without_program}
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
"""
|
||||
Run a shell command using environment from a virtualenv script
|
||||
:param command: command to run
|
||||
:type command: Iterable[Text]
|
||||
:param output: return output
|
||||
:param kwargs: kwargs for get_output/check_output command
|
||||
"""
|
||||
command = self._make_command(command)
|
||||
return (command.get_output if output else command.check_call)(stdin=DEVNULL, **kwargs)
|
||||
|
||||
def _make_command(self, command):
|
||||
return Argv(self.bin, '-m', 'pip', *command)
|
||||
|
||||
def install_flags(self):
|
||||
if self.indices_args is None:
|
||||
self.indices_args = tuple(
|
||||
chain.from_iterable(('--extra-index-url', x) for x in PIP_EXTRA_INDICES)
|
||||
)
|
||||
return self.indices_args
|
||||
77
trains_agent/helper/package/pip_api/venv.py
Normal file
77
trains_agent/helper/package/pip_api/venv.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session
|
||||
from ..pip_api.system import SystemPip
|
||||
from ..requirements import RequirementsManager
|
||||
|
||||
|
||||
class VirtualenvPip(SystemPip, PackageManager):
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
|
||||
"""
|
||||
Program interface to virtualenv pip.
|
||||
Must be given either path to virtualenv or source command.
|
||||
Either way, ``self.source`` is exposed.
|
||||
:param python: interpreter path
|
||||
:param path: path of virtual environment to create/manipulate
|
||||
:param python: python version
|
||||
:param interpreter: path of python interpreter
|
||||
"""
|
||||
super(VirtualenvPip, self).__init__(
|
||||
interpreter
|
||||
or Path(
|
||||
path,
|
||||
select_for_platform(linux="bin/python", windows="scripts/python.exe"),
|
||||
)
|
||||
)
|
||||
self.session = session
|
||||
self.path = path
|
||||
self.requirements_manager = requirements_manager
|
||||
self.python = "python{}".format(python)
|
||||
|
||||
def _make_command(self, command):
|
||||
return self.session.command(self.bin, "-m", "pip", *command)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
if isinstance(requirements, dict) and requirements.get("pip"):
|
||||
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
|
||||
super(VirtualenvPip, self).load_requirements(requirements)
|
||||
self.requirements_manager.post_install()
|
||||
|
||||
def create_flags(self):
|
||||
"""
|
||||
Configurable environment creation arguments
|
||||
"""
|
||||
return Argv.conditional_flag(
|
||||
self.session.config["agent.package_manager.system_site_packages"],
|
||||
"--system-site-packages",
|
||||
)
|
||||
|
||||
def install_flags(self):
|
||||
"""
|
||||
Configurable package installation creation arguments
|
||||
"""
|
||||
return super(VirtualenvPip, self).install_flags() + Argv.conditional_flag(
|
||||
self.session.config["agent.package_manager.force_upgrade"], "--upgrade"
|
||||
)
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create virtualenv.
|
||||
Only valid if instantiated with path.
|
||||
Use self.python as self.bin does not exist.
|
||||
"""
|
||||
self.session.command(
|
||||
self.python, "-m", "virtualenv", self.path, *self.create_flags()
|
||||
).check_call()
|
||||
return self
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Delete virtualenv.
|
||||
Only valid if instantiated with path.
|
||||
"""
|
||||
rm_tree(self.path)
|
||||
98
trains_agent/helper/package/poetry_api.py
Normal file
98
trains_agent/helper/package/poetry_api.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from functools import wraps
|
||||
|
||||
import attr
|
||||
from pathlib2 import Path
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from trains_agent.session import Session, POETRY
|
||||
|
||||
|
||||
def prop_guard(prop, log_prop=None):
|
||||
assert isinstance(prop, property)
|
||||
assert not log_prop or isinstance(log_prop, property)
|
||||
|
||||
def decorator(func):
|
||||
message = "%s:%s calling {}, {} = %s".format(
|
||||
func.__name__, prop.fget.__name__
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def new_func(self, *args, **kwargs):
|
||||
prop_value = prop.fget(self)
|
||||
if log_prop:
|
||||
log_prop.fget(self).debug(
|
||||
message,
|
||||
type(self).__name__,
|
||||
"" if prop_value else " not",
|
||||
prop_value,
|
||||
)
|
||||
if prop_value:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PoetryConfig:
|
||||
|
||||
def __init__(self, session):
|
||||
# type: (Session) -> ()
|
||||
self.session = session
|
||||
self._log = session.get_logger(__name__)
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return self._log
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.session.config["agent.package_manager.type"] == POETRY
|
||||
|
||||
_guard_enabled = prop_guard(enabled, log)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
func = kwargs.pop("func", Argv.get_output)
|
||||
kwargs.setdefault("stdin", DEVNULL)
|
||||
argv = Argv("poetry", "-n", *args)
|
||||
self.log.debug("running: %s", argv)
|
||||
return func(argv, **kwargs)
|
||||
|
||||
def _config(self, *args, **kwargs):
|
||||
return self.run("config", *args, **kwargs)
|
||||
|
||||
@_guard_enabled
|
||||
def initialize(self):
|
||||
self._config("settings.virtualenvs.in-project", "true")
|
||||
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
|
||||
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
|
||||
|
||||
def get_api(self, path):
|
||||
# type: (Path) -> PoetryAPI
|
||||
return PoetryAPI(self, path)
|
||||
|
||||
|
||||
@attr.s
|
||||
class PoetryAPI(object):
|
||||
config = attr.ib(type=PoetryConfig)
|
||||
path = attr.ib(type=Path, converter=Path)
|
||||
|
||||
INDICATOR_FILES = "pyproject.toml", "poetry.lock"
|
||||
|
||||
def install(self):
|
||||
# type: () -> bool
|
||||
if self.enabled:
|
||||
self.config.run("install", cwd=str(self.path), func=Argv.check_call)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.config.enabled and (
|
||||
any((self.path / indicator).exists() for indicator in self.INDICATOR_FILES)
|
||||
)
|
||||
|
||||
def freeze(self):
|
||||
return {"poetry": self.config.run("show", cwd=str(self.path)).splitlines()}
|
||||
|
||||
def get_python_command(self, extra):
|
||||
return Argv("poetry", "run", "python", *extra)
|
||||
595
trains_agent/helper/package/pytorch.py
Normal file
595
trains_agent/helper/package/pytorch.py
Normal file
@@ -0,0 +1,595 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
import sys
|
||||
from furl import furl
|
||||
import urllib.parse
|
||||
from operator import itemgetter
|
||||
from html.parser import HTMLParser
|
||||
from typing import Text
|
||||
|
||||
import attr
|
||||
import requests
|
||||
from semantic_version import Version, Spec
|
||||
|
||||
import six
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError
|
||||
|
||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||
|
||||
|
||||
def os_to_wheel_name(x):
|
||||
return OS_TO_WHEEL_NAME[x]
|
||||
|
||||
|
||||
def fix_version(version):
|
||||
def replace(nums, prerelease):
|
||||
if prerelease:
|
||||
return "{}-{}".format(nums, prerelease)
|
||||
return nums
|
||||
|
||||
return re.sub(
|
||||
r"(\d+(?:\.\d+){,2})(?:\.(.*))?",
|
||||
lambda match: replace(*match.groups()),
|
||||
version,
|
||||
)
|
||||
|
||||
|
||||
class LinksHTMLParser(HTMLParser):
|
||||
def __init__(self):
|
||||
super(LinksHTMLParser, self).__init__()
|
||||
self.links = []
|
||||
|
||||
def handle_data(self, data):
|
||||
if data and data.strip():
|
||||
self.links += [data]
|
||||
|
||||
|
||||
@attr.s
|
||||
class PytorchWheel(object):
|
||||
os_name = attr.ib(type=str, converter=os_to_wheel_name)
|
||||
cuda_version = attr.ib(converter=lambda x: "cu{}".format(x) if x else "cpu")
|
||||
python = attr.ib(type=str, converter=lambda x: str(x).replace(".", ""))
|
||||
torch_version = attr.ib(type=str, converter=fix_version)
|
||||
|
||||
url_template = (
|
||||
"http://download.pytorch.org/whl/"
|
||||
"{0.cuda_version}/torch-{0.torch_version}-cp{0.python}-cp{0.python}m{0.unicode}-{0.os_name}.whl"
|
||||
)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.unicode = "u" if self.python.startswith("2") else ""
|
||||
|
||||
def make_url(self):
|
||||
# type: () -> Text
|
||||
return self.url_template.format(self)
|
||||
|
||||
|
||||
class PytorchResolutionError(FatalSpecsResolutionError):
|
||||
pass
|
||||
|
||||
|
||||
class SimplePytorchRequirement(SimpleSubstitution):
|
||||
name = "torch"
|
||||
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
|
||||
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
|
||||
torch_page_lookup = {
|
||||
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
|
||||
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
|
||||
90: 'https://download.pytorch.org/whl/cu90/torch_stable.html',
|
||||
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
|
||||
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
|
||||
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SimplePytorchRequirement, self).__init__(*args, **kwargs)
|
||||
self._matched = False
|
||||
|
||||
def match(self, req):
|
||||
# match both any of out packages
|
||||
return req.name in self.packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Get rid of +cpu +cu?? etc.
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except:
|
||||
pass
|
||||
self._matched = True
|
||||
return Text(req)
|
||||
|
||||
def matching_done(self, reqs, package_manager):
|
||||
# type: (Sequence[MarkerRequirement], object) -> ()
|
||||
if not self._matched:
|
||||
return
|
||||
# TODO: add conda channel support
|
||||
from .pip_api.system import SystemPip
|
||||
if package_manager and isinstance(package_manager, SystemPip):
|
||||
extra_url, _ = self.get_torch_page(self.cuda_version)
|
||||
package_manager.add_extra_install_flags(('-f', extra_url))
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version):
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except:
|
||||
cuda = 0
|
||||
# first check if key is valid
|
||||
if cuda in cls.torch_page_lookup:
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
|
||||
# then try a new cuda version page
|
||||
torch_url = cls.page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
|
||||
for k in keys:
|
||||
if k <= cuda:
|
||||
return cls.torch_page_lookup[k], k
|
||||
# return default - zero
|
||||
return cls.torch_page_lookup[0], 0
|
||||
|
||||
|
||||
class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
name = "torch"
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os_name = kwargs.pop("os_override", None)
|
||||
super(PytorchRequirement, self).__init__(*args, **kwargs)
|
||||
self.log = self._session.get_logger(__name__)
|
||||
self.package_manager = self.config["agent.package_manager.type"].lower()
|
||||
self.os = os_name or self.get_platform()
|
||||
self.cuda = "cuda{}".format(self.cuda_version).lower()
|
||||
self.python_version_string = str(self.config["agent.default_python"])
|
||||
self.python_semantic_version = Version.coerce(
|
||||
self.python_version_string, partial=True
|
||||
)
|
||||
self.python = "python{}.{}".format(self.python_semantic_version.major, self.python_semantic_version.minor)
|
||||
|
||||
self.exceptions = [
|
||||
PytorchResolutionError(message)
|
||||
for message in (
|
||||
None,
|
||||
'cuda version "{}" is not supported'.format(self.cuda),
|
||||
'python version "{}" is not supported'.format(
|
||||
self.python_version_string
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
self.validate_python_version()
|
||||
except PytorchResolutionError as e:
|
||||
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
|
||||
|
||||
@property
|
||||
def is_conda(self):
|
||||
return self.package_manager == "conda"
|
||||
|
||||
@property
|
||||
def is_pip(self):
|
||||
return not self.is_conda
|
||||
|
||||
def validate_python_version(self):
|
||||
"""
|
||||
Make sure python version has both major and minor versions as required for choosing pytorch wheel
|
||||
"""
|
||||
if self.is_pip and not (
|
||||
self.python_semantic_version.major and self.python_semantic_version.minor
|
||||
):
|
||||
raise PytorchResolutionError(
|
||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||
"must have both major and minor parts of the version (for example: '3.7')".format(
|
||||
self.python_version_string
|
||||
)
|
||||
)
|
||||
|
||||
def match(self, req):
|
||||
return req.name in self.packages
|
||||
|
||||
@staticmethod
|
||||
def get_platform():
|
||||
if sys.platform == "linux":
|
||||
return "linux"
|
||||
if sys.platform == "win32" or sys.platform == "cygwin":
|
||||
return "windows"
|
||||
if sys.platform == "darwin":
|
||||
return "macos"
|
||||
raise RuntimeError("unrecognized OS")
|
||||
|
||||
def _get_link_from_torch_page(self, req, torch_url):
|
||||
links_parser = LinksHTMLParser()
|
||||
links_parser.feed(requests.get(torch_url, timeout=10).text)
|
||||
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
|
||||
py_ver = "{0.major}{0.minor}".format(self.python_semantic_version)
|
||||
url = None
|
||||
# search for our package
|
||||
for l in links_parser.links:
|
||||
parts = l.split('/')[-1].split('-')
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
if parts[0] != req.name:
|
||||
continue
|
||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||
if parts[1].split('%')[0].split('+')[0] != req.specs[0][1]:
|
||||
continue
|
||||
if not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if platform_wheel not in parts[4]:
|
||||
continue
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
break
|
||||
|
||||
return url
|
||||
|
||||
def get_url_for_platform(self, req):
|
||||
assert self.package_manager == "pip"
|
||||
assert self.os != "mac"
|
||||
assert req.specs
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except:
|
||||
pass
|
||||
op, version = req.specs[0]
|
||||
# assert op == "=="
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version:
|
||||
if not url:
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
torch_version=fix_version(version),
|
||||
python="{0.major}{0.minor}".format(self.python_semantic_version),
|
||||
os_name=self.os,
|
||||
cuda_version=self.cuda_version,
|
||||
).make_url()
|
||||
if url:
|
||||
# normalize url (sometimes we will get ../ which we should not...
|
||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||
|
||||
self.log.debug("checking url: %s", url)
|
||||
return url, requests.head(url, timeout=10).ok
|
||||
|
||||
@staticmethod
|
||||
def match_version(req, options):
|
||||
versioned_options = sorted(
|
||||
((Version(fix_version(key)), value) for key, value in options.items()),
|
||||
key=itemgetter(0),
|
||||
reverse=True,
|
||||
)
|
||||
req.specs = [(op, fix_version(version)) for op, version in req.specs]
|
||||
if req.specs:
|
||||
specs = Spec(req.format_specs())
|
||||
else:
|
||||
specs = None
|
||||
try:
|
||||
return next(
|
||||
replacement
|
||||
for version, replacement in versioned_options
|
||||
if not specs or version in specs
|
||||
)
|
||||
except StopIteration:
|
||||
raise PytorchResolutionError(
|
||||
'Could not find wheel for "{}", '
|
||||
"Available versions: {}".format(req, list(options))
|
||||
)
|
||||
|
||||
def replace_conda(self, req):
|
||||
spec = "".join(req.specs[0]) if req.specs else ""
|
||||
if not self.cuda_version:
|
||||
return "pytorch-cpu{spec}\ntorchvision-cpu".format(spec=spec)
|
||||
return "pytorch{spec}\ntorchvision\ncuda{self.cuda_version}".format(
|
||||
self=self, spec=spec
|
||||
)
|
||||
|
||||
def _table_lookup(self, req):
|
||||
"""
|
||||
Look for pytorch wheel matching `req` in table
|
||||
:param req: python requirement
|
||||
"""
|
||||
def check(base_, key_, exception_):
|
||||
result = base_.get(key_)
|
||||
if not result:
|
||||
if key_.startswith('cuda'):
|
||||
print('Could not locate, {}'.format(exception_))
|
||||
ver = sorted([float(a.replace('cuda', '').replace('none', '0')) for a in base_.keys()], reverse=True)[0]
|
||||
key_ = 'cuda'+str(int(ver))
|
||||
result = base_.get(key_)
|
||||
print('Reverting to \"{}\"'.format(key_))
|
||||
if not result:
|
||||
raise exception_
|
||||
return result
|
||||
raise exception_
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
if self.is_conda:
|
||||
return self.replace_conda(req)
|
||||
|
||||
base = self.MAP
|
||||
for key, exception in zip((self.os, self.cuda, self.python), self.exceptions):
|
||||
base = check(base, key, exception)
|
||||
|
||||
return self.match_version(req, base).replace(" ", "\n")
|
||||
|
||||
def replace(self, req):
|
||||
try:
|
||||
return self._replace(req)
|
||||
except Exception as e:
|
||||
message = "Exception when trying to resolve python wheel"
|
||||
self.log.debug(message, exc_info=True)
|
||||
raise PytorchResolutionError("{}: {}".format(message, e))
|
||||
|
||||
def _replace(self, req):
|
||||
self.validate_python_version()
|
||||
try:
|
||||
result, ok = self.get_url_for_platform(req)
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
result = self._table_lookup(req)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
else:
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
|
||||
self.log.debug(
|
||||
"Could not find Pytorch wheel in table, trying manually constructing URL"
|
||||
)
|
||||
result = ok = None
|
||||
# try:
|
||||
# result, ok = self.get_url_for_platform(req)
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
if not ok:
|
||||
if result:
|
||||
self.log.debug("URL not found: {}".format(result))
|
||||
exc = PytorchResolutionError(
|
||||
"Was not able to find pytorch wheel URL: {}".format(exc)
|
||||
)
|
||||
# cancel exception chaining
|
||||
six.raise_from(exc, None)
|
||||
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
|
||||
MAP = {
|
||||
"windows": {
|
||||
"cuda100": {
|
||||
"python3.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-win_amd64.whl"
|
||||
},
|
||||
"python3.6": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda92": {
|
||||
"python3.7": {
|
||||
"0.4.1",
|
||||
"http://download.pytorch.org/whl/cu92/torch-0.4.1-cp37-cp37m-win_amd64.whl",
|
||||
},
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda91": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-win_amd64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-win_amd64.whl"
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda90": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cuda80": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
"cudanone": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-win_amd64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-win_amd64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-win_amd64.whl",
|
||||
},
|
||||
"python2.7": PytorchResolutionError(
|
||||
"PyTorch does not support Python 2.7 on Windows"
|
||||
),
|
||||
},
|
||||
},
|
||||
"macos": {
|
||||
"cuda100": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda92": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda91": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda90": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cuda80": PytorchResolutionError(
|
||||
"MacOS Binaries dont support CUDA, install from source if CUDA is needed"
|
||||
),
|
||||
"cudanone": {
|
||||
"python3.6": {"0.4.0": "torch"},
|
||||
"python3.5": {"0.4.0": "torch"},
|
||||
"python2.7": {"0.4.0": "torch"},
|
||||
},
|
||||
},
|
||||
"linux": {
|
||||
"cuda100": {
|
||||
"python3.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python3.6": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu100/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.1": "http://download.pytorch.org/whl/cu100/torch-1.0.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.1.0": "http://download.pytorch.org/whl/cu100/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.2.0": "http://download.pytorch.org/whl/cu100/torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cuda92": {
|
||||
"python3.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp37-cp37m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp35-cp35m-manylinux1_x86_64.whl"
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu92/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.2.0": "https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp27-cp27mu-manylinux1_x86_64.whl"
|
||||
},
|
||||
},
|
||||
"cuda91": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl"
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl"
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu91/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl"
|
||||
},
|
||||
},
|
||||
"cuda90": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cu90/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu90/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cuda80": {
|
||||
"python3.6": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp35-cp35m-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.1": "http://download.pytorch.org/whl/cu80/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl",
|
||||
"0.3.1": "torch==0.3.1",
|
||||
"0.3.0.post4": "torch==0.3.0.post4",
|
||||
"0.1.2.post1": "torch==0.1.2.post1",
|
||||
"0.1.2": "torch==0.1.2",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cu80/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
"cudanone": {
|
||||
"python3.6": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl",
|
||||
},
|
||||
"python3.5": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl",
|
||||
},
|
||||
"python2.7": {
|
||||
"0.4.0": "http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
"1.0.0": "http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
340
trains_agent/helper/package/requirements.py
Normal file
340
trains_agent/helper/package/requirements.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from itertools import chain, starmap
|
||||
from operator import itemgetter
|
||||
from os import path
|
||||
from typing import Text, List, Type, Optional, Tuple
|
||||
|
||||
import semantic_version
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
from requirements import parse
|
||||
# noinspection PyPackageRequirements
|
||||
from requirements.requirement import Requirement
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES
|
||||
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session, normalize_cuda_version
|
||||
from .translator import RequirementsTranslator
|
||||
|
||||
|
||||
class SpecsResolutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FatalSpecsResolutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@six.python_2_unicode_compatible
|
||||
class MarkerRequirement(object):
|
||||
|
||||
def __init__(self, req): # type: (Requirement) -> None
|
||||
self.req = req
|
||||
|
||||
@property
|
||||
def marker(self):
|
||||
match = re.search(r';\s*(.*)', self.req.line)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def tostr(self, markers=True):
|
||||
if not self.uri:
|
||||
parts = [self.name]
|
||||
|
||||
if self.extras:
|
||||
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
|
||||
|
||||
if self.specifier:
|
||||
parts.append(self.format_specs())
|
||||
|
||||
else:
|
||||
parts = [self.uri]
|
||||
|
||||
if markers and self.marker:
|
||||
parts.append('; {0}'.format(self.marker))
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
__str__ = tostr
|
||||
|
||||
def __repr__(self):
|
||||
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
||||
|
||||
def format_specs(self):
|
||||
return ','.join(starmap(operator.add, self.specs))
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.req, item)
|
||||
|
||||
@property
|
||||
def specs(self): # type: () -> List[Tuple[Text, Text]]
|
||||
return self.req.specs
|
||||
|
||||
@specs.setter
|
||||
def specs(self, value): # type: (List[Tuple[Text, Text]]) -> None
|
||||
self.req.specs = value
|
||||
|
||||
def fix_specs(self):
|
||||
def solve_by(func, op_is, specs):
|
||||
return func([(op, version) for op, version in specs if op == op_is])
|
||||
|
||||
def solve_equal(specs):
|
||||
if len(set(version for _, version in self.specs)) > 1:
|
||||
raise SpecsResolutionError('more than one "==" spec: {}'.format(specs))
|
||||
return specs
|
||||
greater = solve_by(lambda specs: [max(specs, key=itemgetter(1))], '<=', self.specs)
|
||||
smaller = solve_by(lambda specs: [min(specs, key=itemgetter(1))], '>=', self.specs)
|
||||
equal = solve_by(solve_equal, '==', self.specs)
|
||||
if equal:
|
||||
self.specs = equal
|
||||
else:
|
||||
self.specs = greater + smaller
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class RequirementSubstitution(object):
|
||||
|
||||
_pip_extra_index_url = PIP_EXTRA_INDICES
|
||||
|
||||
def __init__(self, session):
|
||||
# type: (Session) -> ()
|
||||
self._session = session
|
||||
self.config = session.config # type: ConfigTree
|
||||
self.suffix = '.post{config[agent.cuda_version]}.dev{config[agent.cudnn_version]}'.format(config=self.config)
|
||||
self.package_manager = self.config['agent.package_manager.type']
|
||||
|
||||
@abstractmethod
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
"""
|
||||
Returns whether a requirement needs to be modified by this substitution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def replace(self, req): # type: (MarkerRequirement) -> Text
|
||||
"""
|
||||
Replace a requirement
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_install(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_pip_version(cls, package):
|
||||
output = Argv(
|
||||
'pip',
|
||||
'search',
|
||||
package,
|
||||
*(chain.from_iterable(('-i', x) for x in cls._pip_extra_index_url))
|
||||
).get_output()
|
||||
# ad-hoc pattern to duplicate the behavior of the old code
|
||||
return re.search(r'{} \((\d+\.\d+\.[^.]+)'.format(package), output).group(1)
|
||||
|
||||
@property
|
||||
def cuda_version(self):
|
||||
return self.config['agent.cuda_version']
|
||||
|
||||
@property
|
||||
def cudnn_version(self):
|
||||
return self.config['agent.cudnn_version']
|
||||
|
||||
|
||||
class SimpleSubstitution(RequirementSubstitution):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self):
|
||||
pass
|
||||
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
return (self.name == req.name or (
|
||||
req.uri and
|
||||
re.match(r'https?://', req.uri) and
|
||||
self.name in req.uri
|
||||
))
|
||||
|
||||
def replace(self, req): # type: (MarkerRequirement) -> Text
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
if req.uri:
|
||||
return re.sub(
|
||||
r'({})(.*?)(-cp)'.format(self.name),
|
||||
r'\1\2{}\3'.format(self.suffix),
|
||||
req.uri,
|
||||
count=1)
|
||||
|
||||
if req.specs:
|
||||
_, version_number = req.specs[0]
|
||||
assert semantic_version.Version(version_number, partial=True)
|
||||
else:
|
||||
version_number = self.get_pip_version(self.name)
|
||||
|
||||
req.specs = [('==', version_number + self.suffix)]
|
||||
return Text(req)
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class CudaSensitiveSubstitution(SimpleSubstitution):
|
||||
|
||||
def match(self, req): # type: (MarkerRequirement) -> bool
|
||||
return self.cuda_version and self.cudnn_version and \
|
||||
super(CudaSensitiveSubstitution, self).match(req)
|
||||
|
||||
|
||||
class CudaNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequirementsManager(object):
|
||||
|
||||
def __init__(self, session, base_interpreter=None):
|
||||
# type: (Session, PathLike) -> ()
|
||||
self._session = session
|
||||
self.config = deepcopy(session.config) # type: ConfigTree
|
||||
self.handlers = [] # type: List[RequirementSubstitution]
|
||||
agent = self.config['agent']
|
||||
self.active = not agent.get('cpu_only', False)
|
||||
self.found_cuda = False
|
||||
if self.active:
|
||||
try:
|
||||
agent['cuda_version'], agent['cudnn_version'] = self.get_cuda_version(self.config)
|
||||
self.found_cuda = True
|
||||
except Exception:
|
||||
# if we have a cuda version, it is good enough (we dont have to have cudnn version)
|
||||
if agent.get('cuda_version'):
|
||||
self.found_cuda = True
|
||||
pip_cache_dir = Path(self.config["agent.pip_download_cache.path"]).expanduser() / (
|
||||
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
|
||||
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
|
||||
cache_dir=pip_cache_dir.as_posix())
|
||||
|
||||
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
|
||||
self.handlers.append(cls(self._session))
|
||||
|
||||
def _replace_one(self, req): # type: (MarkerRequirement) -> Optional[Text]
|
||||
match = re.search(r';\s*(.*)', Text(req))
|
||||
if match:
|
||||
req.markers = match.group(1).split(',')
|
||||
if not self.active:
|
||||
return None
|
||||
for handler in self.handlers:
|
||||
if handler.match(req):
|
||||
return handler.replace(req)
|
||||
return None
|
||||
|
||||
def replace(self, requirements): # type: (Text) -> Text
|
||||
parsed_requirements = tuple(
|
||||
map(
|
||||
MarkerRequirement,
|
||||
filter(
|
||||
None,
|
||||
parse(requirements)
|
||||
if isinstance(requirements, six.text_type)
|
||||
else (next(parse(line), None) for line in requirements)
|
||||
)
|
||||
)
|
||||
)
|
||||
if not parsed_requirements:
|
||||
# return the original requirements just in case
|
||||
return requirements
|
||||
|
||||
def replace_one(i, req):
|
||||
# type: (int, MarkerRequirement) -> Optional[Text]
|
||||
try:
|
||||
return self._replace_one(req)
|
||||
except FatalSpecsResolutionError:
|
||||
raise
|
||||
except Exception:
|
||||
warning('could not find installed CUDA/CuDNN version for {}, '
|
||||
'using original requirements line: {}'.format(req, i))
|
||||
return None
|
||||
|
||||
new_requirements = tuple(replace_one(i, req) for i, req in enumerate(parsed_requirements))
|
||||
conda = is_conda(self.config)
|
||||
result = map(
|
||||
lambda x, y: (x if x is not None else y.tostr(markers=not conda)),
|
||||
new_requirements,
|
||||
parsed_requirements
|
||||
)
|
||||
if not conda:
|
||||
result = map(self.translator.translate, result)
|
||||
return join_lines(result)
|
||||
|
||||
def post_install(self):
|
||||
for h in self.handlers:
|
||||
try:
|
||||
h.post_install()
|
||||
except Exception as ex:
|
||||
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
|
||||
|
||||
@staticmethod
|
||||
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
|
||||
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
|
||||
cuda_version = config['agent.cuda_version']
|
||||
cudnn_version = config['agent.cudnn_version']
|
||||
if cuda_version and cudnn_version:
|
||||
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
output = Argv('nvcc', '--version').get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'release (.{3})', output).group(1)
|
||||
cuda_version = Text(int(float(match) * 10))
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
output = Argv('nvidia-smi',).get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'CUDA Version: ([0-9]+).([0-9]+)', output)
|
||||
match = match.group(1)+'.'+match.group(2)
|
||||
cuda_version = Text(int(float(match) * 10))
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cudnn_version:
|
||||
try:
|
||||
cuda_lib = which('nvcc')
|
||||
if is_windows_platform:
|
||||
cudnn_h = path.sep.join(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h'])
|
||||
else:
|
||||
cudnn_h = path.join(path.sep, *(cuda_lib.split(path.sep)[:-2] + ['include', 'cudnn.h']))
|
||||
|
||||
cudnn_major, cudnn_minor = None, None
|
||||
try:
|
||||
include_file = open(cudnn_h)
|
||||
except OSError:
|
||||
raise CudaNotFound('Could not read cudnn.h')
|
||||
with include_file:
|
||||
for line in include_file:
|
||||
if 'CUDNN_MAJOR' in line:
|
||||
cudnn_major = line.split()[-1]
|
||||
if 'CUDNN_MINOR' in line:
|
||||
cudnn_minor = line.split()[-1]
|
||||
if cudnn_major and cudnn_minor:
|
||||
break
|
||||
cudnn_version = cudnn_major + (cudnn_minor or '0')
|
||||
except:
|
||||
pass
|
||||
|
||||
return (normalize_cuda_version(cuda_version or 0),
|
||||
normalize_cuda_version(cudnn_version or 0))
|
||||
|
||||
63
trains_agent/helper/package/translator.py
Normal file
63
trains_agent/helper/package/translator.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import Text
|
||||
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.config import Config
|
||||
from .pip_api.system import SystemPip
|
||||
|
||||
|
||||
class RequirementsTranslator(object):
|
||||
|
||||
"""
|
||||
Translate explicit URLs to local URLs after downloading them to cache
|
||||
"""
|
||||
|
||||
SUPPORTED_SCHEMES = ["http", "https", "ftp"]
|
||||
|
||||
def __init__(self, session, interpreter=None, cache_dir=None):
|
||||
self._session = session
|
||||
config = session.config
|
||||
self.cache_dir = cache_dir or Path(config["agent.pip_download_cache.path"]).expanduser().as_posix()
|
||||
self.enabled = config["agent.pip_download_cache.enabled"]
|
||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.config = Config()
|
||||
self.pip = SystemPip(interpreter=interpreter)
|
||||
|
||||
def download(self, url):
|
||||
self.pip.download_package(url, cache_dir=self.cache_dir)
|
||||
|
||||
@classmethod
|
||||
def is_supported_link(cls, line):
|
||||
# type: (Text) -> bool
|
||||
"""
|
||||
Return whether requirement is a link that should be downloaded to cache
|
||||
"""
|
||||
url = furl(line)
|
||||
return (
|
||||
url.scheme
|
||||
and url.scheme.lower() in cls.SUPPORTED_SCHEMES
|
||||
and line.lstrip().lower().startswith(url.scheme.lower())
|
||||
)
|
||||
|
||||
def translate(self, line):
|
||||
"""
|
||||
If requirement is supported, download it to cache and return the download path
|
||||
"""
|
||||
if not (self.enabled and self.is_supported_link(line)):
|
||||
return line
|
||||
command = self.config.command
|
||||
command.log('Downloading "{}" to pip cache'.format(line))
|
||||
url = furl(line)
|
||||
try:
|
||||
wheel_name = url.path.segments[-1]
|
||||
except IndexError:
|
||||
command.error('Could not parse wheel name of "{}"'.format(line))
|
||||
return line
|
||||
try:
|
||||
self.download(line)
|
||||
downloaded = Path(self.cache_dir, wheel_name).expanduser().as_uri()
|
||||
except Exception:
|
||||
command.error('Could not download wheel name of "{}"'.format(line))
|
||||
return line
|
||||
return downloaded
|
||||
115
trains_agent/helper/package/venv_update_api.py
Normal file
115
trains_agent/helper/package/venv_update_api.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Optional, Text
|
||||
|
||||
import requests
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import CONFIG_DIR
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
|
||||
|
||||
class VenvUpdateAPI(VirtualenvPip):
|
||||
URL_FILE_PATH = Path(CONFIG_DIR, "venv-update-url.txt")
|
||||
SCRIPT_PATH = Path(CONFIG_DIR, "venv-update")
|
||||
|
||||
def __init__(self, url, *args, **kwargs):
|
||||
super(VenvUpdateAPI, self).__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
self._script_path = None
|
||||
self._first_install = True
|
||||
|
||||
@property
|
||||
def downloaded_venv_url(self):
|
||||
# type: () -> Optional[Text]
|
||||
try:
|
||||
return self.URL_FILE_PATH.read_text()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
@downloaded_venv_url.setter
|
||||
def downloaded_venv_url(self, value):
|
||||
self.URL_FILE_PATH.write_text(value)
|
||||
|
||||
def _check_script_validity(self, path):
|
||||
"""
|
||||
Make sure script in ``path`` is a valid python script
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
result = Argv(self.bin, path, "--version").call(
|
||||
stdout=DEVNULL, stderr=DEVNULL, stdin=DEVNULL
|
||||
)
|
||||
return result == 0
|
||||
|
||||
@property
|
||||
def script_path(self):
|
||||
# type: () -> Text
|
||||
if not self._script_path:
|
||||
self._script_path = self.SCRIPT_PATH
|
||||
if not (
|
||||
self._script_path.exists()
|
||||
and self.downloaded_venv_url
|
||||
and self.downloaded_venv_url == self.url
|
||||
and self._check_script_validity(self._script_path)
|
||||
):
|
||||
with self._script_path.open("wb") as f:
|
||||
for data in requests.get(self.url, stream=True):
|
||||
f.write(data)
|
||||
self.downloaded_venv_url = self.url
|
||||
return self._script_path
|
||||
|
||||
def install_from_file(self, path):
|
||||
first_install = (
|
||||
Argv(
|
||||
self.python,
|
||||
six.text_type(self.script_path),
|
||||
"venv=",
|
||||
"-p",
|
||||
self.python,
|
||||
self.path,
|
||||
)
|
||||
+ self.create_flags()
|
||||
+ ("install=", "-r", path)
|
||||
+ self.install_flags()
|
||||
)
|
||||
later_install = first_install + (
|
||||
"pip-command=",
|
||||
"pip-faster",
|
||||
"install",
|
||||
"--upgrade", # no --prune
|
||||
)
|
||||
self._choose_install(first_install, later_install)
|
||||
|
||||
def install_packages(self, *packages):
|
||||
first_install = (
|
||||
Argv(
|
||||
self.python,
|
||||
six.text_type(self.script_path),
|
||||
"venv=",
|
||||
self.path,
|
||||
"install=",
|
||||
)
|
||||
+ packages
|
||||
)
|
||||
later_install = first_install + (
|
||||
"pip-command=",
|
||||
"pip-faster",
|
||||
"install",
|
||||
"--upgrade", # no --prune
|
||||
)
|
||||
self._choose_install(first_install, later_install)
|
||||
|
||||
def _choose_install(self, first, rest):
|
||||
if self._first_install:
|
||||
command = first
|
||||
self._first_install = False
|
||||
else:
|
||||
command = rest
|
||||
command.check_call(stdin=DEVNULL)
|
||||
|
||||
def upgrade_pip(self):
|
||||
"""
|
||||
pip and venv-update versions are coupled, venv-update installs the latest compatible pip
|
||||
"""
|
||||
pass
|
||||
365
trains_agent/helper/process.py
Normal file
365
trains_agent/helper/process.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from distutils.spawn import find_executable
|
||||
from itertools import chain, repeat, islice
|
||||
from os.path import devnull
|
||||
from typing import Union, Text, Sequence, Any, TypeVar, Callable
|
||||
|
||||
import psutil
|
||||
from furl import furl
|
||||
from future.builtins import super
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import PROGRAM_NAME, CONFIG_FILE
|
||||
from trains_agent.helper.base import bash_c, is_windows_platform, select_for_platform, chain_map
|
||||
|
||||
PathLike = Union[Text, Path]
|
||||
|
||||
|
||||
def get_bash_output(cmd, strip=False, stderr=subprocess.STDOUT, stdin=False):
|
||||
try:
|
||||
output = (
|
||||
subprocess.check_output(
|
||||
bash_c().split() + [cmd],
|
||||
stderr=stderr,
|
||||
stdin=subprocess.PIPE if stdin else None,
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
output = None
|
||||
return output if not strip or not output else output.strip()
|
||||
|
||||
|
||||
def kill_all_child_processes(pid=None):
|
||||
# get current process if pid not provided
|
||||
include_parent = True
|
||||
if not pid:
|
||||
pid = os.getpid()
|
||||
include_parent = False
|
||||
print("\nLeaving process id {}".format(pid))
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.Error:
|
||||
# could not find parent process id
|
||||
return
|
||||
for child in parent.children(recursive=True):
|
||||
child.kill()
|
||||
if include_parent:
|
||||
parent.kill()
|
||||
|
||||
|
||||
def check_if_command_exists(cmd):
|
||||
return bool(find_executable(cmd))
|
||||
|
||||
|
||||
def get_program_invocation():
|
||||
return [sys.executable, "-u", "-m", PROGRAM_NAME.replace('-', '_')]
|
||||
|
||||
|
||||
Retval = TypeVar("Retval")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Executable(object):
|
||||
@abc.abstractmethod
|
||||
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
|
||||
# type: (Callable[..., Retval]) -> Retval
|
||||
pass
|
||||
|
||||
def get_output(self, *args, **kwargs):
|
||||
return (
|
||||
self.call_subprocess(subprocess.check_output, *args, **kwargs)
|
||||
.decode("utf8")
|
||||
.rstrip()
|
||||
)
|
||||
|
||||
def check_call(self, *args, **kwargs):
|
||||
return self.call_subprocess(subprocess.check_call, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def normalize_exception(censor_password=False):
|
||||
try:
|
||||
yield
|
||||
except subprocess.CalledProcessError as e:
|
||||
if censor_password:
|
||||
e.cmd = [furl(word).remove(password=True).tostr() for word in e.cmd]
|
||||
|
||||
if e.output and not isinstance(e.output, six.text_type):
|
||||
e.output = e.output.decode()
|
||||
raise
|
||||
|
||||
@abc.abstractmethod
|
||||
def pretty(self):
|
||||
pass
|
||||
|
||||
|
||||
class Argv(Executable):
|
||||
|
||||
ARGV_SEPARATOR = " "
|
||||
|
||||
def __init__(self, *argv, **kwargs):
|
||||
# type: (*PathLike, Any) -> ()
|
||||
"""
|
||||
Object representing a series of strings used to invoke a process.
|
||||
"""
|
||||
self.argv = argv
|
||||
self._log = kwargs.pop("log", None)
|
||||
if not self._log:
|
||||
self._log = logging.getLogger(__name__)
|
||||
self._log.propagate = False
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
Returns a string of the shell command
|
||||
"""
|
||||
return self.ARGV_SEPARATOR.join(map(quote, self))
|
||||
|
||||
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
|
||||
self._log.debug("running: %s: %s", func.__name__, list(self))
|
||||
with self.normalize_exception(censor_password):
|
||||
return func(list(self), *args, **kwargs)
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return self.call_subprocess(subprocess.call, *args, **kwargs)
|
||||
|
||||
def get_argv(self):
|
||||
return self.argv
|
||||
|
||||
def __repr__(self):
|
||||
return "<Argv{}>".format(self.argv)
|
||||
|
||||
def __str__(self):
|
||||
return "Executing: {}".format(self.argv)
|
||||
|
||||
def __iter__(self):
|
||||
return (six.text_type(word) for word in self.argv)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.argv[item]
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
iter(other)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return type(self)(*(self.argv + tuple(other)), log=self._log)
|
||||
|
||||
def __radd__(self, other):
|
||||
try:
|
||||
iter(other)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return type(self)(*(tuple(other) + self.argv), log=self._log)
|
||||
|
||||
pretty = serialize
|
||||
|
||||
@staticmethod
|
||||
def conditional_flag(condition, flag, *flags):
|
||||
# type: (Any, PathLike, PathLike) -> Sequence[PathLike]
|
||||
"""
|
||||
Translate a boolean to a flag command like arguments.
|
||||
:param condition: condition to translate to flag
|
||||
:param flag: flag to use if condition true (at least one)
|
||||
:param flags: additional flags to use if condition is true
|
||||
"""
|
||||
return (flag,) + flags if condition else ()
|
||||
|
||||
|
||||
class CommandSequence(Executable):
|
||||
|
||||
JOIN_COMMAND_OPERATOR = "&&"
|
||||
|
||||
def __init__(self, *commands, **kwargs):
|
||||
"""
|
||||
Object representing a sequence of shell commands.
|
||||
:param commands: Command elements. Each CommandSequence will be treated as a single command-line argument.
|
||||
:type commands: Each command: [str] | Argv
|
||||
"""
|
||||
self._log = kwargs.pop("log", None)
|
||||
if not self._log:
|
||||
self._log = logging.getLogger(__name__)
|
||||
self._log.propagate = False
|
||||
self.commands = []
|
||||
for c in commands:
|
||||
if isinstance(c, CommandSequence):
|
||||
self.commands.extend(deepcopy(c.commands))
|
||||
elif isinstance(c, Argv):
|
||||
self.commands.append(deepcopy(c))
|
||||
else:
|
||||
self.commands.append(Argv(*c, log=self._log))
|
||||
|
||||
def get_argv(self, shell=False):
|
||||
"""
|
||||
Get array of argv's.
|
||||
:param bool shell: if True, returns the argv of a process that will invoke a shell running the command sequence
|
||||
"""
|
||||
if shell:
|
||||
return tuple(bash_c().split()) + (self.serialize(),)
|
||||
|
||||
def safe_get_argv(obj):
|
||||
try:
|
||||
func = obj.get_argv
|
||||
except AttributeError:
|
||||
result = obj
|
||||
else:
|
||||
result = func()
|
||||
return tuple(map(str, result))
|
||||
|
||||
return tuple(map(safe_get_argv, self.commands))
|
||||
|
||||
def serialize(self):
|
||||
def intersperse(delimiter, seq):
|
||||
return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None)
|
||||
|
||||
def normalize(command):
|
||||
return list(command) if is_windows_platform() else command.serialize()
|
||||
|
||||
return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands))))
|
||||
|
||||
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
|
||||
with self.normalize_exception(censor_password):
|
||||
return func(
|
||||
self.serialize(),
|
||||
*args,
|
||||
**chain_map(
|
||||
dict(
|
||||
executable=select_for_platform(linux="bash", windows=None),
|
||||
shell=True,
|
||||
),
|
||||
kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
tab = " " * 4
|
||||
return "<{}(\n{}{},\n)>".format(
|
||||
type(self).__name__, tab, (",\n" + tab).join(map(repr, self.commands))
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.commands)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.commands[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.commands[key] = value
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
iter(other)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
return type(self)(*(self.commands + tuple(other)))
|
||||
|
||||
def pretty(self):
|
||||
serialized = self.serialize()
|
||||
if is_windows_platform():
|
||||
return " ".join(serialized)
|
||||
return serialized
|
||||
|
||||
|
||||
class WorkerParams(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_level="INFO",
|
||||
config_file=CONFIG_FILE,
|
||||
optimization=0,
|
||||
debug=False,
|
||||
trace=False,
|
||||
):
|
||||
self.trace = trace
|
||||
self.log_level = log_level
|
||||
self.optimization = optimization
|
||||
self.config_file = config_file
|
||||
self.debug = debug
|
||||
|
||||
def get_worker_flags(self):
|
||||
"""
|
||||
Serialize a WorkerParams instance to a tuple of command-line flags
|
||||
:param WorkerParams self: parameters of worker
|
||||
:return: a tuple of global flags and "workers execute/daemon" flags
|
||||
"""
|
||||
global_args = ("--config-file", str(self.config_file))
|
||||
if self.debug:
|
||||
global_args += ("--debug",)
|
||||
if self.trace:
|
||||
global_args += ("--trace",)
|
||||
worker_args = tuple()
|
||||
if self.optimization:
|
||||
worker_args += self.get_optimization_flag()
|
||||
return global_args, worker_args
|
||||
|
||||
def get_optimization_flag(self):
|
||||
return "-{}".format("O" * self.optimization)
|
||||
|
||||
def get_argv_for_command(self, command):
|
||||
"""
|
||||
Get argv for a particular worker command.
|
||||
"""
|
||||
global_args, worker_args = self.get_worker_flags()
|
||||
command_line = (
|
||||
tuple(get_program_invocation())
|
||||
+ global_args
|
||||
+ (command, )
|
||||
+ worker_args
|
||||
)
|
||||
return Argv(*command_line)
|
||||
|
||||
|
||||
class DaemonParams(WorkerParams):
|
||||
def __init__(self, foreground=False, queues=(), *args, **kwargs):
|
||||
super(DaemonParams, self).__init__(*args, **kwargs)
|
||||
self.foreground = foreground
|
||||
self.queues = tuple(queues)
|
||||
|
||||
def get_worker_flags(self):
|
||||
global_args, worker_args = super(DaemonParams, self).get_worker_flags()
|
||||
if self.foreground:
|
||||
worker_args += ("--foreground",)
|
||||
if self.queues:
|
||||
worker_args += ("--queue",) + self.queues
|
||||
return global_args, worker_args
|
||||
|
||||
|
||||
DEVNULL = open(devnull, "w+")
|
||||
SOURCE_COMMAND = select_for_platform(linux="source", windows="call")
|
||||
|
||||
|
||||
class ExitStatus(object):
|
||||
success = 0
|
||||
failure = 1
|
||||
interrupted = 2
|
||||
|
||||
|
||||
COMMAND_SUCCESS = 0
|
||||
|
||||
|
||||
_find_unsafe = re.compile(r"[^\w@%+=:,./-]", getattr(re, "ASCII", 0)).search
|
||||
|
||||
|
||||
def quote(s):
|
||||
"""
|
||||
Backport of shlex.quote():
|
||||
Return a shell-escaped version of the string *s*.
|
||||
"""
|
||||
if not s:
|
||||
return "''"
|
||||
if _find_unsafe(s) is None:
|
||||
return s
|
||||
|
||||
# use single quotes, and put single quotes into double quotes
|
||||
# the string $'b is then quoted as '$'"'"'b'
|
||||
return "'" + s.replace("'", "'\"'\"'") + "'"
|
||||
555
trains_agent/helper/repo.py
Normal file
555
trains_agent/helper/repo.py
Normal file
@@ -0,0 +1,555 @@
|
||||
import abc
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from distutils.spawn import find_executable
|
||||
from hashlib import md5
|
||||
from os import environ, getenv
|
||||
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
from trains_agent.helper.console import ensure_text, ensure_binary
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import (
|
||||
select_for_platform,
|
||||
rm_tree,
|
||||
ExecutionInfo,
|
||||
normalize_path,
|
||||
create_file_if_not_exists,
|
||||
)
|
||||
from trains_agent.helper.process import DEVNULL, Argv, PathLike, COMMAND_SUCCESS
|
||||
from trains_agent.session import Session
|
||||
|
||||
|
||||
class VcsFactory(object):
|
||||
"""
|
||||
Creates VCS instances
|
||||
"""
|
||||
|
||||
GIT_SUFFIX = ".git"
|
||||
|
||||
@classmethod
|
||||
def create(cls, session, execution_info, location):
|
||||
# type: (Session, ExecutionInfo, PathLike) -> VCS
|
||||
"""
|
||||
Create a VCS instance for config and url
|
||||
:param session: program session
|
||||
:param execution_info: task ExecutionInfo
|
||||
:param location: (desired) clone location
|
||||
"""
|
||||
url = execution_info.repository
|
||||
is_git = url.endswith(cls.GIT_SUFFIX)
|
||||
vcs_cls = Git if is_git else Hg
|
||||
revision = (
|
||||
execution_info.version_num
|
||||
or execution_info.tag
|
||||
or vcs_cls.remote_branch_name(execution_info.branch or vcs_cls.main_branch)
|
||||
)
|
||||
return vcs_cls(session, url, location, revision)
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
@attr.s
|
||||
class RepoInfo(object):
|
||||
"""
|
||||
Cloned repository information
|
||||
:param type: VCS type
|
||||
:param url: repository url
|
||||
:param branch: revision branch
|
||||
:param commit: revision number
|
||||
:param root: clone location path
|
||||
"""
|
||||
|
||||
type = attr.ib(type=str)
|
||||
url = attr.ib(type=str)
|
||||
branch = attr.ib(type=str)
|
||||
commit = attr.ib(type=str)
|
||||
root = attr.ib(type=str)
|
||||
|
||||
|
||||
RType = TypeVar("RType")
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class VCS(object):
|
||||
|
||||
"""
|
||||
Provides overloaded utilities for handling repositories of different types
|
||||
"""
|
||||
|
||||
# additional environment variables for VCS commands
|
||||
COMMAND_ENV = {}
|
||||
|
||||
PATCH_ADDED_FILE_RE = re.compile(r"^\+\+\+ b/(?P<path>.*)")
|
||||
|
||||
def __init__(self, session, url, location, revision):
|
||||
# type: (Session, Text, PathLike, Text) -> ()
|
||||
"""
|
||||
Create a VCS instance for config and url
|
||||
:param session: program session
|
||||
:param url: repository url
|
||||
:param location: (desired) clone location
|
||||
:param: desired clone revision
|
||||
"""
|
||||
self.session = session
|
||||
self.log = self.session.get_logger(
|
||||
"{}.{}".format(__name__, type(self).__name__)
|
||||
)
|
||||
self.url = url
|
||||
self.location = Text(location)
|
||||
self.revision = revision
|
||||
self.log = self.session.get_logger(__name__)
|
||||
|
||||
@property
|
||||
def url_with_auth(self):
|
||||
"""
|
||||
Return URL with configured user/password
|
||||
"""
|
||||
return self.add_auth(self.session.config, self.url)
|
||||
|
||||
@abc.abstractproperty
|
||||
def executable_name(self):
|
||||
"""
|
||||
Name of command executable
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def main_branch(self):
|
||||
"""
|
||||
Name of default/main branch
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def checkout_flags(self):
|
||||
# type: () -> Sequence[Text]
|
||||
"""
|
||||
Command-line flags for checkout command
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def patch_base(self):
|
||||
# type: () -> Sequence[Text]
|
||||
"""
|
||||
Command and flags for applying patches
|
||||
"""
|
||||
pass
|
||||
|
||||
def patch(self, location, patch_content):
|
||||
# type: (PathLike, Text) -> bool
|
||||
"""
|
||||
Apply patch repository at `location`
|
||||
"""
|
||||
self.log.info("applying diff to %s", location)
|
||||
|
||||
for match in filter(
|
||||
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
|
||||
):
|
||||
create_file_if_not_exists(normalize_path(location, match.group("path")))
|
||||
|
||||
return_code, errors = self.call_with_stdin(
|
||||
patch_content, *self.patch_base, cwd=location
|
||||
)
|
||||
if return_code:
|
||||
self.log.error("Failed applying diff")
|
||||
lines = errors.splitlines()
|
||||
if any(l for l in lines if "no such file or directory" in l.lower()):
|
||||
self.log.warning(
|
||||
"NOTE: files were not found when applying diff, perhaps you forgot to push your changes?"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
self.log.info("successfully applied uncommitted changes")
|
||||
return True
|
||||
|
||||
# Command-line flags for clone command
|
||||
clone_flags = ()
|
||||
|
||||
@abc.abstractmethod
|
||||
def executable_not_found_error_help(self):
|
||||
# type: () -> Text
|
||||
"""
|
||||
Instructions for when executable is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def remote_branch_name(branch):
|
||||
# type: (Text) -> Text
|
||||
"""
|
||||
Creates name of remote branch from name of local/ambiguous branch.
|
||||
Returns same name by default.
|
||||
"""
|
||||
return branch
|
||||
|
||||
# parse scp-like git ssh URLs, e.g: git@host:user/project.git
|
||||
SSH_URL_GIT_SYNTAX = re.compile(
|
||||
r"""
|
||||
^
|
||||
(?:(?P<user>{regular}*?)@)?
|
||||
(?P<host>{regular}*?)
|
||||
:
|
||||
(?P<path>{regular}.*)?
|
||||
$
|
||||
""".format(
|
||||
regular=r"[^/@:#]"
|
||||
),
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_ssh_url(cls, url):
|
||||
# type: (Text) -> Text
|
||||
"""
|
||||
Replace SSH URL with HTTPS URL when applicable
|
||||
"""
|
||||
|
||||
def get_username(user_, password=None):
|
||||
"""
|
||||
Remove special SSH users hg/git
|
||||
"""
|
||||
return (
|
||||
None
|
||||
if user_ and user_.lower() in ["hg", "git"] and not password
|
||||
else user_
|
||||
)
|
||||
|
||||
match = cls.SSH_URL_GIT_SYNTAX.match(url)
|
||||
if match:
|
||||
user, host, path = match.groups()
|
||||
return (
|
||||
furl()
|
||||
.set(scheme="https", username=get_username(user), host=host, path=path)
|
||||
.url
|
||||
)
|
||||
parsed_url = furl(url)
|
||||
if parsed_url.scheme == "ssh":
|
||||
return parsed_url.set(
|
||||
scheme="https",
|
||||
username=get_username(
|
||||
parsed_url.username, password=parsed_url.password
|
||||
),
|
||||
).url
|
||||
return url
|
||||
|
||||
def _set_ssh_url(self):
|
||||
"""
|
||||
Replace instance URL with SSH substitution result and report to log.
|
||||
According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
|
||||
"""
|
||||
if not self.session.config.agent.translate_ssh:
|
||||
return
|
||||
|
||||
ssh_agent_variable = "SSH_AUTH_SOCK"
|
||||
if not getenv(ssh_agent_variable) and (self.session.config.get('agent.git_user', None) and
|
||||
self.session.config.get('agent.git_pass', None)):
|
||||
new_url = self.resolve_ssh_url(self.url)
|
||||
if new_url != self.url:
|
||||
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
|
||||
self.url, new_url))
|
||||
self.url = new_url
|
||||
|
||||
def clone(self, branch=None):
|
||||
# type: (Text) -> None
|
||||
"""
|
||||
Clone repository to destination and checking out `branch`.
|
||||
If not in debug mode, filter VCS password from output.
|
||||
"""
|
||||
self._set_ssh_url()
|
||||
clone_command = ("clone", self.url_with_auth, self.location) + self.clone_flags
|
||||
if branch:
|
||||
clone_command += ("-b", branch)
|
||||
if self.session.debug_mode:
|
||||
self.call(*clone_command)
|
||||
return
|
||||
|
||||
def normalize_output(result):
|
||||
"""
|
||||
Returns result string without user's password.
|
||||
NOTE: ``self.get_stderr``'s result might or might not have the same type as ``e.output`` in case of error.
|
||||
"""
|
||||
string_type = (
|
||||
ensure_text
|
||||
if isinstance(result, six.text_type)
|
||||
else ensure_binary
|
||||
)
|
||||
return result.replace(
|
||||
string_type(self.url),
|
||||
string_type(furl(self.url).remove(password=True).tostr()),
|
||||
)
|
||||
|
||||
def print_output(output):
|
||||
print(ensure_text(output))
|
||||
|
||||
try:
|
||||
print_output(normalize_output(self.get_stderr(*clone_command)))
|
||||
except subprocess.CalledProcessError as e:
|
||||
# In Python 3, subprocess.CalledProcessError has a `stderr` attribute,
|
||||
# but since stderr is redirect to `subprocess.PIPE` it will appear in the usual `output` attribute
|
||||
if e.output:
|
||||
e.output = normalize_output(e.output)
|
||||
print_output(e.output)
|
||||
raise
|
||||
|
||||
def checkout(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Checkout repository at specified revision
|
||||
"""
|
||||
self.call("checkout", self.revision, *self.checkout_flags, cwd=self.location)
|
||||
|
||||
@abc.abstractmethod
|
||||
def pull(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Pull remote changes for revision
|
||||
"""
|
||||
pass
|
||||
|
||||
def call(self, *argv, **kwargs):
|
||||
"""
|
||||
Execute argv without stdout/stdin.
|
||||
Remove stdin so git/hg can't ask for passwords.
|
||||
``kwargs`` can override all arguments passed to subprocess.
|
||||
"""
|
||||
return self._call_subprocess(subprocess.check_call, argv, **kwargs)
|
||||
|
||||
def call_with_stdin(self, input_, *argv, **kwargs):
|
||||
# type: (...) -> Tuple[int, str]
|
||||
"""
|
||||
Run command with `input_` as stdin
|
||||
"""
|
||||
input_ = input_.encode("latin1")
|
||||
if not input_.endswith(b"\n"):
|
||||
input_ += b"\n"
|
||||
process = self._call_subprocess(
|
||||
subprocess.Popen,
|
||||
argv,
|
||||
**dict(
|
||||
kwargs,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
_, stderr = process.communicate(input_)
|
||||
if stderr:
|
||||
self.log.warning("%s: %s", self._get_vcs_command(argv), stderr)
|
||||
return process.returncode, Text(stderr)
|
||||
|
||||
def get_stderr(self, *argv, **kwargs):
|
||||
"""
|
||||
Execute argv without stdout/stdin in <cwd> and get stderr output.
|
||||
Remove stdin so git/hg can't ask for passwords.
|
||||
``kwargs`` can override all arguments passed to subprocess.
|
||||
"""
|
||||
process = self._call_subprocess(
|
||||
subprocess.Popen, argv, **dict(kwargs, stderr=subprocess.PIPE, stdout=None)
|
||||
)
|
||||
_, stderr = process.communicate()
|
||||
code = process.poll()
|
||||
if code == COMMAND_SUCCESS:
|
||||
return stderr
|
||||
with Argv.normalize_exception(censor_password=True):
|
||||
raise subprocess.CalledProcessError(
|
||||
returncode=code, cmd=argv, output=stderr
|
||||
)
|
||||
|
||||
def _call_subprocess(self, func, argv, **kwargs):
|
||||
# type: (Callable[..., RType], Iterable[Text], dict) -> RType
|
||||
cwd = kwargs.pop("cwd", None)
|
||||
cwd = cwd and str(cwd)
|
||||
kwargs = dict(
|
||||
dict(
|
||||
censor_password=True,
|
||||
cwd=cwd,
|
||||
stdin=DEVNULL,
|
||||
stdout=DEVNULL,
|
||||
env=dict(self.COMMAND_ENV, **environ),
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
command = self._get_vcs_command(argv)
|
||||
self.log.debug("Running: %s", list(command))
|
||||
return command.call_subprocess(func, **kwargs)
|
||||
|
||||
def _get_vcs_command(self, argv):
|
||||
# type: (Iterable[PathLike]) -> Argv
|
||||
return Argv(self.executable_name, *argv)
|
||||
|
||||
@classmethod
|
||||
def add_auth(cls, config, url):
|
||||
"""
|
||||
Add username and password to URL if missing from URL and present in config.
|
||||
Does not modify ssh URLs.
|
||||
"""
|
||||
parsed_url = furl(url)
|
||||
if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"):
|
||||
return parsed_url.url
|
||||
config_user = config.get("agent.{}_user".format(cls.executable_name), None)
|
||||
config_pass = config.get("agent.{}_pass".format(cls.executable_name), None)
|
||||
if (
|
||||
(not (parsed_url.username and parsed_url.password))
|
||||
and config_user
|
||||
and config_pass
|
||||
):
|
||||
parsed_url.set(username=config_user, password=config_pass)
|
||||
return parsed_url.url
|
||||
|
||||
@abc.abstractproperty
|
||||
def info_commands(self):
|
||||
# type: () -> Mapping[Text, Argv]
|
||||
"""
|
||||
` Mapping from `RepoInfo` attribute name (except `type`) to command which acquires it
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_repository_copy_info(self, path):
|
||||
"""
|
||||
Get `RepoInfo` instance from copy of clone in `path`
|
||||
"""
|
||||
path = Text(path)
|
||||
commands_result = {
|
||||
name: command.get_output(cwd=path)
|
||||
# name: subprocess.check_output(command.split(), cwd=path).decode().strip()
|
||||
for name, command in self.info_commands.items()
|
||||
}
|
||||
return RepoInfo(type=self.executable_name, **commands_result)
|
||||
|
||||
|
||||
class Git(VCS):
|
||||
executable_name = "git"
|
||||
main_branch = "master"
|
||||
clone_flags = ("--quiet", "--recursive")
|
||||
checkout_flags = ("--force",)
|
||||
COMMAND_ENV = {
|
||||
# do not prompt for password
|
||||
"GIT_TERMINAL_PROMPT": "0",
|
||||
# do not prompt for ssh key passphrase
|
||||
"GIT_SSH_COMMAND": "ssh -oBatchMode=yes",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def remote_branch_name(branch):
|
||||
return "origin/{}".format(branch)
|
||||
|
||||
def executable_not_found_error_help(self):
|
||||
return 'Cannot find "{}" executable. {}'.format(
|
||||
self.executable_name,
|
||||
select_for_platform(
|
||||
linux="You can install it by running: sudo apt-get install {}".format(
|
||||
self.executable_name
|
||||
),
|
||||
windows="You can download it here: {}".format(
|
||||
"https://gitforwindows.org/"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def pull(self):
|
||||
self.call("fetch", "origin", cwd=self.location)
|
||||
|
||||
info_commands = dict(
|
||||
url=Argv("git", "remote", "get-url", "origin"),
|
||||
branch=Argv("git", "rev-parse", "--abbrev-ref", "HEAD"),
|
||||
commit=Argv("git", "rev-parse", "HEAD"),
|
||||
root=Argv("git", "rev-parse", "--show-toplevel"),
|
||||
)
|
||||
|
||||
patch_base = ("apply",)
|
||||
|
||||
|
||||
class Hg(VCS):
|
||||
executable_name = "hg"
|
||||
main_branch = "default"
|
||||
checkout_flags = ("--clean",)
|
||||
patch_base = ("import", "--no-commit")
|
||||
|
||||
def executable_not_found_error_help(self):
|
||||
return 'Cannot find "{}" executable. {}'.format(
|
||||
self.executable_name,
|
||||
select_for_platform(
|
||||
linux="You can install it by running: sudo apt-get install {}".format(
|
||||
self.executable_name
|
||||
),
|
||||
windows="You can download it here: {}".format(
|
||||
"https://www.mercurial-scm.org/wiki/Download"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def pull(self):
|
||||
self.call(
|
||||
"pull",
|
||||
self.url_with_auth,
|
||||
cwd=self.location,
|
||||
*(("-r", self.revision) if self.revision else ())
|
||||
)
|
||||
|
||||
info_commands = dict(
|
||||
url=Argv("hg", "paths", "--verbose"),
|
||||
branch=Argv("hg", "--debug", "id", "-b"),
|
||||
commit=Argv("hg", "--debug", "id", "-i"),
|
||||
root=Argv("hg", "root"),
|
||||
)
|
||||
|
||||
|
||||
def clone_repository_cached(session, execution, destination):
|
||||
# type: (Session, ExecutionInfo, Path) -> Tuple[VCS, RepoInfo]
|
||||
"""
|
||||
Clone a remote repository.
|
||||
:param execution: execution info
|
||||
:param destination: directory to clone to (in which a directory for the repository will be created)
|
||||
:param session: program session
|
||||
:return: repository information
|
||||
:raises: CommandFailedError if git/hg is not installed
|
||||
"""
|
||||
repo_url = execution.repository # type: str
|
||||
parsed_url = furl(repo_url)
|
||||
no_password_url = parsed_url.copy().remove(password=True).url
|
||||
|
||||
clone_folder_name = Path(str(furl(repo_url).path)).name # type: str
|
||||
clone_folder = Path(destination) / clone_folder_name
|
||||
cached_repo_path = (
|
||||
Path(session.config["agent.vcs_cache.path"]).expanduser()
|
||||
/ "{}.{}".format(clone_folder_name, md5(ensure_binary(repo_url)).hexdigest())
|
||||
/ clone_folder_name
|
||||
) # type: Path
|
||||
|
||||
vcs = VcsFactory.create(
|
||||
session, execution_info=execution, location=cached_repo_path
|
||||
)
|
||||
if not find_executable(vcs.executable_name):
|
||||
raise CommandFailedError(vcs.executable_not_found_error_help())
|
||||
|
||||
if session.config["agent.vcs_cache.enabled"] and cached_repo_path.exists():
|
||||
print('Using cached repository in "{}"'.format(cached_repo_path))
|
||||
else:
|
||||
print("cloning: {}".format(no_password_url))
|
||||
rm_tree(cached_repo_path)
|
||||
vcs.clone(branch=execution.branch)
|
||||
|
||||
vcs.pull()
|
||||
vcs.checkout()
|
||||
|
||||
rm_tree(destination)
|
||||
shutil.copytree(Text(cached_repo_path), Text(clone_folder))
|
||||
if not clone_folder.is_dir():
|
||||
raise CommandFailedError(
|
||||
"copying of repository failed: from {} to {}".format(
|
||||
cached_repo_path, clone_folder
|
||||
)
|
||||
)
|
||||
|
||||
repo_info = vcs.get_repository_copy_info(clone_folder)
|
||||
|
||||
# make sure we have no user/pass in the returned repository structure
|
||||
repo_info = attr.evolve(repo_info, url=no_password_url)
|
||||
|
||||
return vcs, repo_info
|
||||
296
trains_agent/helper/resource_monitor.py
Normal file
296
trains_agent/helper/resource_monitor.py
Normal file
@@ -0,0 +1,296 @@
|
||||
from __future__ import unicode_literals, division
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from itertools import starmap
|
||||
from threading import Thread, Event
|
||||
from time import time
|
||||
from typing import Text, Sequence
|
||||
|
||||
import attr
|
||||
import psutil
|
||||
from pathlib2 import Path
|
||||
from trains_agent.session import Session
|
||||
|
||||
try:
|
||||
from .gpu import gpustat
|
||||
except ImportError:
|
||||
gpustat = None
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BytesSizes(object):
|
||||
@staticmethod
|
||||
def kilobytes(x):
|
||||
# type: (float) -> float
|
||||
return x / 1024
|
||||
|
||||
@staticmethod
|
||||
def megabytes(x):
|
||||
# type: (float) -> float
|
||||
return x / (1024*1024)
|
||||
|
||||
@staticmethod
|
||||
def gigabytes(x):
|
||||
# type: (float) -> float
|
||||
return x / (1024*1024*1024)
|
||||
|
||||
|
||||
class ResourceMonitor(object):
|
||||
@attr.s
|
||||
class StatusReport(object):
|
||||
task = attr.ib(default=None, type=str)
|
||||
queue = attr.ib(default=None, type=str)
|
||||
queues = attr.ib(default=None, type=Sequence[str])
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
key: value
|
||||
for key, value in attr.asdict(self).items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session, # type: Session
|
||||
worker_id, # type: ResourceMonitor.StatusReport,
|
||||
sample_frequency_per_sec=2.0,
|
||||
report_frequency_sec=30.0,
|
||||
first_report_sec=None,
|
||||
):
|
||||
self.session = session
|
||||
self.queue = deque(maxlen=1)
|
||||
self.queue.appendleft(self.StatusReport())
|
||||
self._worker_id = worker_id
|
||||
self._sample_frequency = sample_frequency_per_sec
|
||||
self._report_frequency = report_frequency_sec
|
||||
self._first_report_sec = first_report_sec or report_frequency_sec
|
||||
self._num_readouts = 0
|
||||
self._readouts = {}
|
||||
self._previous_readouts = {}
|
||||
self._previous_readouts_ts = time()
|
||||
self._thread = None
|
||||
self._exit_event = Event()
|
||||
self._gpustat_fail = 0
|
||||
self._gpustat = gpustat
|
||||
if not self._gpustat:
|
||||
log.warning('Trains-Agent Resource Monitor: GPU monitoring is not available')
|
||||
else:
|
||||
self._active_gpus = None
|
||||
try:
|
||||
active_gpus = os.environ.get('NVIDIA_VISIBLE_DEVICES', '') or \
|
||||
os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
||||
if active_gpus:
|
||||
self._active_gpus = [int(g.strip()) for g in active_gpus.split(',')]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_report(self, report):
|
||||
# type: (ResourceMonitor.StatusReport) -> ()
|
||||
if report is not None:
|
||||
self.queue.appendleft(report)
|
||||
|
||||
def get_report(self):
|
||||
# type: () -> ResourceMonitor.StatusReport
|
||||
return self.queue[0]
|
||||
|
||||
def start(self):
|
||||
self._exit_event.clear()
|
||||
self._thread = Thread(target=self._daemon)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self._exit_event.set()
|
||||
self.send_report()
|
||||
|
||||
def send_report(self, stats=None):
|
||||
report = dict(
|
||||
machine_stats=stats,
|
||||
timestamp=(int(time()) * 1000),
|
||||
worker=self._worker_id,
|
||||
**self.get_report().to_dict()
|
||||
)
|
||||
log.debug("sending report: %s", report)
|
||||
|
||||
try:
|
||||
self.session.get(service="workers", action="status_report", **report)
|
||||
except Exception:
|
||||
log.warning("Failed sending report: %s", report)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _daemon(self):
|
||||
seconds_since_started = 0
|
||||
reported = 0
|
||||
while True:
|
||||
last_report = time()
|
||||
current_report_frequency = (
|
||||
self._report_frequency if reported != 0 else self._first_report_sec
|
||||
)
|
||||
while (time() - last_report) < current_report_frequency:
|
||||
# wait for self._sample_frequency seconds, if event set quit
|
||||
if self._exit_event.wait(1 / self._sample_frequency):
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._update_readouts()
|
||||
except Exception as ex:
|
||||
log.warning("failed getting machine stats: %s", report_error(ex))
|
||||
self._failure()
|
||||
|
||||
seconds_since_started += int(round(time() - last_report))
|
||||
# check if we do not report any metric (so it means the last iteration will not be changed)
|
||||
|
||||
# if we do not have last_iteration, we just use seconds as iteration
|
||||
|
||||
# start reporting only when we figured out, if this is seconds based, or iterations based
|
||||
average_readouts = self._get_average_readouts()
|
||||
stats = {
|
||||
# 3 points after the dot
|
||||
key: round(value, 3) if isinstance(value, float) else [round(v, 3) for v in value]
|
||||
for key, value in average_readouts.items()
|
||||
}
|
||||
|
||||
# send actual report
|
||||
if self.send_report(stats):
|
||||
# clear readouts if this is update was sent
|
||||
self._clear_readouts()
|
||||
|
||||
# count reported iterations
|
||||
reported += 1
|
||||
|
||||
def _update_readouts(self):
|
||||
readouts = self._machine_stats()
|
||||
elapsed = time() - self._previous_readouts_ts
|
||||
self._previous_readouts_ts = time()
|
||||
|
||||
def fix(k, v):
|
||||
if k.endswith("_mbs"):
|
||||
v = (v - self._previous_readouts.get(k, v)) / elapsed
|
||||
|
||||
if v is None:
|
||||
v = 0
|
||||
return k, self._readouts.get(k, 0) + v
|
||||
|
||||
self._readouts.update(starmap(fix, readouts.items()))
|
||||
self._num_readouts += 1
|
||||
self._previous_readouts = readouts
|
||||
|
||||
def _get_num_readouts(self):
|
||||
return self._num_readouts
|
||||
|
||||
def _get_average_readouts(self):
|
||||
def create_general_key(old_key):
|
||||
"""
|
||||
Create key for backend payload
|
||||
:param old_key: old stats key
|
||||
:type old_key: str
|
||||
:return: new key for sending stats
|
||||
:rtype: str
|
||||
"""
|
||||
key_parts = old_key.rpartition("_")
|
||||
return "{}_*".format(key_parts[0] if old_key.startswith("gpu") else old_key)
|
||||
|
||||
ret = {}
|
||||
# make sure the gpu/cpu stats are always ordered in the accumulated values list (general_key)
|
||||
ordered_keys = sorted(self._readouts.keys())
|
||||
for k in ordered_keys:
|
||||
v = self._readouts[k]
|
||||
stat_key = self.BACKEND_STAT_MAP.get(k)
|
||||
if stat_key:
|
||||
ret[stat_key] = v / self._num_readouts
|
||||
else:
|
||||
general_key = create_general_key(k)
|
||||
general_key = self.BACKEND_STAT_MAP.get(general_key)
|
||||
if general_key:
|
||||
ret.setdefault(general_key, []).append(v / self._num_readouts)
|
||||
else:
|
||||
pass # log.debug("Cannot find key {}".format(k))
|
||||
return ret
|
||||
|
||||
def _clear_readouts(self):
|
||||
self._readouts = {}
|
||||
self._num_readouts = 0
|
||||
|
||||
def _machine_stats(self):
|
||||
"""
|
||||
:return: machine stats dictionary, all values expressed in megabytes
|
||||
"""
|
||||
cpu_usage = psutil.cpu_percent(percpu=True)
|
||||
stats = {"cpu_usage": sum(cpu_usage) / len(cpu_usage)}
|
||||
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
stats["memory_used"] = BytesSizes.megabytes(virtual_memory.used)
|
||||
stats["memory_free"] = BytesSizes.megabytes(virtual_memory.available)
|
||||
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
|
||||
stats["disk_free_percent"] = 100 - disk_use_percentage
|
||||
sensor_stat = (
|
||||
psutil.sensors_temperatures()
|
||||
if hasattr(psutil, "sensors_temperatures")
|
||||
else {}
|
||||
)
|
||||
if "coretemp" in sensor_stat and len(sensor_stat["coretemp"]):
|
||||
stats["cpu_temperature"] = max([t.current for t in sensor_stat["coretemp"]])
|
||||
|
||||
# update cached measurements
|
||||
net_stats = psutil.net_io_counters()
|
||||
stats["network_tx_mbs"] = BytesSizes.megabytes(net_stats.bytes_sent)
|
||||
stats["network_rx_mbs"] = BytesSizes.megabytes(net_stats.bytes_recv)
|
||||
io_stats = psutil.disk_io_counters()
|
||||
stats["io_read_mbs"] = BytesSizes.megabytes(io_stats.read_bytes)
|
||||
stats["io_write_mbs"] = BytesSizes.megabytes(io_stats.write_bytes)
|
||||
|
||||
# check if we can access the gpu statistics
|
||||
if self._gpustat:
|
||||
try:
|
||||
gpu_stat = self._gpustat.new_query()
|
||||
for i, g in enumerate(gpu_stat.gpus):
|
||||
# only monitor the active gpu's, if none were selected, monitor everything
|
||||
if self._active_gpus and i not in self._active_gpus:
|
||||
continue
|
||||
stats["gpu_temperature_{:d}".format(i)] = g["temperature.gpu"]
|
||||
stats["gpu_utilization_{:d}".format(i)] = g["utilization.gpu"]
|
||||
stats["gpu_mem_usage_{:d}".format(i)] = (
|
||||
100.0 * g["memory.used"] / g["memory.total"]
|
||||
)
|
||||
# already in MBs
|
||||
stats["gpu_mem_free_{:d}".format(i)] = (
|
||||
g["memory.total"] - g["memory.used"]
|
||||
)
|
||||
stats["gpu_mem_used_%d" % i] = g["memory.used"]
|
||||
except Exception as ex:
|
||||
# something happened and we can't use gpu stats,
|
||||
log.warning("failed getting machine stats: %s", report_error(ex))
|
||||
self._failure()
|
||||
|
||||
return stats
|
||||
|
||||
def _failure(self):
|
||||
self._gpustat_fail += 1
|
||||
if self._gpustat_fail >= 3:
|
||||
log.error(
|
||||
"GPU monitoring failed getting GPU reading, switching off GPU monitoring"
|
||||
)
|
||||
self._gpustat = None
|
||||
|
||||
BACKEND_STAT_MAP = {"cpu_usage_*": "cpu_usage",
|
||||
"cpu_temperature_*": "cpu_temperature",
|
||||
"disk_free_percent": "disk_free_home",
|
||||
"io_read_mbs": "disk_read",
|
||||
"io_write_mbs": "disk_write",
|
||||
"network_tx_mbs": "network_tx",
|
||||
"network_rx_mbs": "network_rx",
|
||||
"memory_free": "memory_free",
|
||||
"memory_used": "memory_used",
|
||||
"gpu_temperature_*": "gpu_temperature",
|
||||
"gpu_mem_used_*": "gpu_memory_used",
|
||||
"gpu_mem_free_*": "gpu_memory_free",
|
||||
"gpu_utilization_*": "gpu_usage"}
|
||||
|
||||
|
||||
def report_error(ex):
|
||||
return "{}: {}".format(type(ex).__name__, ex)
|
||||
119
trains_agent/helper/singleton.py
Normal file
119
trains_agent/helper/singleton.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import psutil
|
||||
from time import sleep
|
||||
from glob import glob
|
||||
from tempfile import gettempdir, NamedTemporaryFile
|
||||
|
||||
from trains_agent.helper.base import warning
|
||||
|
||||
|
||||
class Singleton(object):
|
||||
prefix = 'trainsagent'
|
||||
sep = '_'
|
||||
ext = '.tmp'
|
||||
worker_id = None
|
||||
worker_name_sep = ':'
|
||||
instance_slot = None
|
||||
_pid_file = None
|
||||
_lock_file_name = sep+prefix+sep+'global.lock'
|
||||
_lock_timeout = 10
|
||||
|
||||
@classmethod
|
||||
def register_instance(cls, unique_worker_id=None, worker_name=None):
|
||||
"""
|
||||
# Exit the process if another instance of us is using the same worker_id
|
||||
|
||||
:param unique_worker_id: if already exists, return negative
|
||||
:param worker_name: slot number will be added to worker name, based on the available instance slot
|
||||
:return: (str worker_id, int slot_number) Return None value on instance already running
|
||||
"""
|
||||
# try to lock file
|
||||
lock_file = os.path.join(gettempdir(), cls._lock_file_name)
|
||||
timeout = 0
|
||||
while os.path.exists(lock_file):
|
||||
if timeout > cls._lock_timeout:
|
||||
warning('lock file timed out {}sec - clearing lock'.format(cls._lock_timeout))
|
||||
try:
|
||||
os.remove(lock_file)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
timeout += 1
|
||||
|
||||
with open(lock_file, 'wb') as f:
|
||||
f.write(bytes(os.getpid()))
|
||||
f.flush()
|
||||
try:
|
||||
ret = cls._register_instance(unique_worker_id=unique_worker_id, worker_name=worker_name)
|
||||
except:
|
||||
ret = None, None
|
||||
|
||||
try:
|
||||
os.remove(lock_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _register_instance(cls, unique_worker_id=None, worker_name=None):
|
||||
if cls.worker_id:
|
||||
return cls.worker_id, cls.instance_slot
|
||||
# make sure we have a unique name
|
||||
instance_num = 0
|
||||
temp_folder = gettempdir()
|
||||
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
|
||||
slots = {}
|
||||
for file in files:
|
||||
parts = file.split(cls.sep)
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
except Exception:
|
||||
# something is wrong, use non existing pid and delete the file
|
||||
pid = -1
|
||||
# count active instances and delete dead files
|
||||
if not psutil.pid_exists(pid):
|
||||
# delete the file
|
||||
try:
|
||||
os.remove(os.path.join(file))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
instance_num += 1
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
uid, slot = str(f.read()).split('\n')
|
||||
slot = int(slot)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if uid == unique_worker_id:
|
||||
return None, None
|
||||
|
||||
slots[slot] = uid
|
||||
|
||||
# get a new slot
|
||||
if not slots:
|
||||
cls.instance_slot = 0
|
||||
else:
|
||||
# guarantee we have the minimal slot possible
|
||||
for i in range(max(slots.keys())+2):
|
||||
if i not in slots:
|
||||
cls.instance_slot = i
|
||||
break
|
||||
|
||||
# build worker id based on slot
|
||||
if not unique_worker_id:
|
||||
unique_worker_id = worker_name + cls.worker_name_sep + str(cls.instance_slot)
|
||||
|
||||
# create lock
|
||||
cls._pid_file = NamedTemporaryFile(dir=gettempdir(), prefix=cls.prefix + cls.sep + str(os.getpid()) + cls.sep,
|
||||
suffix=cls.ext)
|
||||
cls._pid_file.write(('{}\n{}'.format(unique_worker_id, cls.instance_slot)).encode())
|
||||
cls._pid_file.flush()
|
||||
cls.worker_id = unique_worker_id
|
||||
|
||||
return cls.worker_id, cls.instance_slot
|
||||
144
trains_agent/helper/trace.py
Normal file
144
trains_agent/helper/trace.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import unicode_literals, print_function, absolute_import
|
||||
|
||||
import linecache
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import trace
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Text, Sequence, Union
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
from functools32 import lru_cache
|
||||
|
||||
|
||||
def inclusive_parents(path):
|
||||
"""
|
||||
Return path parents including path itself.
|
||||
"""
|
||||
return chain((path,), path.parents)
|
||||
|
||||
|
||||
def get_module_path(module):
|
||||
"""
|
||||
:param module: Module object or name
|
||||
:return: module path
|
||||
"""
|
||||
if isinstance(module, six.string_types):
|
||||
module = sys.modules[module]
|
||||
path = Path(module.__file__)
|
||||
return path.parent if path.stem == '__init__' else path
|
||||
|
||||
|
||||
Module = Union[ModuleType, Text]
|
||||
|
||||
|
||||
class PackageTraceIgnore(object):
|
||||
|
||||
"""
|
||||
Object that includes package modules in trace and excludes sub modules and all other code.
|
||||
"""
|
||||
|
||||
def __init__(self, package, ignore_submodules):
|
||||
# type: (Module, Sequence[Module]) -> None
|
||||
"""
|
||||
Modules given by name will be searched for in sys.modules, enabling use of "__name__".
|
||||
:param package: Package to include modules of
|
||||
:param ignore_submodules: sub modules of package to ignore
|
||||
"""
|
||||
self.ignore_submodules = tuple(map(get_module_path, ignore_submodules))
|
||||
self.package = package
|
||||
self.package_path = get_module_path(package)
|
||||
|
||||
@lru_cache(None)
|
||||
def names(self, file_name, module_name=None):
|
||||
# type: (Text, Text) -> bool
|
||||
"""
|
||||
Return whether a file should be ignored based on it's path and module name.
|
||||
Ignore files which are not part of self.package.
|
||||
trace.Ignore's documentation states that module_name is unreliable for packages,
|
||||
therefore, it is not used here.
|
||||
|
||||
:param file_name: source file path
|
||||
:param module_name: module name
|
||||
:return: whether file should be ignored
|
||||
"""
|
||||
file_path = Path(file_name).resolve()
|
||||
include = self.include(file_path)
|
||||
return not include
|
||||
|
||||
def include(self, base):
|
||||
# type: (Path) -> bool
|
||||
for path in inclusive_parents(base):
|
||||
if not path.exists():
|
||||
continue
|
||||
if any(path.samefile(sub) for sub in self.ignore_submodules):
|
||||
return False
|
||||
if path.samefile(self.package_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PackageTrace(trace.Trace, object):
|
||||
|
||||
"""
|
||||
Trace object for tracing only lines from a specific package.
|
||||
Some functions are copied and modified for lack of modularity of ``trace.Trace``.
|
||||
"""
|
||||
|
||||
def __init__(self, package, out_file, ignore_submodules=(), *args, **kwargs):
|
||||
super(PackageTrace, self).__init__(*args, **kwargs)
|
||||
self.ignore = PackageTraceIgnore(package, ignore_submodules)
|
||||
self.__out_file = out_file
|
||||
|
||||
def __out(self, *args, **kwargs):
|
||||
print(*args, file=self.__out_file, **kwargs)
|
||||
|
||||
def globaltrace_lt(self, frame, why, arg):
|
||||
"""
|
||||
## Copied from trace module ##
|
||||
Handler for call events.
|
||||
If the code block being entered is to be ignored, returns `None',
|
||||
else returns self.localtrace.
|
||||
"""
|
||||
if why == 'call':
|
||||
code = frame.f_code
|
||||
filename = frame.f_globals.get('__file__', None)
|
||||
if filename:
|
||||
# XXX modname() doesn't work right for packages, so
|
||||
# the ignore support won't work right for packages
|
||||
ignore_it = self.ignore.names(filename)
|
||||
if not ignore_it:
|
||||
if self.trace:
|
||||
filename = Path(filename)
|
||||
modulename = '.'.join(
|
||||
filename.relative_to(self.ignore.package_path).parts[:-1] + (filename.stem,)
|
||||
)
|
||||
self.__out(' --- modulename: %s, funcname: %s' % (modulename, code.co_name))
|
||||
return self.localtrace
|
||||
else:
|
||||
return None
|
||||
|
||||
def localtrace_trace(self, frame, why, arg):
|
||||
"""
|
||||
## Copied from trace module ##
|
||||
"""
|
||||
if why == "line":
|
||||
# record the file name and line number of every trace
|
||||
filename = frame.f_code.co_filename
|
||||
lineno = frame.f_lineno
|
||||
|
||||
if self.start_time:
|
||||
self.__out('%.2f' % (time.time() - self.start_time), end='')
|
||||
bname = os.path.basename(filename)
|
||||
self.__out('%s(%d): %s' % (bname, lineno, linecache.getline(filename, lineno)), end='')
|
||||
return self.localtrace
|
||||
|
||||
localtrace_trace_and_count = localtrace_trace
|
||||
31
trains_agent/interface/__init__.py
Normal file
31
trains_agent/interface/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from functools import partial
|
||||
from importlib import import_module
|
||||
import argparse
|
||||
|
||||
from trains_agent.definitions import PROGRAM_NAME
|
||||
from .base import Parser, base_arguments, add_service, OnlyPluralChoicesHelpFormatter
|
||||
|
||||
SERVICES = [
|
||||
'worker',
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
top_parser = Parser(
|
||||
prog=PROGRAM_NAME,
|
||||
add_help=False,
|
||||
formatter_class=partial(
|
||||
OnlyPluralChoicesHelpFormatter,
|
||||
max_help_position=120,
|
||||
width=120,
|
||||
),
|
||||
)
|
||||
base_arguments(top_parser)
|
||||
from .worker import COMMANDS
|
||||
subparsers = top_parser.add_subparsers(dest='command')
|
||||
for c in COMMANDS:
|
||||
parser = subparsers.add_parser(name=c, help=COMMANDS[c]['help'])
|
||||
for a in COMMANDS[c].get('args', {}).keys():
|
||||
parser.add_argument(a, **COMMANDS[c]['args'][a])
|
||||
|
||||
return top_parser
|
||||
424
trains_agent/interface/base.py
Normal file
424
trains_agent/interface/base.py
Normal file
@@ -0,0 +1,424 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent import definitions
|
||||
from trains_agent.session import Session
|
||||
|
||||
HEADER = 'TRAINS-AGENT Deep Learning DevOps'
|
||||
|
||||
|
||||
class Parser(argparse.ArgumentParser):
|
||||
__default_subparser = None
|
||||
|
||||
def __init__(self, usage_on_error=True, *args, **kwargs):
|
||||
super(Parser, self).__init__(fromfile_prefix_chars=definitions.FROM_FILE_PREFIX_CHARS, *args, **kwargs)
|
||||
self._usage_on_error = usage_on_error
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
try:
|
||||
subparser = next(
|
||||
action for action in self._actions
|
||||
if isinstance(action, argparse._SubParsersAction))
|
||||
except StopIteration:
|
||||
return {}
|
||||
return subparser.choices
|
||||
|
||||
def error(self, message):
|
||||
if self._usage_on_error and message == argparse._('too few arguments'):
|
||||
self.print_help()
|
||||
print()
|
||||
self.exit(2, argparse._('%s: error: %s\n') % (self.prog, message))
|
||||
super(Parser, self).error(message)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.choices[name]
|
||||
|
||||
def remove_top_level_results(self, parse_results):
|
||||
"""
|
||||
Remove useless, artifact values
|
||||
:param parse_results: resulting namespace of parse_args, converted to dict ( vars(args) )
|
||||
"""
|
||||
for action in self._actions:
|
||||
if action.dest != 'version':
|
||||
parse_results.pop(action.dest, None)
|
||||
for key in ('func', 'command', 'subcommand', 'action'):
|
||||
parse_results.pop(key, None)
|
||||
|
||||
def set_default_subparser(self, name):
|
||||
self.__default_subparser = name
|
||||
|
||||
def get_default_subparser(self):
|
||||
return self.choices[self.__default_subparser]
|
||||
|
||||
def _parse_known_args(self, arg_strings, *args, **kwargs):
|
||||
in_args = set(arg_strings)
|
||||
d_sp = self.__default_subparser
|
||||
if d_sp is not None and not {'-h', '--help'}.intersection(in_args):
|
||||
for x in self._subparsers._actions:
|
||||
subparser_found = (
|
||||
isinstance(x, argparse._SubParsersAction) and
|
||||
in_args.intersection(x._name_parser_map.keys())
|
||||
)
|
||||
if subparser_found:
|
||||
break
|
||||
else:
|
||||
# insert default in first position, this implies no
|
||||
# global options without a sub_parsers specified
|
||||
arg_strings = [d_sp] + arg_strings
|
||||
return super(Parser, self)._parse_known_args(
|
||||
arg_strings, *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class AliasedPseudoAction(argparse.Action):
|
||||
"""
|
||||
Action for choosing between sub-commands, including aliases
|
||||
"""
|
||||
def __init__(self, name, aliases, help):
|
||||
dest = name
|
||||
aliases = [a for a in aliases if a != name]
|
||||
if aliases:
|
||||
dest += ' (%s)' % ','.join(aliases)
|
||||
super(AliasedPseudoAction, self).__init__(option_strings=[], dest=dest, help=help)
|
||||
|
||||
|
||||
class AliasedSubParsersAction(argparse._SubParsersAction):
|
||||
"""
|
||||
Action for adding aliases for sub-commands
|
||||
"""
|
||||
|
||||
def add_parser(self, name, **kwargs):
|
||||
aliases = kwargs.pop('aliases', [])
|
||||
parser = super(AliasedSubParsersAction, self).add_parser(name, **kwargs)
|
||||
|
||||
# Make the aliases work
|
||||
for alias in aliases:
|
||||
self._name_parser_map[alias] = parser
|
||||
# Make the help text reflect them, first removing old help entry.
|
||||
help = kwargs.pop('help', None)
|
||||
if help:
|
||||
self._choices_actions.pop()
|
||||
pseudo_action = AliasedPseudoAction(name, aliases, help)
|
||||
self._choices_actions.append(pseudo_action)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnlyPluralChoicesHelpFormatter(argparse.HelpFormatter):
|
||||
|
||||
@staticmethod
|
||||
def _metavar_formatter(action, default_metavar):
|
||||
if action.metavar is not None:
|
||||
result = action.metavar
|
||||
elif action.choices is not None:
|
||||
choice_strs = [str(choice) for choice in action.choices]
|
||||
choice_strs = [choice for choice in choice_strs if choice + 's' not in choice_strs]
|
||||
result = '{%s}' % ','.join(choice_strs)
|
||||
else:
|
||||
result = default_metavar
|
||||
|
||||
def format(tuple_size):
|
||||
if isinstance(result, tuple):
|
||||
return result
|
||||
else:
|
||||
return (result, ) * tuple_size
|
||||
return format
|
||||
|
||||
|
||||
def hyphenate(s):
|
||||
return s.replace('_', '-')
|
||||
|
||||
|
||||
def add_args(parser, args):
|
||||
"""
|
||||
Add arguments to parser from args mapping
|
||||
:param parser: parser to add arguments to
|
||||
:type parser: argparse.ArgumentParser
|
||||
:param args: mapping of name -> other arguments to ArgumentParser.add_argument
|
||||
:type args: dict
|
||||
"""
|
||||
for arg_name, arg_params in args.items():
|
||||
aliases = arg_params.pop('aliases', tuple())
|
||||
parser.add_argument(arg_name, *aliases, **arg_params)
|
||||
|
||||
|
||||
def add_mutually_exclusive_groups(parser, groups):
|
||||
"""
|
||||
Add mutually exclusive groups to parser from list
|
||||
:param parser: parser to add groups to
|
||||
:param groups: list of dictionaries, each containing:
|
||||
1. 'args': parameter to add_args
|
||||
2. arguments to ArgumentParser.add_mutually_exclusive_group
|
||||
"""
|
||||
for group in groups:
|
||||
args = group.pop('args', {})
|
||||
group_parser = parser.add_mutually_exclusive_group(**group)
|
||||
add_args(group_parser, args)
|
||||
|
||||
|
||||
def add_service(subparsers, name, commands, command_name_dest='command', formatter_class=argparse.RawDescriptionHelpFormatter, **kwargs):
|
||||
"""
|
||||
Add service commands to parser from arguments dictionary
|
||||
:param subparsers: subparsers object of ArgumentParser
|
||||
:param name: name of service
|
||||
:param commands: mapping of names to dictionaries, each of them containing:
|
||||
1. 'args' - mapping of name -> other arguments to ArgumentParser.add_argument
|
||||
2. 'help' - command description
|
||||
3. 'mutually_exclusive_groups' - see add_mutually_exclusive_groups
|
||||
:param command_name_dest: name of attribute in which to store selected sub-command
|
||||
:param formatter_class; help formatter class
|
||||
:param kwargs: any other arguments to add_parser method of subparser object
|
||||
:return: service subparser
|
||||
"""
|
||||
commands = deepcopy(commands)
|
||||
service_parser = subparsers.add_parser(
|
||||
name,
|
||||
# aliases=(name.strip('s'),),
|
||||
formatter_class=formatter_class,
|
||||
**kwargs
|
||||
)
|
||||
service_parser.register('action', 'parsers', AliasedSubParsersAction)
|
||||
service_parser.set_defaults(**{command_name_dest: name})
|
||||
service_subparsers = service_parser.add_subparsers(
|
||||
title='{} commands'.format(name.capitalize()),
|
||||
parser_class=partial(Parser, usage_on_error=False),
|
||||
dest='action')
|
||||
|
||||
# This is a fix for a bug in python3's argparse: running "trains-agent some_service" fails
|
||||
service_subparsers.required = True
|
||||
|
||||
for name, subparser in commands.pop('subparsers', {}).items():
|
||||
add_service(service_subparsers, name, command_name_dest='subcommand', **subparser)
|
||||
|
||||
for command_name, command in commands.items():
|
||||
command_type = command.pop('type', None)
|
||||
mutually_exclusive_groups = command.pop('mutually_exclusive_groups', [])
|
||||
func = command.pop('func', command_name)
|
||||
args = command.pop('args', {})
|
||||
command_parser = service_subparsers.add_parser(hyphenate(command_name), **command)
|
||||
if command_type:
|
||||
command_type.make(command_parser)
|
||||
command_parser.set_defaults(func=func)
|
||||
add_mutually_exclusive_groups(command_parser, mutually_exclusive_groups)
|
||||
add_args(command_parser, args)
|
||||
|
||||
return service_parser
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class CommandType(object):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def make(self, parser):
|
||||
return self._make(parser, *self._args, **self._kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ListCommand(CommandType):
|
||||
|
||||
@staticmethod
|
||||
def _make(parser, default_value, pagination=False, tree=False):
|
||||
if tree:
|
||||
table_group = parser.add_mutually_exclusive_group()
|
||||
tree_action = table_group.add_argument(
|
||||
'--tree', help='Tree view output', action='store_true', default=False)
|
||||
csv_group = parser.add_mutually_exclusive_group()
|
||||
# hack: tree cannot be used with either csv or table, which can be used
|
||||
# which each other
|
||||
csv_group._group_actions.append(tree_action)
|
||||
|
||||
else:
|
||||
csv_group = table_group = parser
|
||||
|
||||
table_group.add_argument(
|
||||
'--table',
|
||||
help='Select table columns ("#" separated, default: %(default)s)',
|
||||
default=default_value)
|
||||
csv_group.add_argument(
|
||||
'--csv',
|
||||
help='Generate CSV output to specified path',
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
'--no-headers',
|
||||
action='store_false',
|
||||
dest='headers',
|
||||
help='Do not print table/csv headers')
|
||||
parser.add_argument(
|
||||
'--sort',
|
||||
help='Fields to sort by (same format as --table)')
|
||||
parser.add_argument(
|
||||
'--ascending',
|
||||
default=None,
|
||||
dest='sort_reverse',
|
||||
action='store_false',
|
||||
help='Sort in ascending order (default)')
|
||||
parser.add_argument(
|
||||
'--descending',
|
||||
default=None,
|
||||
dest='sort_reverse',
|
||||
action='store_true',
|
||||
help='Sort in descending order')
|
||||
|
||||
if not pagination:
|
||||
return
|
||||
|
||||
parser.add_argument(
|
||||
'--page',
|
||||
help='Page number to show (default: %(default)s)',
|
||||
default=0,
|
||||
type=bound_number_type(minimum=0))
|
||||
parser.add_argument(
|
||||
'--page-size',
|
||||
help='Size of page (default: %(default)s)',
|
||||
default=50,
|
||||
type=bound_number_type(minimum=1))
|
||||
parser.add_argument(
|
||||
'--no-pagination',
|
||||
action='store_false',
|
||||
dest='pagination',
|
||||
help='Disable pagination (return all results)')
|
||||
|
||||
|
||||
class _HelpAction(argparse._HelpAction):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
# print header
|
||||
print(HEADER + '\n')
|
||||
|
||||
parser.print_help()
|
||||
print('')
|
||||
parser.exit()
|
||||
|
||||
|
||||
class _DetailedHelpAction(argparse._HelpAction):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
# print header
|
||||
print(HEADER + '\n')
|
||||
|
||||
parser.print_help()
|
||||
print('\n')
|
||||
# retrieve subparsers from parser
|
||||
subparsers_actions = [
|
||||
action for action in parser._actions
|
||||
if isinstance(action, argparse._SubParsersAction)
|
||||
]
|
||||
# iterate and print help for each suparser
|
||||
for subparsers_action in subparsers_actions:
|
||||
# get all subparsers and print help
|
||||
for choice, subparsercmd in subparsers_action.choices.items():
|
||||
# split help into lines so we can skip the header
|
||||
text = subparsercmd.format_help().split('\n')
|
||||
# find first line of command (skip usage and header)
|
||||
for i, t in enumerate(text):
|
||||
if t.startswith(choice.title()):
|
||||
break
|
||||
# print help command prefix
|
||||
print(text[i])
|
||||
|
||||
# print help per sub-commands, we actually assume only one
|
||||
subact = [
|
||||
action for action in subparsercmd._actions
|
||||
if isinstance(action, argparse._SubParsersAction)
|
||||
]
|
||||
# per action print all parameters
|
||||
for j, t in enumerate(text[i + 2:]):
|
||||
print(t)
|
||||
k = t.split()
|
||||
if not k:
|
||||
continue
|
||||
try:
|
||||
subc = subact[0].choices[k[0]]
|
||||
except KeyError:
|
||||
continue
|
||||
# hack so we can control formatting in one place
|
||||
# otherwise we need to update all
|
||||
# the parsers when we create them
|
||||
subc.formatter_class = lambda prog: argparse.HelpFormatter(prog, width=120)
|
||||
subchelp = subc.format_help().split('\n')
|
||||
# skip until we reach "optional arguments:"
|
||||
for si, st in enumerate(subchelp):
|
||||
if st.startswith('optional arguments:'):
|
||||
break
|
||||
# print help with single tab indent
|
||||
for st in subchelp[si + 1:]:
|
||||
print('\t %s' % st)
|
||||
parser.exit()
|
||||
|
||||
|
||||
def base_arguments(top_parser):
|
||||
top_parser.register('action', 'parsers', AliasedSubParsersAction)
|
||||
top_parser.add_argument('-h', action=_HelpAction, help='Displays summary of all commands')
|
||||
top_parser.add_argument(
|
||||
'--help',
|
||||
action=_DetailedHelpAction,
|
||||
help='Detailed help of command line interface')
|
||||
top_parser.add_argument(
|
||||
'--version',
|
||||
action='version',
|
||||
version='TRAINS-AGENT version %s' % Session.version,
|
||||
help='TRAINS-AGENT version number')
|
||||
top_parser.add_argument(
|
||||
'--config-file',
|
||||
help='Use a different configuration file (default: "{}")'.format(definitions.CONFIG_FILE))
|
||||
top_parser.add_argument('--debug', '-d', action='store_true', help='print debug information')
|
||||
top_parser.add_argument(
|
||||
'--trace', '-t',
|
||||
action='store_true',
|
||||
help='Trace execution to a temporary file.',
|
||||
)
|
||||
|
||||
|
||||
def bound_number_type(minimum=None, maximum=None):
|
||||
"""
|
||||
bound_number_type
|
||||
|
||||
Creates a bounded integer "type" (validator function)
|
||||
for use with argparse.ArgumentParser.add_argument.
|
||||
At least one of ``minimum`` and ``maximum`` must be passed.
|
||||
|
||||
:param minimum: maximum allowed value
|
||||
:param maximum: minimum allowed value
|
||||
"""
|
||||
if minimum is maximum is None:
|
||||
raise ValueError('either "minimum" or "maximum" must be provided')
|
||||
|
||||
def bound_int(arg):
|
||||
num = int(arg)
|
||||
if minimum is not None and num < minimum:
|
||||
raise argparse.ArgumentTypeError('minimum value is {}'.format(minimum))
|
||||
if maximum is not None and num > maximum:
|
||||
raise argparse.ArgumentTypeError('maximum value is {}'.format(minimum))
|
||||
return num
|
||||
|
||||
return bound_int
|
||||
|
||||
|
||||
def real_path_type(string):
|
||||
path = Path(string).expanduser()
|
||||
if not path.exists():
|
||||
raise argparse.ArgumentTypeError('"{}": No such file or directory'.format(path))
|
||||
return path
|
||||
|
||||
|
||||
class ObjectID(object):
|
||||
|
||||
def __init__(self, name, service=None):
|
||||
self.name = name
|
||||
self.service = service
|
||||
|
||||
|
||||
def foreign_object_id(service):
|
||||
return partial(ObjectID, service=service)
|
||||
111
trains_agent/interface/worker.py
Normal file
111
trains_agent/interface/worker.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import argparse
|
||||
from textwrap import dedent
|
||||
|
||||
from trains_agent.helper.base import warning, is_windows_platform
|
||||
from trains_agent.interface.base import foreign_object_id
|
||||
|
||||
|
||||
class DeprecatedFlag(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
warning('argument "{}" is deprecated'.format(option_string))
|
||||
|
||||
|
||||
WORKER_ARGS = {
|
||||
'-O': {
|
||||
'help': 'Compile optimized pyc code (see python documentation). Repeat for more optimization.',
|
||||
'action': 'count',
|
||||
'default': 0,
|
||||
'dest': 'optimization',
|
||||
},
|
||||
'--git-user': {
|
||||
'help': 'git username for repository access',
|
||||
},
|
||||
'--git-pass': {
|
||||
'help': 'git password for repository access',
|
||||
},
|
||||
'--log-level': {
|
||||
'help': 'SDK log level',
|
||||
'choices': ['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
'type': lambda x: x.upper(),
|
||||
'default': 'INFO',
|
||||
},
|
||||
}
|
||||
|
||||
DAEMON_ARGS = dict({
|
||||
'--foreground': {
|
||||
'help': 'Pipe full log to stdout/stderr, should not be used if running in background',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--docker': {
|
||||
'help': 'Run execution task inside a docker (v19.03 and above). Optional args <image> <arguments> or '
|
||||
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
|
||||
'set NVIDIA_VISIBLE_DEVICES to limit gpu visibility for docker',
|
||||
'nargs': '*',
|
||||
'default': False,
|
||||
},
|
||||
'--queue': {
|
||||
'help': 'Queue ID(s)/Name(s) to pull tasks from (\'default\' queue)',
|
||||
'nargs': '+',
|
||||
'default': tuple(),
|
||||
'dest': 'queues',
|
||||
'type': foreign_object_id('queues'),
|
||||
},
|
||||
}, **WORKER_ARGS)
|
||||
|
||||
COMMANDS = {
|
||||
'execute': {
|
||||
'help': 'Build & Execute a selected experiment',
|
||||
'args': dict({
|
||||
'--id': {
|
||||
'help': 'Task ID to run',
|
||||
'required': True,
|
||||
'dest': 'task_id',
|
||||
'type': foreign_object_id('tasks'),
|
||||
},
|
||||
'--log-file': {
|
||||
'help': 'Output task execution (stdout/stderr) into text file',
|
||||
},
|
||||
'--disable-monitoring': {
|
||||
'help': 'Disable logging & monitoring (stdout is still visible)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--full-monitoring': {
|
||||
'help': 'Full environment setup log & task logging & monitoring (stdout is still visible)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
}, **WORKER_ARGS),
|
||||
},
|
||||
'build': {
|
||||
'help': 'Build selected experiment environment '
|
||||
'(including pip packages, cloned code and git diff)\n'
|
||||
'Used mostly for debugging purposes',
|
||||
'args': dict({
|
||||
'--id': {
|
||||
'help': 'Task ID to build',
|
||||
'required': True,
|
||||
'dest': 'task_id',
|
||||
'type': foreign_object_id('tasks'),
|
||||
},
|
||||
'--target-folder': {
|
||||
'help': 'Where to build the task\'s virtual environment and source code',
|
||||
},
|
||||
'--python-version': {
|
||||
'help': 'Virtual environment python version to use',
|
||||
},
|
||||
}, **WORKER_ARGS),
|
||||
},
|
||||
'list': {
|
||||
'help': 'List all worker machines and status',
|
||||
},
|
||||
'daemon': {
|
||||
'help': 'Start Trains-Agent daemon worker',
|
||||
'args': DAEMON_ARGS,
|
||||
},
|
||||
'config': {
|
||||
'help': 'Check daemon configuration and print it',
|
||||
},
|
||||
'init': {
|
||||
'help': 'Trains-Agent configuration wizard',
|
||||
}
|
||||
}
|
||||
314
trains_agent/session.py
Normal file
314
trains_agent/session.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable
|
||||
|
||||
import attr
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigFactory, HOCONConverter, ConfigTree
|
||||
|
||||
from trains_agent.backend_api.session import Session as _Session, Request
|
||||
from trains_agent.backend_api.session.client import APIClient
|
||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES
|
||||
from trains_agent.definitions import ENVIRONMENT_CONFIG
|
||||
from trains_agent.errors import APIError
|
||||
from trains_agent.helper.base import HOCONEncoder
|
||||
from trains_agent.helper.process import Argv
|
||||
from .version import __version__
|
||||
|
||||
POETRY = "poetry"
|
||||
|
||||
|
||||
@attr.s
|
||||
class ConfigValue(object):
|
||||
|
||||
"""
|
||||
Manages a single config key
|
||||
"""
|
||||
|
||||
config = attr.ib(type=ConfigTree)
|
||||
key = attr.ib(type=str)
|
||||
|
||||
def get(self, default=None):
|
||||
"""
|
||||
Get value of key with default
|
||||
"""
|
||||
return self.config.get(self.key, default=default)
|
||||
|
||||
def set(self, value):
|
||||
"""
|
||||
Change the value of key
|
||||
"""
|
||||
self.config.put(self.key, value)
|
||||
|
||||
def modify(self, fn):
|
||||
# type: (Callable[[Any], Any]) -> ()
|
||||
"""
|
||||
Change the value of a key using a function
|
||||
"""
|
||||
self.set(fn(self.get()))
|
||||
|
||||
|
||||
def tree(*args):
|
||||
"""
|
||||
Helper function for creating config trees
|
||||
"""
|
||||
return ConfigTree(args)
|
||||
|
||||
|
||||
class Session(_Session):
|
||||
version = __version__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# make sure we set the environment variable so the parent session opens the correct file
|
||||
if kwargs.get('config_file'):
|
||||
config_file = Path(os.path.expandvars(kwargs.get('config_file'))).expanduser().absolute().as_posix()
|
||||
kwargs['config_file'] = config_file
|
||||
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
|
||||
if not Path(config_file).is_file():
|
||||
raise ValueError("Could not open configuration file: {}".format(config_file))
|
||||
super(Session, self).__init__(*args, **kwargs)
|
||||
self.log = self.get_logger(__name__)
|
||||
self.trace = kwargs.get('trace', False)
|
||||
self._config_file = kwargs.get('config_file') or \
|
||||
os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0]
|
||||
self.api_client = APIClient(session=self, api_version="2.4")
|
||||
# HACK make sure we have python version to execute,
|
||||
# if nothing was specific, use the one that runs us
|
||||
def_python = ConfigValue(self.config, "agent.default_python")
|
||||
if not def_python.get():
|
||||
def_python.set("{version.major}.{version.minor}".format(version=sys.version_info))
|
||||
|
||||
# HACK: backwards compatibility
|
||||
os.environ['ALG_CONFIG_FILE'] = self._config_file
|
||||
os.environ['SM_CONFIG_FILE'] = self._config_file
|
||||
if not self.config.get('api.host', None) and self.config.get('api.api_server', None):
|
||||
self.config['api']['host'] = self.config.get('api.api_server')
|
||||
|
||||
# initialize nvidia visibility variable
|
||||
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
|
||||
if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
||||
elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'):
|
||||
os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
# override with environment variables
|
||||
# cuda_version & cudnn_version are overridden with os.environ here, and normalized in the next section
|
||||
for config_key, env_config in ENVIRONMENT_CONFIG.items():
|
||||
value = env_config.get()
|
||||
if not value:
|
||||
continue
|
||||
env_key = ConfigValue(self.config, config_key)
|
||||
env_key.set(value)
|
||||
|
||||
# initialize cuda versions
|
||||
try:
|
||||
from trains_agent.helper.package.requirements import RequirementsManager
|
||||
agent = self.config['agent']
|
||||
agent['cuda_version'], agent['cudnn_version'] = \
|
||||
RequirementsManager.get_cuda_version(self.config)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# initialize worker name
|
||||
worker_name = ConfigValue(self.config, "agent.worker_name")
|
||||
if not worker_name.get():
|
||||
worker_name.set(platform.node())
|
||||
|
||||
self.create_cache_folders()
|
||||
|
||||
@staticmethod
|
||||
def get_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = True
|
||||
return TrainsAgentLogger(logger)
|
||||
|
||||
@property
|
||||
def debug_mode(self):
|
||||
return self.config.get("agent.debug", False)
|
||||
|
||||
@property
|
||||
def config_file(self):
|
||||
return self._config_file
|
||||
|
||||
def create_cache_folders(self, slot_index=0):
|
||||
"""
|
||||
create and update the cache folders
|
||||
notice we support multiple instances sharing the same cache on some folders
|
||||
and on some we use "instance slot" numbers in order to differentiate between the different instances running
|
||||
notice slot_index=0 is the default, meaning no suffix is added to the singleton_folders
|
||||
|
||||
Note: do not call this function twice with non zero slot_index
|
||||
it will add a suffix to the folders on each call
|
||||
|
||||
:param slot_index: integer
|
||||
"""
|
||||
|
||||
# create target folders:
|
||||
folder_keys = ('agent.venvs_dir', 'agent.vcs_cache.path',
|
||||
'agent.pip_download_cache.path',
|
||||
'agent.docker_pip_cache', 'agent.docker_apt_cache')
|
||||
singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path',)
|
||||
|
||||
for key in folder_keys:
|
||||
folder_key = ConfigValue(self.config, key)
|
||||
if not folder_key.get():
|
||||
continue
|
||||
|
||||
if slot_index and key in singleton_folders:
|
||||
f = folder_key.get()
|
||||
if f.endswith(os.path.sep):
|
||||
f = f[:-1]
|
||||
folder_key.set(f + '.{}'.format(slot_index))
|
||||
|
||||
# update the configuration for full path
|
||||
folder = Path(os.path.expandvars(folder_key.get())).expanduser().absolute()
|
||||
folder_key.set(folder.as_posix())
|
||||
try:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def print_configuration(self, remove_secret_keys=("secret", "pass", "token", "account_key")):
|
||||
# remove all the secrets from the print
|
||||
def recursive_remove_secrets(dictionary, secret_keys=()):
|
||||
for k in list(dictionary):
|
||||
for s in secret_keys:
|
||||
if s in k:
|
||||
dictionary.pop(k)
|
||||
break
|
||||
if isinstance(dictionary.get(k, None), dict):
|
||||
recursive_remove_secrets(dictionary[k], secret_keys=secret_keys)
|
||||
elif isinstance(dictionary.get(k, None), (list, tuple)):
|
||||
for item in dictionary[k]:
|
||||
if isinstance(item, dict):
|
||||
recursive_remove_secrets(item, secret_keys=secret_keys)
|
||||
|
||||
config = deepcopy(self.config.to_dict())
|
||||
# remove the env variable, it's not important
|
||||
config.pop('env', None)
|
||||
if remove_secret_keys:
|
||||
recursive_remove_secrets(config, secret_keys=remove_secret_keys)
|
||||
config = ConfigFactory.from_dict(config)
|
||||
self.log.debug("Run by interpreter: %s", sys.executable)
|
||||
print(
|
||||
"Current configuration (trains_agent v{}, location: {}):\n"
|
||||
"----------------------\n{}\n".format(
|
||||
self.version, self._config_file, HOCONConverter.convert(config, "properties")
|
||||
)
|
||||
)
|
||||
|
||||
def send_api(self, request):
|
||||
# type: (Request) -> Any
|
||||
result = self.send(request)
|
||||
if not result.ok():
|
||||
raise APIError(result)
|
||||
if not result.response:
|
||||
raise APIError(result, extra_info="Invalid response")
|
||||
return result.response
|
||||
|
||||
def get(self, service, action, version=None, headers=None,
|
||||
data=None, json=None, async_enable=False, **kwargs):
|
||||
return self._manual_request(service=service, action=action,
|
||||
version=version, method="get", headers=headers,
|
||||
data=data, async_enable=async_enable,
|
||||
json=json or kwargs)
|
||||
|
||||
def post(self, service, action, version=None, headers=None,
|
||||
data=None, json=None, async_enable=False, **kwargs):
|
||||
return self._manual_request(service=service, action=action,
|
||||
version=version, method="post", headers=headers,
|
||||
data=data, async_enable=async_enable,
|
||||
json=json or kwargs)
|
||||
|
||||
def _manual_request(self, service, action, version=None, method="get", headers=None,
|
||||
data=None, json=None, async_enable=False, **kwargs):
|
||||
|
||||
res = self.send_request(service=service, action=action,
|
||||
version=version, method=method, headers=headers,
|
||||
data=data, async_enable=async_enable,
|
||||
json=json or kwargs)
|
||||
|
||||
try:
|
||||
res_json = res.json()
|
||||
return_code = res_json["meta"]["result_code"]
|
||||
except (ValueError, KeyError, TypeError):
|
||||
raise APIError(res)
|
||||
|
||||
# check return code
|
||||
if return_code != 200:
|
||||
raise APIError(res)
|
||||
|
||||
return res_json["data"]
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(
|
||||
self.config.as_plain_ordered_dict(), cls=HOCONEncoder, indent=4
|
||||
)
|
||||
|
||||
def command(self, *args):
|
||||
return Argv(*args, log=self.get_logger(Argv.__module__))
|
||||
|
||||
|
||||
@attr.s
|
||||
class TrainsAgentLogger(object):
|
||||
"""
|
||||
Proxy around logging.Logger because inheriting from it is difficult.
|
||||
"""
|
||||
|
||||
logger = attr.ib(type=logging.Logger)
|
||||
|
||||
def _log_with_error(self, level, *args, **kwargs):
|
||||
"""
|
||||
Include error information when in debug mode
|
||||
"""
|
||||
kwargs.setdefault("exc_info", self.logger.isEnabledFor(logging.DEBUG))
|
||||
return self.logger.log(level, *args, **kwargs)
|
||||
|
||||
def warning(self, *args, **kwargs):
|
||||
return self._log_with_error(logging.WARNING, *args, **kwargs)
|
||||
|
||||
def error(self, *args, **kwargs):
|
||||
return self._log_with_error(logging.ERROR, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.logger, item)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Compatibility with old ``Command.log()`` method
|
||||
"""
|
||||
return self.logger.info(*args, **kwargs)
|
||||
|
||||
|
||||
def normalize_cuda_version(value):
|
||||
# type: (Any) -> str
|
||||
"""
|
||||
Take variably formatted cuda version string/number and return it in the same format:
|
||||
string decimal representation of 10 * major + minor.
|
||||
|
||||
>>> normalize_cuda_version(100)
|
||||
'100'
|
||||
>>> normalize_cuda_version("100")
|
||||
'100'
|
||||
>>> normalize_cuda_version(10)
|
||||
'10'
|
||||
>>> normalize_cuda_version(10.0)
|
||||
'100'
|
||||
>>> normalize_cuda_version("10.0")
|
||||
'100'
|
||||
>>> normalize_cuda_version("10.0.130")
|
||||
'100'
|
||||
"""
|
||||
value = str(value)
|
||||
if "." in value:
|
||||
try:
|
||||
value = str(int(float(".".join(value.split(".")[:2])) * 10))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return value
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user