Compare commits

...

86 Commits

Author SHA1 Message Date
allegroai
090327234a Version bump to v0.16.3 2020-12-22 20:18:30 +02:00
allegroai
3620c3a12d Update PyJWT requirement (v2.0.0 breaks interface) as well as other requirements constraints 2020-12-22 20:18:14 +02:00
allegroai
9a3f950ac6 Fix conform queue name to k8s standard 2020-12-13 16:21:29 +02:00
allegroai
0b36cb0f85 Change k8s pod naming scheme to include queue name 2020-12-10 14:19:19 +02:00
allegroai
dd42423482 Version bump to v0.16.2 2020-12-10 13:02:19 +02:00
allegroai
69eb25db1f Fix running trains-agent from conda environment - conda.sh not found in first conda PATH match 2020-12-10 09:53:18 +02:00
allegroai
a41ea52f87 Add multiple packages support 2020-12-10 09:52:00 +02:00
allegroai
259113c989 Add PackageCollectorRequirement to allow multiple entries of the same package 2020-12-06 12:16:56 +02:00
allegroai
1afa3a3914 Add torchcsprng and torchtext to PyTorch resolving. Improve debug prints on auto cuda version resolving. 2020-12-06 12:15:12 +02:00
allegroai
448e23825c Fix requirements dict with null entry in pip should be considered None and we should install from requirements.txt 2020-12-06 12:14:22 +02:00
allegroai
b0c0f41f62 Allow zero context diffs (useful when blind patching repository) 2020-12-06 12:13:28 +02:00
allegroai
d2c5fb6512 Add K8s glue example --gateway-address settings properties/k8s-gateway-address on all Tasks 2020-12-06 12:12:42 +02:00
allegroai
b89cf4ec23 version bump 2020-11-29 23:17:50 +02:00
allegroai
74b646af9e Add pass TRAINS_DOCKER_IMAGE into docker for interactive sessions 2020-11-29 23:16:40 +02:00
allegroai
0cf485f7a9 Improve k8s nvidia container integration 2020-11-26 01:15:49 +02:00
allegroai
ea63e4f66e Add --ssh-server-port to k8s glue service 2020-11-26 01:15:20 +02:00
allegroai
58eb5fbd5f Fix torch CUDA 11.1 support 2020-11-26 01:14:36 +02:00
allegroai
a8c543ef7b Fix nvidia pytorch dockers support 2020-11-25 16:45:09 +02:00
allegroai
64e198a57a Fix nvidia docker support on some linux distros (SUSE) 2020-11-25 16:44:37 +02:00
allegroai
de332b9e6b Document '--stop' usage 2020-11-19 12:36:58 +02:00
allegroai
60eeff292d version bump 2020-11-11 17:11:51 +02:00
allegroai
52f30b306a Fix git diff with empty line at the end of the git diff will cause corrupt diff apply message 2020-11-11 17:11:28 +02:00
allegroai
6df0f81ca0 Fix uid is None causes ValueError in str.startswith(). Fix str.split (should be on the filename itself, not the path). 2020-11-11 16:32:47 +02:00
allegroai
40b3c1502d Add extra_bash_init_script to k8s glue. Default config is the raw config file (not created at runtime) 2020-11-11 16:31:25 +02:00
allegroai
a61265effe Improve trying to find conda executable 2020-11-11 16:29:50 +02:00
allegroai
92efea6b76 Add agent.package_manager.force_repo_requirements_txt. If True, "Installed Packages" on Task are ignored, and only repo requirements.txt is used 2020-11-11 16:29:00 +02:00
allegroai
216b3e2179 Allow to specifying cudatoolkit version in "installed packages" when using Conda as package manager (trains issue #229) 2020-10-30 10:06:02 +02:00
allegroai
293a92f486 Improve k8s glue add --template-yaml 2020-10-23 01:28:22 +03:00
allegroai
6bad2b5352 Fix support non-ascii git diff 2020-10-23 01:27:59 +03:00
allegroai
a09a638b9c Improve k8s glue layer 2020-10-22 18:09:56 +03:00
allegroai
24f57270ed version bump 2020-10-22 18:09:23 +03:00
allegroai
1b7964ce98 Add k8s select external trains.conf file for the pod itself 2020-10-21 19:04:38 +03:00
allegroai
5a510882b8 Ignore environment SSH_AUTH_SOCK. Only check if git_user/pass are configured, if they are not, leave the links as they are 2020-10-21 19:02:29 +03:00
allegroai
601ed03198 Add support for k8s pod custom user properties 2020-10-20 23:48:02 +03:00
allegroai
90fe4570b9 Show k8s pod number in task's User Properties configuration section 2020-10-20 23:27:04 +03:00
allegroai
92fc8e838f Add K8s glue support for limited number of services exposing ports 2020-10-20 14:17:30 +03:00
allegroai
89a3020c5e Fix ubuntu/debian support by making sure not to ask for input (fix tzdata install) 2020-10-15 23:32:17 +03:00
allegroai
fc3e47b67e Add suppress_carriage_return to documentation
Add docker_preprocess_bash_script to allow preprocessing bash to be added
Fix multiple python versions installed in the same docker by finding the highest installed python inside the docker
Fix conda_env_as_base_docker not set to False in docker mode
2020-10-15 23:31:01 +03:00
allegroai
b2a80ca314 Fix Trains examples references 2020-10-15 23:28:53 +03:00
allegroai
14655f19a0 Fix conda PYTHONPATH (point only to code, not to venv) 2020-10-15 23:26:58 +03:00
allegroai
47092c47db Fix apply git diff from submodule only 2020-10-15 23:26:52 +03:00
allegroai
8e6fce8d63 Add conda support for read-only pre-built environment (pass conda folder as docker_cmd on Task).
Fix conda restore prebuild tar.gz file, fix conda prefix by call conda-unpack from unzipped conda env.
2020-10-15 23:25:57 +03:00
allegroai
3c514e3418 Make sure TRAINS_AGENT_K8S_HOST_MOUNT is used only once per mount 2020-10-15 23:24:51 +03:00
allegroai
8a425b100b Fix k8s glue script to trains-agent default docker script 2020-10-15 23:24:21 +03:00
allegroai
eb942cfedd Add agent.package_manager.conda_env_as_base_docker allowing "docker_cmd" to contain link to a full pre-packaged conda environment (conda-pack outputs a tar.gz). Use TRAINS_CONDA_ENV_PACKAGE to specify conda tar.gz file. 2020-10-15 23:23:46 +03:00
Allegro AI
0a7fc06108 Merge pull request #31 from eliorc/master
Fix broken links in README.md
2020-10-14 16:13:40 +03:00
Elior Cohen
0ae35afa76 📝 Broken links in README.md 2020-10-14 10:43:33 +03:00
allegroai
a2156e73bf Fix conda pip freeze to be consistent with trains 0.16.3 2020-10-11 11:25:35 +03:00
allegroai
9fe77f3c28 Fix conda environment support for trains 0.16.3 full env. Add agent.package_manager.conda_full_env_update to allow conda to update back the requirements (default is false, to preserve previous behavior) 2020-10-11 11:24:52 +03:00
allegroai
6f078afafd Add Requirement.clone() 2020-10-11 11:21:49 +03:00
allegroai
15f4aa613e Suppress "\r" when reading a current chunk of a file. Add agent.suppress_carriage_return (default True) to support previous behavior. 2020-10-11 11:21:08 +03:00
allegroai
7cd9fa6c41 Version bump to v0.16.1 2020-10-05 18:27:07 +03:00
allegroai
234d5fac2c When using force ssh protocol, only enforce on git_host if provided, otherwise apply everywhere 2020-10-05 18:26:21 +03:00
allegroai
6cbfb96ff8 Rename git_domain to git_host 2020-10-05 11:25:03 +03:00
allegroai
6e54e55c31 Add agent.force_git_ssh_port to control https to ssh link conversion for non standard ssh port 2020-10-04 19:42:44 +03:00
allegroai
3ff85b7b85 Replace back package version on conda and pip 2020-10-04 19:41:26 +03:00
allegroai
5640489f57 Replace torch version on pre-installed local file 2020-10-04 19:40:39 +03:00
allegroai
8135a6facf Add agent.git_domain setting for limiting git credential usage for a specific domain (env var TRAINS_AGENT_GIT_DOMAIN is also supported) 2020-10-04 19:40:04 +03:00
allegroai
b6ae4f211d Fix "package @ " should processed by us (pip will not test pre-installed version of the package compared with the link) 2020-10-04 19:38:33 +03:00
allegroai
a56f032ec4 Fix torch support to not change back the same link 2020-10-04 19:37:12 +03:00
allegroai
075736de20 Translate downloaded URL back to original link when new pip version is installed (otherwise we end up with file:///... links) 2020-10-04 19:36:14 +03:00
allegroai
d8543c892e When new pip version is installed, no need to install git packages twice (pip freeze will detect the correct git link version) 2020-10-04 19:35:26 +03:00
allegroai
ca0870b048 Allow parsing of "package @ scheme://link" lines in requirements 2020-10-04 19:34:32 +03:00
allegroai
c7a739fafa Add support for detecting new pip version (20+) supporting @ in requirements 2020-10-04 19:33:52 +03:00
allegroai
7170296162 Remove warning on '.' (same as an empty working directory) 2020-10-04 19:32:48 +03:00
allegroai
3bed0ef33c Add protection against bad file name parsing in git diff apply 2020-10-04 19:31:48 +03:00
allegroai
d419fa1e4f Update torch version after using system pre-installed version 2020-10-04 19:29:47 +03:00
allegroai
31a56c71bd Add preliminary agent uptime/downtime support 2020-09-29 19:34:51 +03:00
allegroai
28f47419b0 Fix incorrect check for spaces in current execution folder (only check in cache folders) 2020-09-15 20:26:02 +03:00
allegroai
6a24da2849 Add post_packages post_optional_packages to control packages installed after all the rest (e.g. horovod)
Rename CythonReq to PriorityPackageRequirement and HorovodReq to PostRequirement
2020-09-15 20:20:55 +03:00
allegroai
782668fd21 Add sdk.metrics.plot_max_num_digits to reduce plot storage size 2020-09-05 16:37:17 +03:00
allegroai
aaf8d802e7 Update documentation 2020-09-05 16:37:17 +03:00
allegroai
ca89a1e322 Fix pre-installed packages are ignored when installing a git package wheel. Reinstalling a git+http link is enough to make sure all requirements are met/installed (trains issue #196) 2020-09-05 16:37:17 +03:00
allegroai
121dec2a62 Version bump to v0.16.0 2020-08-10 17:28:00 +03:00
allegroai
4aacf9005e Fix GPU Windows monitoring support (Trains Issue #177) 2020-08-10 08:07:51 +03:00
allegroai
6b333202e9 Sync generated conf file with latest Trains 2020-08-08 14:44:45 +03:00
allegroai
ce6831368f Fix GPU monitoring on Windows machines 2020-08-08 14:43:25 +03:00
allegroai
e4111c830b Fix GIT user/pass in requirements and support for '-e git+http' lines 2020-07-30 14:30:23 +03:00
allegroai
52c1772b04 Add requirement_parser into trains-agent instead as a dependency. Fix requirement_parser to support 'package @ git+http' lines 2020-07-30 14:29:37 +03:00
allegroai
699d13bbb3 Fix task status change to queued should also never happen during Task runtime 2020-07-14 23:42:11 +03:00
allegroai
2c8d7d3d9a Fix --debug to set all specified loggers to DEBUG
Add set_urllib_log_level, in debug set urllib log level to DEBUG
2020-07-11 01:45:46 +03:00
allegroai
b13cc1e8e7 Add error message when Trains API Server is not accessible on startup 2020-07-11 01:44:45 +03:00
allegroai
17d2bf2a3e Change daemon --stop without any specific flag to terminate the agents by worker id lexicographic order 2020-07-11 01:43:54 +03:00
allegroai
94997f9c88 Add daemon --order-fairness for round-robin queue pulling
Add daemon --stop to terminate running agent (assume all the rest of the arguments are the same)
Clean up all log files on termination unless executed with --debug
2020-07-11 01:42:56 +03:00
allegroai
c6d998c4df Add terminate process and rmtree utilities 2020-07-11 01:40:50 +03:00
allegroai
f8ea445339 Fix docker to use UTF-8 encoding, so prints won't break it 2020-07-11 01:40:14 +03:00
44 changed files with 2352 additions and 406 deletions

View File

@@ -227,6 +227,14 @@ The **Trains Agent** will first try to pull jobs from the `important_jobs` queue
Adding queues, managing job order within a queue and moving jobs between queues, is available using the Web UI, see example on our [open server](https://demoapp.trains.allegro.ai/workers-and-queues/queues)
#### Stopping the Trains Agent
To stop a **Trains Agent** running in the background, run the same command line used to start the agent with `--stop` appended.
For example, to stop the first of the above shown same machine, single gpu agents:
```bash
trains-agent daemon --detached --gpus 0 --queue default --docker nvidia/cuda --stop
```
## How do I create an experiment on the Trains Server? <a name="from-scratch"></a>
* Integrate [Trains](https://github.com/allegroai/trains) with your code
* Execute the code on your machine (Manually / PyCharm / Jupyter Notebook)
@@ -272,18 +280,18 @@ trains-agent daemon --services-mode --detached --queue services --create-queue -
## AutoML and Orchestration Pipelines <a name="automl-pipes"></a>
The Trains Agent can also be used to implement AutoML orchestration and Experiment Pipelines in conjunction with the Trains package.
Sample AutoML & Orchestration examples can be found in the Trains [example/automl](https://github.com/allegroai/trains/tree/master/examples/automl) folder.
Sample AutoML & Orchestration examples can be found in the Trains [example/automation](https://github.com/allegroai/trains/tree/master/examples/automation) folder.
AutoML examples
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/automl/automl_base_template_keras_simple.py)
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/optimization/hyper-parameter-optimization/base_template_keras_simple.py)
- In order to create an experiment-template in the system, this code must be executed once manually
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automl/automl_random_search_example.py)
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automation/manual_random_param_search_example.py)
- This example will create multiple copies of the Keras experiment-template, with different hyper-parameter combinations
Experiment Pipeline examples
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/task_piping_example.py)
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/task_piping_example.py)
- This example will "process data", and once done, will launch a copy of the 'second step' experiment-template
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/toy_base_task.py)
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/toy_base_task.py)
- In order to create an experiment-template in the system, this code must be executed once manually
## License

View File

@@ -5,8 +5,17 @@ WORKDIR /usr/agent
COPY . /usr/agent
ENV LC_ALL=en_US.UTF-8
ENV LANG=en_US.UTF-8
ENV LANGUAGE=en_US.UTF-8
ENV PYTHONIOENCODING=UTF-8
RUN apt-get update
RUN apt-get dist-upgrade -y
RUN apt-get install -y locales
RUN locale-gen en_US.UTF-8
RUN apt-get install -y curl python3-pip git
RUN curl -sSL https://get.docker.com/ | sh
RUN python3 -m pip install -U pip

View File

@@ -11,4 +11,4 @@ TRAINS_API_HOST=${TRAINS_API_HOST:-"http://$TRAINS_HOST_IP:8008"}
echo $TRAINS_FILES_HOST $TRAINS_WEB_HOST $TRAINS_API_HOST 1>&2
python3 -m pip install -q -U "trains-agent${TRAINS_AGENT_UPDATE_VERSION}"
trains-agent daemon --services-mode --queue services --create-queue --docker $TRAINS_AGENT_DEFAULT_BASE_DOCKER --cpu-only $TRAINS_AGENT_EXTRA_ARGS
trains-agent daemon --services-mode --queue services --create-queue --docker "$TRAINS_AGENT_DEFAULT_BASE_DOCKER" --cpu-only $TRAINS_AGENT_EXTRA_ARGS

View File

@@ -17,9 +17,14 @@ agent {
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
git_user=""
git_pass=""
# Limit credentials to a single domain, for example: github.com,
# all other domains will use public access (no user/pass). Default: always send user/pass for any VCS domain
git_host=""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
# force_git_ssh_port: ""
# unique name of this worker, if None, created based on hostname:process_id
# Overridden with os environment: TRAINS_WORKER_NAME
@@ -57,6 +62,24 @@ agent {
# additional conda channels to use when installing with conda package manager
conda_channels: ["pytorch", "conda-forge", ]
# conda_full_env_update: false
# conda_env_as_base_docker: false
# set the priority packages to be installed before the rest of the required packages
# priority_packages: ["cython", "numpy", "setuptools", ]
# set the optional priority packages to be installed before the rest of the required packages,
# In case a package installation fails, the package will be ignored,
# and the virtual environment process will continue
# priority_optional_packages: ["pygobject", ]
# set the post packages to be installed after all the rest of the required packages
# post_packages: ["horovod", ]
# set the optional post packages to be installed after all the rest of the required packages,
# In case a package installation fails, the package will be ignored,
# and the virtual environment process will continue
# post_optional_packages: []
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
@@ -146,6 +169,9 @@ sdk {
# X images are stored in the upload destination for each matplotlib plot title.
matplotlib_untitled_history_size: 100
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
# plot_max_num_digits: 5
# Settings for generated debug images
images {
format: JPEG

View File

@@ -0,0 +1,75 @@
"""
This example assumes you have preconfigured services with selectors in the form of
"ai.allegro.agent.serial=pod-<number>" and a targetPort of 10022.
The K8sIntegration component will label each pod accordingly.
"""
from argparse import ArgumentParser
from trains_agent.glue.k8s import K8sIntegration
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--queue", type=str, help="Queue to pull tasks from"
)
parser.add_argument(
"--ports-mode", action='store_true', default=False,
help="Ports-Mode will add a label to the pod which can be used as service, in order to expose ports"
)
parser.add_argument(
"--num-of-services", type=int, default=20,
help="Specify the number of k8s services to be used. Use only with ports-mode."
)
parser.add_argument(
"--base-port", type=int,
help="Used in conjunction with ports-mode, specifies the base port exposed by the services. "
"For pod #X, the port will be <base-port>+X"
)
parser.add_argument(
"--gateway-address", type=str, default=None,
help="Used in conjunction with ports-mode, specify the external address of the k8s ingress / ELB"
)
parser.add_argument(
"--pod-trains-conf", type=str,
help="Configuration file to be used by the pod itself (if not provided, current configuration is used)"
)
parser.add_argument(
"--overrides-yaml", type=str,
help="YAML file containing pod overrides to be used when launching a new pod"
)
parser.add_argument(
"--template-yaml", type=str,
help="YAML file containing pod template. If provided pod will be scheduled with kubectl apply "
"and overrides are ignored, otherwise it will be scheduled with kubectl run"
)
parser.add_argument(
"--ssh-server-port", type=int, default=0,
help="If non-zero, every pod will also start an SSH server on the selected port (default: zero, not active)"
)
return parser.parse_args()
def main():
args = parse_args()
user_props_cb = None
if args.ports_mode and args.base_port:
def k8s_user_props_cb(pod_number=0):
user_prop = {"k8s-pod-port": args.base_port + pod_number}
if args.gateway_address:
user_prop["k8s-gateway-address"] = args.gateway_address
return user_prop
user_props_cb = k8s_user_props_cb
k8s = K8sIntegration(
ports_mode=args.ports_mode, num_of_services=args.num_of_services, user_props_cb=user_props_cb,
overrides_yaml=args.overrides_yaml, trains_conf_file=args.pod_trains_conf, template_yaml=args.template_yaml,
extra_bash_init_script=K8sIntegration.get_ssh_server_bash(
ssh_port_number=args.ssh_server_port) if args.ssh_server_port else None
)
k8s.k8s_daemon(args.queue)
if __name__ == "__main__":
main()

View File

@@ -1,21 +1,20 @@
attrs>=18.0
enum34>=0.9 ; python_version < '3.6'
furl>=2.0.0
future>=0.16.0
humanfriendly>=2.1
jsonschema>=2.6.0
pathlib2>=2.3.0
psutil>=3.4.2
pyhocon>=0.3.38
pyparsing>=2.0.3
python-dateutil>=2.4.2
pyjwt>=1.6.4
PyYAML>=3.12
requests-file>=1.4.2
requests>=2.20.0
requirements_parser>=0.2.0
six>=1.11.0
tqdm>=4.19.5
typing>=3.6.4
urllib3>=1.21.1
attrs>=18.0,<20.4.0
enum34>=0.9,<1.2.0 ; python_version < '3.6'
furl>=2.0.0,<2.2.0
future>=0.16.0,<0.19.0
humanfriendly>=2.1,<9.2
jsonschema>=2.6.0,<3.3.0
pathlib2>=2.3.0,<2.4.0
psutil>=3.4.2,<5.9.0
pyhocon>=0.3.38,<0.4.0
pyparsing>=2.0.3,<2.5.0
python-dateutil>=2.4.2,<2.9.0
pyjwt>=1.6.4,<2.0.0
PyYAML>=3.12,<5.4.0
requests-file>=1.4.2,<1.6.0
requests>=2.20.0,<2.26.0
six>=1.11.0,<1.16.0
tqdm>=4.19.5,<4.55.0
typing>=3.6.4,<3.8.0
urllib3>=1.21.1,<1.27.0
virtualenv>=16,<20

View File

@@ -13,9 +13,12 @@
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
# git_user: ""
# git_pass: ""
# git_host: ""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
# force_git_ssh_port: 0
# Set the python version to use when creating the virtual environment and launching the experiment
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
@@ -44,6 +47,26 @@
# additional conda channels to use when installing with conda package manager
conda_channels: ["defaults", "conda-forge", "pytorch", ]
# If set to true, Task's "installed packages" are ignored,
# and the repository's "requirements.txt" is used instead
# force_repo_requirements_txt: false
# set the priority packages to be installed before the rest of the required packages
# priority_packages: ["cython", "numpy", "setuptools", ]
# set the optional priority packages to be installed before the rest of the required packages,
# In case a package installation fails, the package will be ignored,
# and the virtual environment process will continue
# priority_optional_packages: ["pygobject", ]
# set the post packages to be installed after all the rest of the required packages
# post_packages: ["horovod", ]
# set the optional post packages to be installed after all the rest of the required packages,
# In case a package installation fails, the package will be ignored,
# and the virtual environment process will continue
# post_optional_packages: []
# set to True to support torch nightly build installation,
# notice: torch nightly builds are ephemeral and are deleted from time to time
torch_nightly: false,
@@ -86,6 +109,22 @@
# optional shell script to run in docker when started before the experiment is started
# extra_docker_shell_script: ["apt-get install -y bindfs", ]
# optional uptime configuration, make sure to use only one of 'uptime/downtime' and not both.
# If uptime is specified, agent will actively poll (and execute) tasks in the time-spans defined here.
# Outside of the specified time-spans, the agent will be idle.
# Defined using a list of items of the format: "<hours> <days>".
# hours - use values 0-23, single values would count as start hour and end at midnight.
# days - use days in abbreviated format (SUN-SAT)
# use '-' for ranges and ',' to separate singular values.
# for example, to enable the workers every Sunday and Tuesday between 17:00-20:00 set uptime to:
# uptime: ["17-20 SUN,TUE"]
# optional downtime configuration, can be used only when uptime is not used.
# If downtime is specified, agent will be idle in the time-spans defined here.
# Outside of the specified time-spans, the agent will actively poll (and execute) tasks.
# Use the same format as described above for uptime
# downtime: []
# set to true in order to force "docker pull" before running an experiment using a docker image.
# This makes sure the docker image is updated.
docker_force_pull: false
@@ -109,6 +148,16 @@
# "(which {python_single_digit} && {python_single_digit} -m pip --version) || apt-get install -y {python_single_digit}-pip",
# ]
# set the preprocessing bash script to execute at the startup of any docker.
# all lines will be executed regardless of their exit code.
# docker_preprocess_bash_script = [
# "echo \"starting docker\"",
#]
# If False replace \r with \n and display full console output
# default is True, report a single \r line in a sequence of consecutive lines, per 5 seconds.
# suppress_carriage_return: true
# cuda versions used for solving pytorch wheel packages
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
# cuda_version: 10.1

View File

@@ -31,12 +31,18 @@
# X images are stored in the upload destination for each matplotlib plot title.
matplotlib_untitled_history_size: 100
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
# plot_max_num_digits: 5
# Settings for generated debug images
images {
format: JPEG
quality: 87
subsampling: 0
}
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph)
tensorboard_single_series_per_graph: false
}
network {
@@ -117,11 +123,11 @@
log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
null_log_propagate: false
task_log_buffer_capacity: 66
# disable urllib info and lower levels
disable_urllib3_info: True
disable_urllib3_info: true
}
development {
@@ -131,14 +137,30 @@
task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously
vcs_repo_detect_async: True
vcs_repo_detect_async: true
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
store_uncommitted_code_diff: true
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
support_stopping: true
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
default_output_uri: ""
# Default auto generated requirements optimize for smaller requirements
# If True, analyze the entire repository regardless of the entry point.
# If False, first analyze the entry point script, if it does not contain other to local files,
# do not analyze the entire repository.
force_analyze_entire_repo: false
# If set to true, *trains* update message will not be printed to the console
# this value can be overwritten with os environment variable TRAINS_SUPPRESS_UPDATE_MESSAGE=1
suppress_update_message: false
# If this flag is true (default is false), instead of analyzing the code with Pigar, analyze with `pip freeze`
detect_with_pip_freeze: false
# Development mode worker
worker {
@@ -149,7 +171,11 @@
ping_period_sec: 30
# Log all stdout & stderr
log_stdout: True
log_stdout: true
# compatibility feature, report memory usage for the entire machine
# default (false), report only on the running process and its sub-processes
report_global_mem_used: false
}
}
}
}

View File

@@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)
ENV_CONDA_ENV_PACKAGE = EnvEntry("TRAINS_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")

View File

@@ -40,6 +40,7 @@ class Session(TokenManager):
_session_requests = 0
_session_initial_timeout = (3.0, 10.)
_session_timeout = (10.0, 30.)
_session_initial_connect_retry = 4
_write_session_data_size = 15000
_write_session_timeout = (30.0, 30.)
@@ -96,7 +97,7 @@ class Session(TokenManager):
else:
self.config = load()
if initialize_logging:
self.config.initialize_logging()
self.config.initialize_logging(debug=kwargs.get('debug', False))
token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60
@@ -134,7 +135,6 @@ class Session(TokenManager):
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
self.__http_session = get_http_session_with_retry(**http_retries_config)
self.__worker = worker or gethostname()
@@ -144,7 +144,14 @@ class Session(TokenManager):
self.client = client or "api-{}".format(__version__)
# limit the reconnect retries, so we get an error if we are starting the session
http_no_retries_config = dict(**http_retries_config)
http_no_retries_config['connect'] = self._session_initial_connect_retry
self.__http_session = get_http_session_with_retry(**http_no_retries_config)
# try to connect with the server
self.refresh_token()
# create the default session with many retries
self.__http_session = get_http_session_with_retry(**http_retries_config)
# update api version from server response
try:
@@ -427,16 +434,15 @@ class Session(TokenManager):
@classmethod
def get_api_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return None
return ENV_HOST.get(default=(config.get("api.api_server", None) or
config.get("api.host", None) or cls.default_host))
@classmethod
def get_app_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return None
# get from config/environment
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
@@ -463,8 +469,8 @@ class Session(TokenManager):
@classmethod
def get_files_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return None
# get from config/environment
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
if files_host:
@@ -546,6 +552,9 @@ class Session(TokenManager):
else:
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
"exception: {}".format(res, ex))
except requests.ConnectionError as ex:
raise ValueError('Connection Error: it seems *api_server* is misconfigured. '
'Is this the TRAINS API server {} ?'.format('/'.join(ex.request.url.split('/')[:3])))
except Exception as ex:
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))

View File

@@ -190,7 +190,7 @@ class Config(object):
def reload(self):
self.replace(self._reload())
def initialize_logging(self):
def initialize_logging(self, debug=False):
logging_config = self._config.get("logging", None)
if not logging_config:
return False
@@ -217,6 +217,8 @@ class Config(object):
)
for logger in loggers:
handlers = logger.get("handlers", None)
if debug:
logger['level'] = 'DEBUG'
if not handlers:
continue
logger["handlers"] = [h for h in handlers if h not in deleted]

View File

@@ -46,6 +46,15 @@ class Environment(object):
local = 'local'
class UptimeConf(object):
min_api_version = "2.10"
queue_tag_on = "force_workers:on"
queue_tag_off = "force_workers:off"
worker_key = "force"
worker_value_off = ["off"]
worker_value_on = ["on"]
CONFIG_FILE_EXTENSION = '.conf'

View File

@@ -142,6 +142,7 @@ def main():
with open(str(conf_file), 'wt') as f:
header = '# TRAINS-AGENT configuration file\n' \
'api {\n' \
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
' api_server: %s\n' \
' web_server: %s\n' \
' files_server: %s\n' \

View File

@@ -12,12 +12,12 @@ import sys
import shutil
import traceback
from collections import defaultdict
from copy import copy, deepcopy
from copy import deepcopy
from datetime import datetime
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
from tempfile import mkdtemp, gettempdir
from tempfile import mkdtemp, NamedTemporaryFile
from time import sleep, time
from typing import Text, Optional, Any, Tuple
@@ -30,12 +30,11 @@ from pathlib2 import Path
from pyhocon import ConfigTree, ConfigFactory
from six.moves.urllib.parse import quote
from trains_agent.backend_config.defs import UptimeConf
from trains_agent.helper.check_update import start_check_update_daemon
from trains_agent.commands.base import resolve_names, ServiceCommandSection
from trains_agent.definitions import (
WORKER_ALREADY_REGISTERED,
ENVIRONMENT_SDK_PARAMS,
INVALID_WORKER_ID,
PROGRAM_NAME,
DEFAULT_VENV_UPDATE_URL,
ENV_TASK_EXECUTE_AS_USER,
@@ -66,12 +65,12 @@ from trains_agent.helper.base import (
get_python_path,
is_linux_platform,
rm_file,
add_python_path)
add_python_path, safe_remove_tree, )
from trains_agent.helper.console import ensure_text, print_text, decode_binary_lines
from trains_agent.helper.os.daemonize import daemonize_process
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.package.conda_api import CondaAPI
from trains_agent.helper.package.horovod_req import HorovodRequirement
from trains_agent.helper.package.post_req import PostRequirement
from trains_agent.helper.package.external_req import ExternalRequirements
from trains_agent.helper.package.pip_api.system import SystemPip
from trains_agent.helper.package.pip_api.venv import VirtualenvPip
@@ -89,21 +88,21 @@ from trains_agent.helper.process import (
get_bash_output,
shutdown_docker_process,
get_docker_id,
commit_docker
commit_docker, terminate_process,
)
from trains_agent.helper.package.cython_req import CythonRequirement
from trains_agent.helper.package.priority_req import PriorityPackageRequirement, PackageCollectorRequirement
from trains_agent.helper.repo import clone_repository_cached, RepoInfo, VCS
from trains_agent.helper.resource_monitor import ResourceMonitor
from trains_agent.helper.runtime_verification import check_runtime, print_uptime_properties
from trains_agent.session import Session
from trains_agent.helper.singleton import Singleton
from .events import Events
log = logging.getLogger(__name__)
DOCKER_ROOT_CONF_FILE = "/root/trains.conf"
DOCKER_DEFAULT_CONF_FILE = "/root/default_trains.conf"
@attr.s
class LiteralScriptManager(object):
"""
@@ -122,7 +121,24 @@ class LiteralScriptManager(object):
if not script:
return False
diff = script.diff
return diff and not diff.strip().lower().startswith("diff ")
if not diff:
return False
# test git diff prefix
if diff.lstrip().lower().startswith("diff "):
return False
# test git submodule prefix
# noinspection PyBroadException
try:
if diff.lstrip().lower().startswith("submodule ") and \
diff.splitlines()[1].lstrip().lower().startswith("diff "):
return False
except Exception:
pass
# none of the above
return True
@staticmethod
def write(task, directory, entry_point=None):
@@ -151,10 +167,11 @@ class LiteralScriptManager(object):
Create notebook file in appropriate location
:return: directory and script path
"""
log = logging.getLogger(__name__)
if repo_info and repo_info.root:
location = Path(repo_info.root, execution.working_dir)
else:
if execution.working_dir:
if execution.working_dir and execution.working_dir.strip() != '.':
log.warning(
"found task with `script.working_dir` (`%s`) but without `script.repository`, ignoring",
execution.working_dir,
@@ -212,6 +229,7 @@ class TaskStopSignal(object):
statuses.stopped,
statuses.failed,
statuses.published,
statuses.queued,
]
default = TaskStopReason.no_stop
stopping_message = "stopping"
@@ -263,7 +281,7 @@ class TaskStopSignal(object):
)
return TaskStopReason.stopped
if status in self.unexpected_statuses: ## and "worker" not in message:
if status in self.unexpected_statuses: # ## and "worker" not in message:
self.command.log("unexpected status change, task will terminate")
return TaskStopReason.status_changed
@@ -301,9 +319,11 @@ class Worker(ServiceCommandSection):
_requirement_substitutions = (
PytorchRequirement,
CythonRequirement,
HorovodRequirement,
PriorityPackageRequirement,
PostRequirement,
ExternalRequirements,
partial(PackageCollectorRequirement, collect_package=['trains']),
partial(PackageCollectorRequirement, collect_package=['clearml']),
)
# poll queues every _polling_interval seconds
@@ -319,6 +339,7 @@ class Worker(ServiceCommandSection):
_run_as_user_home = '/trains_agent_home'
_docker_fixed_user_cache = '/trains_agent_cache'
_temp_cleanup_list = []
@property
def service(self):
@@ -332,6 +353,8 @@ class Worker(ServiceCommandSection):
@staticmethod
def register_signal_handler():
def handler(*_):
for f in Worker._temp_cleanup_list + [Singleton.get_pid_file()]:
safe_remove_tree(f)
raise Sigterm()
signal.signal(signal.SIGTERM, handler)
@@ -372,7 +395,7 @@ class Worker(ServiceCommandSection):
self.temp_config_path = None
self.queues = ()
self.venv_folder = None # type: Optional[Text]
self.package_api = None # type: PackageManager
self.package_api = None # type: Optional[PackageManager]
self.global_package_api = None
self.is_venv_update = self._session.config.agent.venv_update.enabled
@@ -388,6 +411,15 @@ class Worker(ServiceCommandSection):
self._standalone_mode = None
self._services_mode = None
self._force_current_version = None
self._redirected_stdout_file_no = None
self._uptime_config = self._session.config.get("agent.uptime", None)
self._downtime_config = self._session.config.get("agent.downtime", None)
self._suppress_cr = self._session.config.get("agent.suppress_carriage_return", True)
# True - supported
# None - not initialized
# str - not supported, version string indicates last server version
self._runtime_props_support = None
@classmethod
def _verify_command_states(cls, kwargs):
@@ -433,7 +465,7 @@ class Worker(ServiceCommandSection):
pass
def run_one_task(self, queue, task_id, worker_args, docker=None):
# type: (Text, Text, WorkerParams) -> ()
# type: (Text, Text, WorkerParams, Optional[Text]) -> ()
"""
Run one task pulled from queue.
:param queue: ID of queue that task was pulled from
@@ -497,7 +529,7 @@ class Worker(ServiceCommandSection):
full_docker_cmd = self.docker_image_func(docker_image=docker_image, docker_arguments=docker_arguments)
try:
self._session.send_api(
tasks_api.EditRequest(task_id, force=True, execution=dict(
tasks_api.EditRequest(task_id, force=True, execution=dict( # noqa
docker_cmd=' '.join([docker_image] + docker_arguments) if docker_arguments else docker_image)))
except Exception:
pass
@@ -509,6 +541,10 @@ class Worker(ServiceCommandSection):
'--full-monitoring' if self._services_mode else '--disable-monitoring',
'--standalone-mode' if self._standalone_mode else '',
task_id)
# send the actual used command line to the backend
self.send_logs(task_id=task_id, lines=['Executing: {}\n'.format(full_docker_cmd)], level="INFO")
cmd = Argv(*full_docker_cmd)
print('Running Docker:\n{}\n'.format(str(cmd)))
else:
@@ -569,12 +605,12 @@ class Worker(ServiceCommandSection):
else:
self.handle_task_termination(task_id, status, stop_signal_status)
# remove temp files after we sent everything to the backend
safe_remove_file(temp_stdout_name)
safe_remove_file(temp_stderr_name)
if self.docker_image_func:
shutdown_docker_process(docker_cmd_contains='--id {}\'\"'.format(task_id))
safe_remove_file(temp_stdout_name)
safe_remove_file(temp_stderr_name)
def run_tasks_loop(self, queues, worker_params):
def run_tasks_loop(self, queues, worker_params, priority_order=True):
"""
:summary: Pull and run tasks from queues.
:description: 1. Go through ``queues`` by order.
@@ -584,16 +620,28 @@ class Worker(ServiceCommandSection):
:type queues: list of ``Text``
:param worker_params: Worker command line arguments
:type worker_params: ``trains_agent.helper.process.WorkerParams``
:param priority_order: If True pull order in priority manner. always from the first
If False, pull from each queue once in a round robin manner
:type priority_order: bool
"""
if not self._daemon_foreground:
print('Starting infinite task polling loop...')
_last_machine_update_ts = 0
while True:
while True:
queue_tags = None
runtime_props = None
# iterate over queues (priority style, queues[0] is highest)
for queue in queues:
if queue_tags is None or runtime_props is None:
queue_tags, runtime_props = self.get_worker_properties(queues)
if not self.should_be_currently_active(queue_tags[queue], runtime_props):
continue
# get next task in queue
try:
response = self._session.send_api(
@@ -614,6 +662,16 @@ class Worker(ServiceCommandSection):
print("No tasks in queue {}".format(queue))
continue
# clear output log if we start a new Task
if not worker_params.debug and self._redirected_stdout_file_no is not None and \
self._redirected_stdout_file_no > 2:
# noinspection PyBroadException
try:
os.lseek(self._redirected_stdout_file_no, 0, 0)
os.ftruncate(self._redirected_stdout_file_no, 0)
except:
pass
self.send_logs(
task_id=task_id,
lines=["task {} pulled from {} by worker {}\n".format(task_id, queue, self.worker_id)],
@@ -622,7 +680,14 @@ class Worker(ServiceCommandSection):
self.report_monitor(ResourceMonitor.StatusReport(queues=queues, queue=queue, task=task_id))
self.run_one_task(queue, task_id, worker_params)
self.report_monitor(ResourceMonitor.StatusReport(queues=self.queues))
break
queue_tags = None
runtime_props = None
# if we are using priority start pulling from the first always,
# if we are doing round robin, pull from the next one
if priority_order:
break
else:
# sleep and retry polling
if self._daemon_foreground or worker_params.debug:
@@ -632,6 +697,77 @@ class Worker(ServiceCommandSection):
if self._session.config["agent.reload_config"]:
self.reload_config()
def get_worker_properties(self, queue_ids):
queue_tags = {
q.id: {'name': q.name, 'tags': q.tags}
for q in self._session.send_api(
queues_api.GetAllRequest(id=queue_ids, only_fields=["id", "tags"])
).queues
}
runtime_props = self.get_runtime_properties()
return queue_tags, runtime_props
def get_runtime_properties(self):
if self._runtime_props_support is not True:
# either not supported or never tested
if self._runtime_props_support == self._session.api_version:
# tested against latest api_version, not supported
return []
if not self._session.check_min_api_version(UptimeConf.min_api_version):
# not supported due to insufficient api_version
self._runtime_props_support = self._session.api_version
return []
try:
res = self.get("get_runtime_properties", worker=self.worker_id)["runtime_properties"]
# definitely supported
self._runtime_props_support = True
return res
except APIError:
self._runtime_props_support = self._session.api_version
return []
def should_be_currently_active(self, current_queue, runtime_properties):
"""
Checks if a worker is active according to queue tags, worker's runtime properties and uptime schedule.
"""
if UptimeConf.queue_tag_off in current_queue['tags']:
self.log.debug("Queue {} is tagged '{}', worker will not pull tasks".format(
current_queue['name'], UptimeConf.queue_tag_off)
)
return False
if UptimeConf.queue_tag_on in current_queue['tags']:
self.log.debug("Queue {} is tagged '{}', worker will pull tasks".format(
current_queue['name'], UptimeConf.queue_tag_on)
)
return True
force_flag = next(
(prop for prop in runtime_properties if prop["key"] == UptimeConf.worker_key), None
)
if force_flag:
if force_flag["value"].lower() in UptimeConf.worker_value_off:
self.log.debug("worker has the following runtime property: '{}'. worker will not pull tasks".format(
force_flag)
)
return False
elif force_flag["value"].lower() in UptimeConf.worker_value_on:
self.log.debug("worker has the following runtime property: '{}'. worker will pull tasks".format(
force_flag)
)
return True
else:
print(
"Warning: invalid runtime_property '{}: {}' supported values are: '{}/{}', ignoring".format(
force_flag["key"], force_flag["value"], UptimeConf.worker_value_on, UptimeConf.worker_value_off
)
)
if self._uptime_config:
self.log.debug("following uptime configurations")
return check_runtime(self._uptime_config)
if self._downtime_config:
self.log.debug("following downtime configurations")
return check_runtime(self._downtime_config, is_uptime=False)
return True
def reload_config(self):
try:
reloaded = self._session.reload()
@@ -649,7 +785,7 @@ class Worker(ServiceCommandSection):
def check(self, **_):
try:
check_directory_path(str(Path(".").resolve()))
check_directory_path(str(Path(".").resolve()), check_whitespace_in_path=False)
except OSError as e:
if e.errno == errno.ENOENT:
raise CommandFailedError("current working directory does not exist")
@@ -674,7 +810,7 @@ class Worker(ServiceCommandSection):
self._session.print_configuration()
def daemon(self, queues, log_level, foreground=False, docker=False, detached=False, **kwargs):
def daemon(self, queues, log_level, foreground=False, docker=False, detached=False, order_fairness=False, **kwargs):
# if we do not need to create queues, make sure they are valid
# match previous behaviour when we validated queue names before everything else
queues = self._resolve_queue_names(queues, create_if_missing=kwargs.get('create_queue', False))
@@ -685,6 +821,34 @@ class Worker(ServiceCommandSection):
if self._services_mode:
kwargs = self._verify_command_states(kwargs)
docker = docker or kwargs.get('docker')
self._uptime_config = kwargs.get('uptime', None) or self._uptime_config
self._downtime_config = kwargs.get('downtime', None) or self._downtime_config
if self._uptime_config and self._downtime_config:
self.log.error(
"Both uptime and downtime were specified when only one of them could be used. Both will be ignored."
)
self._uptime_config = None
self._downtime_config = None
# We are not running a daemon we are killing one.
# find the pid send termination signal and leave
if kwargs.get('stop', False):
return 1 if not self._kill_daemon() else 0
queues_info = [
q.to_dict()
for q in self._session.send_api(
queues_api.GetAllRequest(id=queues)
).queues
]
if kwargs.get('status', False):
runtime_properties = self.get_runtime_properties()
if self._downtime_config:
print_uptime_properties(self._downtime_config, queues_info, runtime_properties, is_uptime=False)
else:
print_uptime_properties(self._uptime_config, queues_info, runtime_properties)
return 1
# make sure we only have a single instance,
# also make sure we set worker_id properly and cache folders
@@ -697,12 +861,6 @@ class Worker(ServiceCommandSection):
self.log.debug("starting resource monitor thread")
print("Worker \"{}\" - ".format(self.worker_id), end='')
queues_info = [
self._session.send_api(
queues_api.GetByIdRequest(queue)
).queue.to_dict()
for queue in queues
]
columns = ("id", "name", "tags")
print("Listening to queues:")
print_table(queues_info, columns=columns, titles=columns)
@@ -711,9 +869,8 @@ class Worker(ServiceCommandSection):
self._register(queues)
# create temp config file with current configuration
self.temp_config_path = safe_mkstemp(
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
)
self.temp_config_path = NamedTemporaryFile(
suffix=".cfg", prefix=".trains_agent.", mode='w+t').name
# print docker image
if docker is not False and docker is not None:
@@ -753,6 +910,9 @@ class Worker(ServiceCommandSection):
)
)
if not self._session.debug_mode:
self._temp_cleanup_list.append(name)
if not detached:
# redirect std out/err to new file
sys.stdout = sys.stderr = out_file
@@ -760,6 +920,7 @@ class Worker(ServiceCommandSection):
# in detached mode
# fully detach stdin.stdout/stderr and leave main process, running in the background
daemonize_process(out_file.fileno())
self._redirected_stdout_file_no = out_file.fileno()
# make sure we update the singleton lock file to the new pid
Singleton.update_pid_file()
# reprint headers to std file (we are now inside the daemon process)
@@ -779,6 +940,7 @@ class Worker(ServiceCommandSection):
debug=self._session.debug_mode,
trace=self._session.trace,
),
priority_order=not order_fairness,
)
except Exception:
tb = six.text_type(traceback.format_exc())
@@ -842,22 +1004,27 @@ class Worker(ServiceCommandSection):
stop_signal=None, # type: Optional[TaskStopSignal]
**kwargs # type: Any
):
# type: (...) -> Tuple[Optional[int], TaskStopReason]
def _print_file(file_path, prev_line_count):
# type: (...) -> Tuple[Optional[int], Optional[TaskStopReason]]
def _print_file(file_path, prev_pos=0):
with open(file_path, "rb") as f:
f.seek(prev_pos)
binary_text = f.read()
if not binary_text:
return []
pos = f.tell()
# skip the previously printed lines,
blines = binary_text.split(b'\n')[prev_line_count:]
blines = binary_text.split(b'\n') if binary_text else []
if not blines:
return blines
return decode_binary_lines(blines if blines[-1] else blines[:-1])
return blines, pos
return (
decode_binary_lines(blines if blines[-1] else blines[:-1],
replace_cr=not self._suppress_cr,
overwrite_cr=self._suppress_cr),
pos
)
stdout = open(stdout_path, "wt")
stderr = open(stderr_path, "wt") if stderr_path else stdout
stdout_line_count, stdout_last_lines = 0, []
stderr_line_count, stderr_last_lines = 0, []
stdout_line_count, stdout_pos_count, stdout_last_lines = 0, 0, []
stderr_line_count, stderr_pos_count, stderr_last_lines = 0, 0, []
service_mode_internal_agent_started = None
stopping = False
status = None
@@ -896,7 +1063,7 @@ class Worker(ServiceCommandSection):
stderr.flush()
# get diff from previous poll
printed_lines = _print_file(stdout_path, stdout_line_count)
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
if self._services_mode and not stopping and not status:
# if the internal agent started, we stop logging, it will take over logging.
# if the internal agent started running the task itself, it will return status==0,
@@ -906,13 +1073,10 @@ class Worker(ServiceCommandSection):
if status is not None:
stop_reason = 'Service started'
stdout_line_count += self.send_logs(
task_id, printed_lines
)
stdout_line_count += self.send_logs(task_id, printed_lines)
if stderr_path:
stderr_line_count += self.send_logs(
task_id, _print_file(stderr_path, stderr_line_count)
)
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += self.send_logs(task_id, printed_lines)
except subprocess.CalledProcessError as ex:
# non zero return code
@@ -923,9 +1087,11 @@ class Worker(ServiceCommandSection):
raise
except Exception:
# we should not get here, but better safe than sorry
stdout_line_count += self.send_logs(task_id, _print_file(stdout_path, stdout_line_count))
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
stdout_line_count += self.send_logs(task_id, printed_lines)
if stderr_path:
stderr_line_count += self.send_logs(task_id, _print_file(stderr_path, stderr_line_count))
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += self.send_logs(task_id, printed_lines)
stop_reason = 'Exception occurred'
status = -1
@@ -939,13 +1105,11 @@ class Worker(ServiceCommandSection):
stderr.close()
# Send last lines
stdout_line_count += self.send_logs(
task_id, _print_file(stdout_path, stdout_line_count)
)
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
stdout_line_count += self.send_logs(task_id, printed_lines)
if stderr_path:
stderr_line_count += self.send_logs(
task_id, _print_file(stderr_path, stderr_line_count)
)
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += self.send_logs(task_id, printed_lines)
return status, stop_reason
@@ -1024,7 +1188,8 @@ class Worker(ServiceCommandSection):
success = False
if not success:
raise ValueError("Failed applying git diff:\n{}\n\nERROR! Failed applying git diff, see diff above.".format(diff))
raise ValueError("Failed applying git diff:\n{}\n\n"
"ERROR! Failed applying git diff, see diff above.".format(diff))
@resolve_names
def build(
@@ -1049,10 +1214,15 @@ class Worker(ServiceCommandSection):
execution = self.get_execution_info(current_task)
try:
requirements = current_task.script.requirements
except AttributeError:
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
requirements = None
print("[package_manager.force_repo_requirements_txt=true] "
"Skipping requirements, using repository \"requirements.txt\" ")
else:
try:
requirements = current_task.script.requirements
except AttributeError:
requirements = None
if not python_version:
try:
@@ -1063,8 +1233,8 @@ class Worker(ServiceCommandSection):
except:
python_version = None
venv_folder, requirements_manager = self.install_virtualenv(venv_dir=target,
requested_python_version=python_version)
venv_folder, requirements_manager = self.install_virtualenv(
venv_dir=target, requested_python_version=python_version, execution_info=execution)
if self._default_pip:
if install_globally and self.global_package_api:
@@ -1219,7 +1389,8 @@ class Worker(ServiceCommandSection):
print("Cloning task id={}".format(task_id))
current_task = self._session.api_client.tasks.get_by_id(
self._session.send_api(
tasks_api.CloneRequest(task=current_task.id, new_task_name='Clone of {}'.format(current_task.name))
tasks_api.CloneRequest(task=current_task.id,
new_task_name='Clone of {}'.format(current_task.name))
).id
)
print("Task cloned, new task id={}".format(current_task.id))
@@ -1269,10 +1440,15 @@ class Worker(ServiceCommandSection):
execution = self.get_execution_info(current_task)
try:
requirements = current_task.script.requirements
except AttributeError:
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
requirements = None
print("[package_manager.force_repo_requirements_txt=true] "
"Skipping requirements, using repository \"requirements.txt\" ")
else:
try:
requirements = current_task.script.requirements
except AttributeError:
requirements = None
try:
python_ver = current_task.script.binary
@@ -1282,8 +1458,8 @@ class Worker(ServiceCommandSection):
except:
python_ver = None
venv_folder, requirements_manager = self.install_virtualenv(standalone_mode=standalone_mode,
requested_python_version=python_ver)
venv_folder, requirements_manager = self.install_virtualenv(
standalone_mode=standalone_mode, requested_python_version=python_ver, execution_info=execution)
if not standalone_mode:
if self._default_pip:
@@ -1308,7 +1484,10 @@ class Worker(ServiceCommandSection):
# do not update the task packages if we are using conda,
# it will most likely make the task environment unreproducible
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None,
skip_freeze_update = self.is_conda and not self._session.config.get(
"agent.package_manager.conda_full_env_update", False)
freeze = self.freeze_task_environment(current_task.id if not skip_freeze_update else None,
requirements_manager=requirements_manager)
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
@@ -1362,8 +1541,7 @@ class Worker(ServiceCommandSection):
self._update_commit_id(current_task.id, execution, repo_info)
# Add the script CWD to the python path
python_path = get_python_path(script_dir, execution.entry_point, self.package_api) \
if not self.is_conda else None
python_path = get_python_path(script_dir, execution.entry_point, self.package_api, is_conda_env=self.is_conda)
if os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH):
python_path = add_python_path(python_path, os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH))
if python_path:
@@ -1661,7 +1839,7 @@ class Worker(ServiceCommandSection):
def install_requirements_for_package_api(
self, execution, repo_info, requirements_manager, cached_requirements=None, cwd=None, package_api=None,
):
# type: (ExecutionInfo, RepoInfo, RequirementsManager, Optional[dict]) -> None
# type: (ExecutionInfo, RepoInfo, RequirementsManager, Optional[dict], Optional[str], Optional[Any]) -> None
"""
:summary: Install requirements for task script using pip.
:description: A file named "requirements.txt" is looked for in each containing folder between the
@@ -1671,6 +1849,7 @@ class Worker(ServiceCommandSection):
:param repo_info: repository information
:param requirements_manager: requirements manager for task
:param cached_requirements: cached requirements from previous run
:param cwd: current folder
:param package_api: package_api to be used when installing requirements
"""
if package_api:
@@ -1684,12 +1863,13 @@ class Worker(ServiceCommandSection):
package_api.set_selected_package_manager()
# always install cython,
# if we have a specific version in the requirements,
# the CythonRequirement(SimpleSubstitution) will reinstall cython with the specific version
# the PriorityPackageRequirement(SimpleSubstitution) will reinstall cython with the specific version
if not self.is_conda:
package_api.out_of_scope_install_package('Cython')
cached_requirements_failed = False
if cached_requirements and ('pip' in cached_requirements or 'conda' in cached_requirements):
if cached_requirements and (cached_requirements.get('pip') is not None or
cached_requirements.get('conda') is not None):
self.log("Found task requirements section, trying to install")
try:
package_api.load_requirements(cached_requirements)
@@ -1859,8 +2039,9 @@ class Worker(ServiceCommandSection):
)
)
def install_virtualenv(self, venv_dir=None, requested_python_version=None, standalone_mode=False):
# type: (str, str, bool) -> Tuple[Path, RequirementsManager]
def install_virtualenv(
self, venv_dir=None, requested_python_version=None, standalone_mode=False, execution_info=None):
# type: (str, str, bool, ExecutionInfo) -> Tuple[Path, RequirementsManager]
"""
Install a new python virtual environment, removing the old one if exists
:return: virtualenv directory and requirements manager to use with task
@@ -1907,6 +2088,7 @@ class Worker(ServiceCommandSection):
python=executable_version_suffix if self.is_conda else executable_name,
path=venv_dir,
requirements_manager=requirements_manager,
execution_info=execution_info,
)
global_package_manager_params = dict(
@@ -2005,26 +2187,33 @@ class Worker(ServiceCommandSection):
mounted_pip_dl_dir = '/root/.trains/pip-download-cache'
mounted_vcs_cache = '/root/.trains/vcs-cache'
mounted_venv_dir = '/root/.trains/venvs-builds'
host_cache = Path(os.path.expandvars(self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
host_pip_dl = Path(os.path.expandvars(self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
host_vcs_cache = Path(os.path.expandvars(self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
host_cache = Path(os.path.expandvars(
self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
host_pip_dl = Path(os.path.expandvars(
self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
host_vcs_cache = Path(os.path.expandvars(
self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir)
temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir)
temp_config.put("agent.vcs_cache.path", mounted_vcs_cache)
temp_config.put("agent.package_manager.system_site_packages", True)
temp_config.put("agent.package_manager.conda_env_as_base_docker", False)
temp_config.put("agent.default_python", "")
temp_config.put("agent.python_binary", "")
temp_config.put("agent.cuda_version", "")
temp_config.put("agent.cudnn_version", "")
temp_config.put("agent.venvs_dir", mounted_venv_dir)
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or self._session.config.get("agent.git_user", None)))
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or self._session.config.get("agent.git_pass", None)))
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or
self._session.config.get("agent.git_user", None)))
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or
self._session.config.get("agent.git_pass", None)))
host_apt_cache = Path(os.path.expandvars(self._session.config.get(
"agent.docker_apt_cache", '~/.trains/apt-cache'))).expanduser().as_posix()
host_pip_cache = Path(os.path.expandvars(self._session.config.get(
"agent.docker_pip_cache", '~/.trains/pip-cache'))).expanduser().as_posix()
host_ssh_cache = mkdtemp(prefix='trains_agent.ssh.')
self._temp_cleanup_list.append(host_ssh_cache)
# make sure all folders are valid
Path(host_apt_cache).mkdir(parents=True, exist_ok=True)
@@ -2042,7 +2231,7 @@ class Worker(ServiceCommandSection):
shutil.copytree(Path('~/.ssh').expanduser().as_posix(), host_ssh_cache)
except Exception:
host_ssh_cache = None
log.warning('Failed creating temporary copy of ~/.ssh for git credential')
self.log.warning('Failed creating temporary copy of ~/.ssh for git credential')
pass
# check if the .git credentials exist:
@@ -2065,6 +2254,7 @@ class Worker(ServiceCommandSection):
extra_shell_script_str = " ; ".join(map(str, cmds)) + " ; "
bash_script = self._session.config.get("agent.docker_init_bash_script", None)
preprocess_bash_script = self._session.config.get("agent.docker_preprocess_bash_script", None)
self.temp_config_path = self.temp_config_path or safe_mkstemp(
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
@@ -2085,6 +2275,7 @@ class Worker(ServiceCommandSection):
standalone_mode=self._standalone_mode,
force_current_version=self._force_current_version,
bash_script=bash_script,
preprocess_bash_script=preprocess_bash_script,
)
return temp_config, partial(docker_cmd_functor, docker_cmd, temp_config)
@@ -2098,7 +2289,8 @@ class Worker(ServiceCommandSection):
host_pip_dl, mounted_pip_dl,
host_vcs_cache, mounted_vcs_cache,
standalone_mode=False, extra_docker_arguments=None, extra_shell_script=None,
force_current_version=None, host_git_credentials=None, bash_script=None):
force_current_version=None, host_git_credentials=None,
bash_script=None, preprocess_bash_script=None):
docker = 'docker'
base_cmd = [docker, 'run', '-t']
@@ -2115,7 +2307,7 @@ class Worker(ServiceCommandSection):
if os.environ.get('TRAINS_DOCKER_SKIP_GPUS_FLAG', None):
dockers_nvidia_visible_devices = gpu_devices
else:
base_cmd += ['--gpus', 'device='+gpu_devices, ]
base_cmd += ['--gpus', '\"device={}\"'.format(gpu_devices), ]
# We are using --gpu, so we should not pass NVIDIA_VISIBLE_DEVICES, I think.
# base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES=' + gpu_devices, ]
elif gpu_devices.strip() == 'none':
@@ -2132,7 +2324,8 @@ class Worker(ServiceCommandSection):
base_cmd += [str(a) for a in extra_docker_arguments if a]
# check if running inside a kubernetes
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and os.environ.get('KUBERNETES_PORT')):
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and
os.environ.get('KUBERNETES_PORT')):
# map network to sibling docker, unless we have other network argument
if not any(a.strip().startswith('--network') for a in base_cmd):
try:
@@ -2154,7 +2347,7 @@ class Worker(ServiceCommandSection):
print('Warning: K8S mount missing, ignoring cached folder {}'.format(m))
host_mounts[i] = None
else:
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt)
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1)
host_apt_cache, host_pip_cache, host_pip_dl, host_cache, host_vcs_cache = host_mounts
# copy the configuration file into the mounted folder
@@ -2177,6 +2370,8 @@ class Worker(ServiceCommandSection):
raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
base_cmd += ['-e', 'TRAINS_WORKER_ID='+worker_id, ]
# update the docker image, so the system knows where it runs
base_cmd += ['-e', 'TRAINS_DOCKER_IMAGE={} {}'.format(docker_image, ' '.join(docker_arguments)).strip()]
# if we are running a RC version, install the same version in the docker
# because the default latest, will be a release version (not RC)
@@ -2200,21 +2395,33 @@ class Worker(ServiceCommandSection):
if not standalone_mode:
if not bash_script:
# Find the highest python version installed, or install from apt-get
# python+pip is the requirement to match
bash_script = [
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
"chown -R root /root/.cache/pip",
"export DEBIAN_FRONTEND=noninteractive",
"apt-get update",
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
"(which {python_single_digit} && {python_single_digit} -m pip --version) || " +
"apt-get install -y {python_single_digit}-pip",
"declare LOCAL_PYTHON",
"for i in {{10..5}}; do which {python_single_digit}.$i && " +
"{python_single_digit}.$i -m pip --version && " +
"export LOCAL_PYTHON=$(which {python_single_digit}.$i) && break ; done",
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y {python_single_digit}-pip",
]
if preprocess_bash_script:
bash_script = preprocess_bash_script + bash_script
docker_bash_script = " ; ".join(bash_script) if not isinstance(bash_script, str) else bash_script
# make sure that if we do not have $LOCAL_PYTHON defined
# we set it to python3
update_scheme += (
docker_bash_script + " ; " +
"{python} -m pip install -U \"pip{pip_version}\" ; " +
"{python} -m pip install -U {trains_agent_wheel} ; ").format(
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON={python} ; " +
"$LOCAL_PYTHON -m pip install -U \"pip{pip_version}\" ; " +
"$LOCAL_PYTHON -m pip install -U {trains_agent_wheel} ; ").format(
python_single_digit=python_version.split('.')[0],
python=python_version, pip_version=PackageManager.get_pip_version(),
trains_agent_wheel=trains_agent_wheel)
@@ -2235,7 +2442,7 @@ class Worker(ServiceCommandSection):
update_scheme +
extra_shell_script +
"cp {} {} ; ".format(DOCKER_ROOT_CONF_FILE, DOCKER_DEFAULT_CONF_FILE) +
"NVIDIA_VISIBLE_DEVICES={nv_visible} {python} -u -m trains_agent ".format(
"NVIDIA_VISIBLE_DEVICES={nv_visible} $LOCAL_PYTHON -u -m trains_agent ".format(
nv_visible=dockers_nvidia_visible_devices, python=python_version)
])
@@ -2266,7 +2473,8 @@ class Worker(ServiceCommandSection):
os.setuid(self.uid)
# create a home folder for our user
trains_agent_home = self._run_as_user_home + '{}'.format('.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
trains_agent_home = self._run_as_user_home + '{}'.format(
'.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
try:
home_folder = self._run_as_user_home
rm_tree(home_folder)
@@ -2322,18 +2530,26 @@ class Worker(ServiceCommandSection):
return command, script_dir
def _kill_daemon(self):
worker_id, worker_name = self._generate_worker_id_name()
# Iterate over all running process
for pid, uid, slot, file in sorted(Singleton.get_running_pids(), key=lambda x: x[1] or ''):
# wither we have a match for the worker_id or we just pick the first one
if pid >= 0 and uid is not None and (
(worker_id and uid == worker_id) or
(not worker_id and uid.startswith('{}:'.format(worker_name)))):
# this is us kill it
print('Terminating trains-agent worker_id={} pid={}'.format(uid, pid))
if not terminate_process(pid, timeout=10):
error('Could not terminate process pid={}'.format(pid))
return True
print('Could not find a running trains-agent instance with worker_name={} worker_id={}'.format(
worker_name, worker_id))
return False
def _singleton(self):
# ensure singleton
worker_id = self._session.config["agent.worker_id"]
worker_name = self._session.config["agent.worker_name"]
if not worker_id and os.environ.get('NVIDIA_VISIBLE_DEVICES') is not None:
nvidia_visible_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES')
if nvidia_visible_devices and nvidia_visible_devices.lower() != 'none':
worker_id = '{}:gpu{}'.format(worker_name, nvidia_visible_devices)
elif nvidia_visible_devices == '':
pass
else:
worker_name = '{}:cpu'.format(worker_name)
worker_id, worker_name = self._generate_worker_id_name()
# if we are running in services mode, we allow double register since
# docker-compose will kill instances before they cleanup
@@ -2348,6 +2564,19 @@ class Worker(ServiceCommandSection):
# update folders based on free slot
self._session.create_cache_folders(slot_index=worker_slot)
def _generate_worker_id_name(self):
worker_id = self._session.config["agent.worker_id"]
worker_name = self._session.config["agent.worker_name"]
if not worker_id and os.environ.get('NVIDIA_VISIBLE_DEVICES') is not None:
nvidia_visible_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES')
if nvidia_visible_devices and nvidia_visible_devices.lower() != 'none':
worker_id = '{}:gpu{}'.format(worker_name, nvidia_visible_devices)
elif nvidia_visible_devices == '':
pass
else:
worker_name = '{}:cpu'.format(worker_name)
return worker_id, worker_name
def _resolve_queue_names(self, queues, create_if_missing=False):
if not queues:
default_queue = self._session.send_api(queues_api.GetDefaultRequest())

View File

@@ -122,6 +122,7 @@ PIP_EXTRA_INDICES = [
DEFAULT_PIP_DOWNLOAD_CACHE = normalize_path(CONFIG_DIR, "pip-download-cache")
ENV_AGENT_GIT_USER = EnvironmentConfig('TRAINS_AGENT_GIT_USER')
ENV_AGENT_GIT_PASS = EnvironmentConfig('TRAINS_AGENT_GIT_PASS')
ENV_AGENT_GIT_HOST = EnvironmentConfig('TRAINS_AGENT_GIT_HOST')
ENV_TASK_EXECUTE_AS_USER = 'TRAINS_AGENT_EXEC_USER'
ENV_TASK_EXTRA_PYTHON_PATH = 'TRAINS_AGENT_EXTRA_PYTHON_PATH'
ENV_DOCKER_HOST_MOUNT = EnvironmentConfig('TRAINS_AGENT_K8S_HOST_MOUNT', 'TRAINS_AGENT_DOCKER_HOST_MOUNT')

0
trains_agent/external/__init__.py vendored Normal file
View File

View File

@@ -0,0 +1,22 @@
from .parser import parse # noqa
_MAJOR = 0
_MINOR = 2
_PATCH = 0
def version_tuple():
'''
Returns a 3-tuple of ints that represent the version
'''
return (_MAJOR, _MINOR, _PATCH)
def version():
'''
Returns a string representation of the version
'''
return '%d.%d.%d' % (version_tuple())
__version__ = version()

View File

@@ -0,0 +1,44 @@
import re
# Copied from pip
# https://github.com/pypa/pip/blob/281eb61b09d87765d7c2b92f6982b3fe76ccb0af/pip/index.py#L947
HASH_ALGORITHMS = set(['sha1', 'sha224', 'sha384', 'sha256', 'sha512', 'md5'])
extras_require_search = re.compile(
r'(?P<name>.+)\[(?P<extras>[^\]]+)\]').search
def parse_fragment(fragment_string):
"""Takes a fragment string nd returns a dict of the components"""
fragment_string = fragment_string.lstrip('#')
try:
return dict(
key_value_string.split('=')
for key_value_string in fragment_string.split('&')
)
except ValueError:
raise ValueError(
'Invalid fragment string {fragment_string}'.format(
fragment_string=fragment_string
)
)
def get_hash_info(d):
"""Returns the first matching hashlib name and value from a dict"""
for key in d.keys():
if key.lower() in HASH_ALGORITHMS:
return key, d[key]
return None, None
def parse_extras_require(egg):
if egg is not None:
match = extras_require_search(egg)
if match is not None:
name = match.group('name')
extras = match.group('extras')
return name, [extra.strip() for extra in extras.split(',')]
return egg, []

View File

@@ -0,0 +1,50 @@
import os
import warnings
from .requirement import Requirement
def parse(reqstr):
"""
Parse a requirements file into a list of Requirements
See: pip/req.py:parse_requirements()
:param reqstr: a string or file like object containing requirements
:returns: a *generator* of Requirement objects
"""
filename = getattr(reqstr, 'name', None)
try:
# Python 2.x compatibility
if not isinstance(reqstr, basestring):
reqstr = reqstr.read()
except NameError:
# Python 3.x only
if not isinstance(reqstr, str):
reqstr = reqstr.read()
for line in reqstr.splitlines():
line = line.strip()
if line == '':
continue
elif not line or line.startswith('#'):
# comments are lines that start with # only
continue
elif line.startswith('-r') or line.startswith('--requirement'):
_, new_filename = line.split()
new_file_path = os.path.join(os.path.dirname(filename or '.'),
new_filename)
with open(new_file_path) as f:
for requirement in parse(f):
yield requirement
elif line.startswith('-f') or line.startswith('--find-links') or \
line.startswith('-i') or line.startswith('--index-url') or \
line.startswith('--extra-index-url') or \
line.startswith('--no-index'):
warnings.warn('Private repos not supported. Skipping.')
continue
elif line.startswith('-Z') or line.startswith('--always-unzip'):
warnings.warn('Unused option --always-unzip. Skipping.')
continue
else:
yield Requirement.parse(line)

View File

@@ -0,0 +1,241 @@
from __future__ import unicode_literals
import re
from pkg_resources import Requirement as Req
from .fragment import get_hash_info, parse_fragment, parse_extras_require
from .vcs import VCS, VCS_SCHEMES
URI_REGEX = re.compile(
r'^(?P<scheme>https?|file|ftps?)://(?P<path>[^#]+)'
r'(#(?P<fragment>\S+))?'
)
VCS_REGEX = re.compile(
r'^(?P<scheme>{0})://'.format(r'|'.join(
[scheme.replace('+', r'\+') for scheme in VCS_SCHEMES])) +
r'((?P<login>[^/@]+)@)?'
r'(?P<path>[^#@]+)'
r'(@(?P<revision>[^#]+))?'
r'(#(?P<fragment>\S+))?'
)
# This matches just about everyting
LOCAL_REGEX = re.compile(
r'^((?P<scheme>file)://)?'
r'(?P<path>[^#]+)' +
r'(#(?P<fragment>\S+))?'
)
class Requirement(object):
"""
Represents a single requirementfrom trains_agent.external.requirements_parser.requirement import Requirement
Typically instances of this class are created with ``Requirement.parse``.
For local file requirements, there's no verification that the file
exists. This class attempts to be *dict-like*.
See: http://www.pip-installer.org/en/latest/logic.html
**Members**:
* ``line`` - the actual requirement line being parsed
* ``editable`` - a boolean whether this requirement is "editable"
* ``local_file`` - a boolean whether this requirement is a local file/path
* ``specifier`` - a boolean whether this requirement used a requirement
specifier (eg. "django>=1.5" or "requirements")
* ``vcs`` - a string specifying the version control system
* ``revision`` - a version control system specifier
* ``name`` - the name of the requirement
* ``uri`` - the URI if this requirement was specified by URI
* ``subdirectory`` - the subdirectory fragment of the URI
* ``path`` - the local path to the requirement
* ``hash_name`` - the type of hashing algorithm indicated in the line
* ``hash`` - the hash value indicated by the requirement line
* ``extras`` - a list of extras for this requirement
(eg. "mymodule[extra1, extra2]")
* ``specs`` - a list of specs for this requirement
(eg. "mymodule>1.5,<1.6" => [('>', '1.5'), ('<', '1.6')])
"""
def __init__(self, line):
# Do not call this private method
self.line = line
self.editable = False
self.local_file = False
self.specifier = False
self.vcs = None
self.name = None
self.subdirectory = None
self.uri = None
self.path = None
self.revision = None
self.hash_name = None
self.hash = None
self.extras = []
self.specs = []
def __repr__(self):
return '<Requirement: "{0}">'.format(self.line)
def __getitem__(self, key):
return getattr(self, key)
def keys(self):
return self.__dict__.keys()
@classmethod
def parse_editable(cls, line):
"""
Parses a Requirement from an "editable" requirement which is either
a local project path or a VCS project URI.
See: pip/req.py:from_editable()
:param line: an "editable" requirement
:returns: a Requirement instance for the given line
:raises: ValueError on an invalid requirement
"""
req = cls('-e {0}'.format(line))
req.editable = True
vcs_match = VCS_REGEX.match(line)
local_match = LOCAL_REGEX.match(line)
if vcs_match is not None:
groups = vcs_match.groupdict()
if groups.get('login'):
req.uri = '{scheme}://{login}@{path}'.format(**groups)
else:
req.uri = '{scheme}://{path}'.format(**groups)
req.revision = groups['revision']
if groups['fragment']:
fragment = parse_fragment(groups['fragment'])
egg = fragment.get('egg')
req.name, req.extras = parse_extras_require(egg)
req.hash_name, req.hash = get_hash_info(fragment)
req.subdirectory = fragment.get('subdirectory')
for vcs in VCS:
if req.uri.startswith(vcs):
req.vcs = vcs
else:
assert local_match is not None, 'This should match everything'
groups = local_match.groupdict()
req.local_file = True
if groups['fragment']:
fragment = parse_fragment(groups['fragment'])
egg = fragment.get('egg')
req.name, req.extras = parse_extras_require(egg)
req.hash_name, req.hash = get_hash_info(fragment)
req.subdirectory = fragment.get('subdirectory')
req.path = groups['path']
return req
@classmethod
def parse_line(cls, line):
"""
Parses a Requirement from a non-editable requirement.
See: pip/req.py:from_line()
:param line: a "non-editable" requirement
:returns: a Requirement instance for the given line
:raises: ValueError on an invalid requirement
"""
req = cls(line)
vcs_match = VCS_REGEX.match(line)
uri_match = URI_REGEX.match(line)
local_match = LOCAL_REGEX.match(line)
if vcs_match is not None:
groups = vcs_match.groupdict()
if groups.get('login'):
req.uri = '{scheme}://{login}@{path}'.format(**groups)
else:
req.uri = '{scheme}://{path}'.format(**groups)
req.revision = groups['revision']
if groups['fragment']:
fragment = parse_fragment(groups['fragment'])
egg = fragment.get('egg')
req.name, req.extras = parse_extras_require(egg)
req.hash_name, req.hash = get_hash_info(fragment)
req.subdirectory = fragment.get('subdirectory')
for vcs in VCS:
if req.uri.startswith(vcs):
req.vcs = vcs
elif uri_match is not None:
groups = uri_match.groupdict()
req.uri = '{scheme}://{path}'.format(**groups)
if groups['fragment']:
fragment = parse_fragment(groups['fragment'])
egg = fragment.get('egg')
req.name, req.extras = parse_extras_require(egg)
req.hash_name, req.hash = get_hash_info(fragment)
req.subdirectory = fragment.get('subdirectory')
if groups['scheme'] == 'file':
req.local_file = True
elif '#egg=' in line:
# Assume a local file match
assert local_match is not None, 'This should match everything'
groups = local_match.groupdict()
req.local_file = True
if groups['fragment']:
fragment = parse_fragment(groups['fragment'])
egg = fragment.get('egg')
name, extras = parse_extras_require(egg)
req.name = fragment.get('egg')
req.hash_name, req.hash = get_hash_info(fragment)
req.subdirectory = fragment.get('subdirectory')
req.path = groups['path']
else:
# This is a requirement specifier.
# Delegate to pkg_resources and hope for the best
req.specifier = True
pkg_req = Req.parse(line)
req.name = pkg_req.unsafe_name
req.extras = list(pkg_req.extras)
req.specs = pkg_req.specs
return req
@classmethod
def parse(cls, line):
"""
Parses a Requirement from a line of a requirement file.
:param line: a line of a requirement file
:returns: a Requirement instance for the given line
:raises: ValueError on an invalid requirement
"""
line = line.lstrip()
if line.startswith('-e') or line.startswith('--editable'):
# Editable installs are either a local project path
# or a VCS project URI
return cls.parse_editable(
re.sub(r'^(-e|--editable=?)\s*', '', line))
elif '@' in line and ('#' not in line or line.index('#') > line.index('@')):
# Allegro bug fix: support 'name @ git+' entries
name, uri = line.split('@', 1)
name = name.strip()
uri = uri.strip()
# noinspection PyBroadException
try:
# check if the name is valid & parsed
Req.parse(name)
# if we are here, name is a valid package name, check if the vcs part is valid
if VCS_REGEX.match(uri):
req = cls.parse_line(uri)
req.name = name
return req
elif URI_REGEX.match(uri):
req = cls.parse_line(uri)
req.name = name
req.line = line
return req
except Exception:
pass
return cls.parse_line(line)

View File

@@ -0,0 +1,30 @@
from __future__ import unicode_literals
VCS = [
'git',
'hg',
'svn',
'bzr',
]
VCS_SCHEMES = [
'git',
'git+https',
'git+ssh',
'git+git',
'hg+http',
'hg+https',
'hg+static-http',
'hg+ssh',
'svn',
'svn+svn',
'svn+http',
'svn+https',
'svn+ssh',
'bzr+http',
'bzr+https',
'bzr+ssh',
'bzr+sftp',
'bzr+ftp',
'bzr+lp',
]

View File

@@ -1,45 +1,115 @@
from __future__ import print_function, division, unicode_literals
import base64
import logging
import os
import re
import subprocess
import tempfile
from copy import deepcopy
import yaml
import json
from time import sleep
from typing import Text, List
from pyhocon import HOCONConverter
from trains_agent.commands.events import Events
from trains_agent.commands.worker import Worker
from trains_agent.errors import APIError
from trains_agent.helper.base import safe_remove_file
from trains_agent.helper.dicts import merge_dicts
from trains_agent.helper.process import get_bash_output
from trains_agent.helper.resource_monitor import ResourceMonitor
from trains_agent.interface.base import ObjectID
class K8sIntegration(Worker):
K8S_PENDING_QUEUE = "k8s_scheduler"
KUBECTL_RUN_CMD = "kubectl run trains_id_{task_id} " \
KUBECTL_APPLY_CMD = "kubectl apply -f"
KUBECTL_RUN_CMD = "kubectl run trains-{queue_name}-id-{task_id} " \
"--image {docker_image} " \
"--restart=Never --replicas=1 " \
"--generator=run-pod/v1"
"--generator=run-pod/v1 " \
"--namespace=trains"
KUBECTL_DELETE_CMD = "kubectl delete pods " \
"--selector=TRAINS=agent " \
"--field-selector=status.phase!=Pending,status.phase!=Running"
"--field-selector=status.phase!=Pending,status.phase!=Running " \
"--namespace=trains"
CONTAINER_BASH_SCRIPT = "apt-get install -y git python-pip && " \
"pip install trains-agent && " \
"python -u -m trains_agent execute --full-monitoring --require-queue --id {}"
BASH_INSTALL_SSH_CMD = [
"apt-get install -y openssh-server",
"mkdir -p /var/run/sshd",
"echo 'root:training' | chpasswd",
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config",
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config",
r"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd",
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY' >> /etc/ssh/sshd_config",
'echo "export VISIBLE=now" >> /etc/profile',
'echo "export PATH=$PATH" >> /etc/profile',
'echo "ldconfig" >> /etc/profile',
"/usr/sbin/sshd -p {port}"]
def __init__(self, k8s_pending_queue_name=None, kubectl_cmd=None, container_bash_script=None, debug=False):
CONTAINER_BASH_SCRIPT = [
"export DEBIAN_FRONTEND='noninteractive'",
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
"chown -R root /root/.cache/pip",
"apt-get update",
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
"declare LOCAL_PYTHON",
"for i in {{10..5}}; do which python3.$i && python3.$i -m pip --version && "
"export LOCAL_PYTHON=$(which python3.$i) && break ; done",
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip",
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3",
"$LOCAL_PYTHON -m pip install trains-agent",
"{extra_bash_init_cmd}",
"$LOCAL_PYTHON -m trains_agent execute --full-monitoring --require-queue --id {task_id}"
]
AGENT_LABEL = "TRAINS=agent"
LIMIT_POD_LABEL = "ai.allegro.agent.serial=pod-{pod_number}"
_edit_hyperparams_version = "2.9"
def __init__(
self,
k8s_pending_queue_name=None,
kubectl_cmd=None,
container_bash_script=None,
debug=False,
ports_mode=False,
num_of_services=20,
user_props_cb=None,
overrides_yaml=None,
template_yaml=None,
trains_conf_file=None,
extra_bash_init_script=None,
):
"""
Initialize the k8s integration glue layer daemon
:param str k8s_pending_queue_name: queue name to use when task is pending in the k8s scheduler
:param str|callable kubectl_cmd: kubectl command line str, supports formating (default: KUBECTL_RUN_CMD)
:param str|callable kubectl_cmd: kubectl command line str, supports formatting (default: KUBECTL_RUN_CMD)
example: "task={task_id} image={docker_image} queue_id={queue_id}"
or a callable function: kubectl_cmd(task_id, docker_image, queue_id, task_data)
:param str container_bash_script: container bash script to be executed in k8s (default: CONTAINER_BASH_SCRIPT)
Notice this string will use format() call, if you have curly brackets they should be doubled { -> {{
Format arguments passed: {task_id} and {extra_bash_init_cmd}
:param bool debug: Switch logging on
:param bool ports_mode: Adds a label to each pod which can be used in services in order to expose ports.
Requires the `num_of_services` parameter.
:param int num_of_services: Number of k8s services configured in the cluster. Required if `port_mode` is True.
(default: 20)
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
:param str template_yaml: YAML file containing the template for the pod (optional).
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
:param str trains_conf_file: trains.conf file to be use by the pod itself (optional)
:param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container
"""
super(K8sIntegration, self).__init__()
self.k8s_pending_queue_name = k8s_pending_queue_name or self.K8S_PENDING_QUEUE
@@ -51,12 +121,88 @@ class K8sIntegration(Worker):
if debug:
self.log.logger.disabled = False
self.log.logger.setLevel(logging.INFO)
self.ports_mode = ports_mode
self.num_of_services = num_of_services
self._edit_hyperparams_support = None
self._user_props_cb = user_props_cb
self.trains_conf_file = None
self.overrides_json_string = None
self.template_dict = None
self.extra_bash_init_script = extra_bash_init_script or None
if self.extra_bash_init_script and not isinstance(self.extra_bash_init_script, str):
self.extra_bash_init_script = ' ; '.join(self.extra_bash_init_script) # noqa
self.pod_limits = []
self.pod_requests = []
if overrides_yaml:
with open(os.path.expandvars(os.path.expanduser(str(overrides_yaml))), 'rt') as f:
overrides = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
if overrides:
containers = overrides.get('spec', {}).get('containers', [])
for c in containers:
resources = {str(k).lower(): v for k, v in c.get('resources', {}).items()}
if not resources:
continue
if resources.get('limits'):
self.pod_limits += ['{}={}'.format(k, v) for k, v in resources['limits'].items()]
if resources.get('requests'):
self.pod_requests += ['{}={}'.format(k, v) for k, v in resources['requests'].items()]
# remove double entries
self.pod_limits = list(set(self.pod_limits))
self.pod_requests = list(set(self.pod_requests))
if self.pod_limits or self.pod_requests:
self.log.warning('Found pod container requests={} limits={}'.format(
self.pod_limits, self.pod_requests))
if containers:
self.log.warning('Removing containers section: {}'.format(overrides['spec'].pop('containers')))
self.overrides_json_string = json.dumps(overrides)
if template_yaml:
with open(os.path.expandvars(os.path.expanduser(str(template_yaml))), 'rt') as f:
self.template_dict = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
def run_one_task(self, queue: Text, task_id: Text, worker_args=None):
if trains_conf_file:
with open(os.path.expandvars(os.path.expanduser(str(trains_conf_file))), 'rt') as f:
self.trains_conf_file = f.read()
# make sure we use system packages!
self.trains_conf_file += '\nagent.package_manager.system_site_packages=true\n'
def _set_task_user_properties(self, task_id: str, **properties: str):
if self._edit_hyperparams_support is not True:
# either not supported or never tested
if self._edit_hyperparams_support == self._session.api_version:
# tested against latest api_version, not supported
return
if not self._session.check_min_api_version(self._edit_hyperparams_version):
# not supported due to insufficient api_version
self._edit_hyperparams_support = self._session.api_version
return
try:
self._session.get(
service="tasks",
action="edit_hyper_params",
task=task_id,
hyperparams=[
{
"section": "properties",
"name": k,
"value": str(v),
}
for k, v in properties.items()
],
)
# definitely supported
self._runtime_props_support = True
except APIError as error:
if error.code == 404:
self._edit_hyperparams_support = self._session.api_version
def run_one_task(self, queue: Text, task_id: Text, worker_args=None, **_):
print('Pulling task {} launching on kubernetes cluster'.format(task_id))
task_data = self._session.api_client.tasks.get_all(id=[task_id])[0]
# push task into the k8s queue, so we have visibility on pending tasks in the k8s scheduler
try:
print('Pushing task {} into temporary pending queue'.format(task_id))
self._session.api_client.tasks.reset(task_id)
self._session.api_client.tasks.enqueue(task_id, queue=self.k8s_pending_queue_name,
status_reason='k8s pending scheduler')
except Exception as e:
@@ -65,36 +211,231 @@ class K8sIntegration(Worker):
return
if task_data.execution.docker_cmd:
docker_image = task_data.execution.docker_cmd
docker_parts = task_data.execution.docker_cmd
else:
docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
docker_parts = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
self._session.config.get("agent.default_docker.image", "nvidia/cuda"))
# take the first part, this is the docker image name (not arguments)
docker_image = docker_image.split()[0]
docker_parts = docker_parts.split()
docker_image = docker_parts[0]
docker_args = docker_parts[1:] if len(docker_parts) > 1 else []
create_trains_conf = "echo '{}' >> ~/trains.conf && ".format(
HOCONConverter.to_hocon(self._session.config._config))
# get the trains.conf encoded file
# noinspection PyProtectedMember
hocon_config_encoded = (self.trains_conf_file or self._session._config_file).encode('ascii')
create_trains_conf = "echo '{}' | base64 --decode >> ~/trains.conf".format(
base64.b64encode(
hocon_config_encoded
).decode('ascii')
)
if callable(self.kubectl_cmd):
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data)
if self.ports_mode:
print("Kubernetes looking for available pod to use")
# noinspection PyBroadException
try:
queue_name = self._session.api_client.queues.get_by_id(queue=queue).name
except Exception:
queue_name = 'k8s'
# conform queue name to k8s standards
safe_queue_name = queue_name.lower().strip()
safe_queue_name = re.sub(r'\W+', '', safe_queue_name).replace('_', '').replace('-', '')
# Search for a free pod number
pod_number = 1
while self.ports_mode:
kubectl_cmd_new = "kubectl get pods -l {pod_label},{agent_label} -n trains".format(
pod_label=self.LIMIT_POD_LABEL.format(pod_number=pod_number),
agent_label=self.AGENT_LABEL
)
process = subprocess.Popen(kubectl_cmd_new.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
output = '' if not output else output if isinstance(output, str) else output.decode('utf-8')
error = '' if not error else error if isinstance(error, str) else error.decode('utf-8')
if not output:
# No such pod exist so we can use the pod_number we found
break
if pod_number >= self.num_of_services:
# All pod numbers are taken, exit
self.log.warning(
"kubectl last result: {}\n{}\nAll k8s services are in use, task '{}' "
"will be enqueued back to queue '{}'".format(
error, output, task_id, queue
)
)
self._session.api_client.tasks.reset(task_id)
self._session.api_client.tasks.enqueue(
task_id, queue=queue, status_reason='k8s max pod limit (no free k8s service)')
return
pod_number += 1
labels = ([self.LIMIT_POD_LABEL.format(pod_number=pod_number)] if self.ports_mode else []) + [self.AGENT_LABEL]
if self.ports_mode:
print("Kubernetes scheduling task id={} on pod={}".format(task_id, pod_number))
else:
kubectl_cmd = self.kubectl_cmd.format(task_id=task_id, docker_image=docker_image, queue_id=queue)
print("Kubernetes scheduling task id={}".format(task_id))
# make sure we gave a list
if self.template_dict:
output, error = self._kubectl_apply(
create_trains_conf=create_trains_conf,
labels=labels, docker_image=docker_image, docker_args=docker_args,
task_id=task_id, queue=queue, queue_name=safe_queue_name)
else:
output, error = self._kubectl_run(
create_trains_conf=create_trains_conf,
labels=labels, docker_image=docker_image,
task_data=task_data,
task_id=task_id, queue=queue, queue_name=safe_queue_name)
error = '' if not error else (error if isinstance(error, str) else error.decode('utf-8'))
output = '' if not output else (output if isinstance(output, str) else output.decode('utf-8'))
print('kubectl output:\n{}\n{}'.format(error, output))
if error:
self.log.error("Running kubectl encountered an error: {}".format(error))
user_props = {"k8s-queue": str(queue_name)}
if self.ports_mode:
user_props.update({"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]})
if self._user_props_cb:
# noinspection PyBroadException
try:
custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb()
user_props.update(custom_props)
except Exception:
pass
if user_props:
self._set_task_user_properties(
task_id=task_id,
**user_props
)
def _parse_docker_args(self, docker_args):
# type: (list) -> dict
kube_args = {'env': []}
while docker_args:
cmd = docker_args.pop().strip()
if cmd in ('-e', '--env',):
env = docker_args.pop().strip()
key, value = env.split('=', 1)
kube_args[key] += {key: value}
else:
self.log.warning('skipping docker argument {} (only -e --env supported)'.format(cmd))
return kube_args
def _kubectl_apply(self, create_trains_conf, docker_image, docker_args, labels, queue, task_id, queue_name):
template = deepcopy(self.template_dict)
template.setdefault('apiVersion', 'v1')
template['kind'] = 'Pod'
template.setdefault('metadata', {})
name = 'trains-{queue}-id-{task_id}'.format(queue=queue_name, task_id=task_id)
template['metadata']['name'] = name
template.setdefault('spec', {})
template['spec'].setdefault('containers', [])
if labels:
labels_dict = dict(pair.split('=', 1) for pair in labels)
template['metadata'].setdefault('labels', {})
template['metadata']['labels'].update(labels_dict)
container = self._parse_docker_args(docker_args)
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
else self.container_bash_script
script_encoded = '\n'.join(
['#!/bin/bash', ] +
[line.format(extra_bash_init_cmd=self.extra_bash_init_script or '', task_id=task_id)
for line in container_bash_script])
create_init_script = \
"echo '{}' | base64 --decode >> ~/__start_agent__.sh ; " \
"/bin/bash ~/__start_agent__.sh".format(
base64.b64encode(
script_encoded.encode('ascii')
).decode('ascii'))
container = merge_dicts(
container,
dict(name=name, image=docker_image,
command=['/bin/bash'],
args=['-c', '{} ; {}'.format(create_trains_conf, create_init_script)])
)
if template['spec']['containers']:
template['spec']['containers'][0] = merge_dicts(template['spec']['containers'][0], container)
else:
template['spec']['containers'].append(container)
fp, yaml_file = tempfile.mkstemp(prefix='trains_k8stmpl_', suffix='.yml')
os.close(fp)
with open(yaml_file, 'wt') as f:
yaml.dump(template, f)
kubectl_cmd = self.KUBECTL_APPLY_CMD.format(
task_id=task_id,
docker_image=docker_image,
queue_id=queue,
)
# make sure we provide a list
if isinstance(kubectl_cmd, str):
kubectl_cmd = kubectl_cmd.split()
kubectl_cmd += ["--labels=TRAINS=agent", "--command", "--", "/bin/sh", "-c",
create_trains_conf + self.container_bash_script.format(task_id)]
# add the template file at the end
kubectl_cmd += [yaml_file]
try:
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
except Exception as ex:
return None, str(ex)
finally:
safe_remove_file(yaml_file)
return output, error
def _kubectl_run(self, create_trains_conf, docker_image, labels, queue, task_data, task_id, queue_name):
if callable(self.kubectl_cmd):
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data, queue_name)
else:
kubectl_cmd = self.kubectl_cmd.format(
queue_name=queue_name,
task_id=task_id,
docker_image=docker_image,
queue_id=queue
)
# make sure we provide a list
if isinstance(kubectl_cmd, str):
kubectl_cmd = kubectl_cmd.split()
if self.overrides_json_string:
kubectl_cmd += ['--overrides=' + self.overrides_json_string]
if self.pod_limits:
kubectl_cmd += ['--limits', ",".join(self.pod_limits)]
if self.pod_requests:
kubectl_cmd += ['--requests', ",".join(self.pod_requests)]
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
else self.container_bash_script
container_bash_script = ' ; '.join(container_bash_script)
kubectl_cmd += [
"--labels=" + ",".join(labels),
"--command",
"--",
"/bin/sh",
"-c",
"{} ; {}".format(create_trains_conf, container_bash_script.format(
extra_bash_init_cmd=self.extra_bash_init_script, task_id=task_id)),
]
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
self.log.info("K8s scheduling experiment task id={}".format(task_id))
if error:
self.log.error("Running kubectl encountered an error: {}".format(
error if isinstance(error, str) else error.decode()))
return output, error
def run_tasks_loop(self, queues: List[Text], worker_params):
def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
"""
:summary: Pull and run tasks from queues.
:description: 1. Go through ``queues`` by order.
@@ -108,6 +449,7 @@ class K8sIntegration(Worker):
events_service = self.get_service(Events)
# make sure we have a k8s pending queue
# noinspection PyBroadException
try:
self._session.api_client.queues.create(self.k8s_pending_queue_name)
except Exception:
@@ -119,7 +461,7 @@ class K8sIntegration(Worker):
while True:
# iterate over queues (priority style, queues[0] is highest)
for queue in queues:
# delete old completed /failed pods
# delete old completed / failed pods
get_bash_output(self.KUBECTL_DELETE_CMD)
# get next task in queue
@@ -155,15 +497,20 @@ class K8sIntegration(Worker):
if self._session.config["agent.reload_config"]:
self.reload_config()
def k8s_daemon(self, queues):
def k8s_daemon(self, queue):
"""
Start the k8s Glue service.
This service will be pulling tasks from *queues* and scheduling them for execution using kubectl.
This service will be pulling tasks from *queue* and scheduling them for execution using kubectl.
Notice all scheduled tasks are pushed back into K8S_PENDING_QUEUE,
and popped when execution actually starts. This creates full visibility into the k8s scheduler.
Manually popping a task from the K8S_PENDING_QUEUE,
will cause the k8s scheduler to skip the execution once the scheduled tasks needs to be executed
:param list(str) queues: List of queue names to pull from
:param list(str) queue: queue name to pull from
"""
return self.daemon(queues=queues, log_level=logging.INFO, foreground=True, docker=False)
return self.daemon(queues=[ObjectID(name=queue)] if queue else None,
log_level=logging.INFO, foreground=True, docker=False)
@classmethod
def get_ssh_server_bash(cls, ssh_port_number):
return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD)

View File

@@ -173,14 +173,32 @@ def normalize_path(*paths):
def safe_remove_file(filename, error_message=None):
# noinspection PyBroadException
try:
os.remove(filename)
if filename:
os.remove(filename)
except Exception:
if error_message:
print(error_message)
def get_python_path(script_dir, entry_point, package_api):
def safe_remove_tree(filename):
if not filename:
return
# noinspection PyBroadException
try:
shutil.rmtree(filename, ignore_errors=True)
except Exception:
pass
# noinspection PyBroadException
try:
os.remove(filename)
except Exception:
pass
def get_python_path(script_dir, entry_point, package_api, is_conda_env=False):
# noinspection PyBroadException
try:
python_path_sep = ';' if is_windows_platform() else ':'
python_path_cmd = package_api.get_python_command(
@@ -192,9 +210,9 @@ def get_python_path(script_dir, entry_point, package_api):
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
python_path_sep=python_path_sep)
if is_windows_platform():
return python_path.replace('/', '\\') + org_python_path
python_path = python_path.replace('/', '\\')
return python_path + org_python_path
return python_path if is_conda_env else (python_path + org_python_path)
except Exception:
return None
@@ -442,9 +460,9 @@ def chain_map(*args):
return reduce(lambda x, y: x.update(y) or x, args, {})
def check_directory_path(path):
def check_directory_path(path, check_whitespace_in_path=True):
message = 'Could not create directory "{}": {}'
if not is_windows_platform():
if not is_windows_platform() and check_whitespace_in_path:
match = re.search(r'\s', path)
if match:
raise CommandFailedError(
@@ -537,6 +555,7 @@ class ExecutionInfo(NonStrictAttrs):
branch = nullable_string
version_num = nullable_string
tag = nullable_string
docker_cmd = nullable_string
@classmethod
def from_task(cls, task_info):
@@ -554,6 +573,12 @@ class ExecutionInfo(NonStrictAttrs):
execution.entry_point = entry_point
execution.working_dir = working_dir or ""
# noinspection PyBroadException
try:
execution.docker_cmd = task_info.execution.docker_cmd
except Exception:
pass
return execution

View File

@@ -9,7 +9,7 @@ from attr import attrs, attrib
import six
from six import binary_type, text_type
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
from trains_agent.helper.base import nonstrict_in_place_sort
def print_text(text, newline=True):
@@ -22,15 +22,21 @@ def print_text(text, newline=True):
sys.stdout.write(data)
def decode_binary_lines(binary_lines, encoding='utf-8'):
def decode_binary_lines(binary_lines, encoding='utf-8', replace_cr=False, overwrite_cr=False):
# decode per line, if we failed decoding skip the line
lines = []
for b in binary_lines:
# noinspection PyBroadException
try:
l = b.decode(encoding=encoding, errors='replace').replace('\r', '\n')
except:
l = ''
lines.append(l + '\n' if l and l[-1] != '\n' else l)
line = b.decode(encoding=encoding, errors='replace')
if replace_cr:
line = line.replace('\r', '\n')
elif overwrite_cr:
cr_lines = line.split('\r')
line = cr_lines[-1] if cr_lines[-1] or len(cr_lines) < 2 else cr_lines[-2]
except Exception:
line = ''
lines.append(line + '\n' if not line or line[-1] != '\n' else line)
return lines

View File

@@ -3,3 +3,15 @@ from typing import Callable, Dict, Any
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
return {key: value for key, value in dct.items() if filter_(key)}
def merge_dicts(dict1, dict2):
""" Recursively merges dict2 into dict1 """
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
return dict2
for k in dict2:
if k in dict1:
dict1[k] = merge_dicts(dict1[k], dict2[k])
else:
dict1[k] = dict2[k]
return dict1

View File

@@ -200,24 +200,30 @@ class GPUStatCollection(object):
GPUStatCollection.global_processes[nv_process.pid] = \
psutil.Process(pid=nv_process.pid)
ps_process = GPUStatCollection.global_processes[nv_process.pid]
process['username'] = ps_process.username()
# cmdline returns full path;
# as in `ps -o comm`, get short cmdnames.
_cmdline = ps_process.cmdline()
if not _cmdline:
# sometimes, zombie or unknown (e.g. [kworker/8:2H])
process['command'] = '?'
process['full_command'] = ['?']
else:
process['command'] = os.path.basename(_cmdline[0])
process['full_command'] = _cmdline
# Bytes to MBytes
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
process['cpu_percent'] = ps_process.cpu_percent()
process['cpu_memory_usage'] = \
round((ps_process.memory_percent() / 100.0) *
psutil.virtual_memory().total)
process['pid'] = nv_process.pid
# noinspection PyBroadException
try:
# we do not actually use these, so no point in collecting them
# process['username'] = ps_process.username()
# # cmdline returns full path;
# # as in `ps -o comm`, get short cmdnames.
# _cmdline = ps_process.cmdline()
# if not _cmdline:
# # sometimes, zombie or unknown (e.g. [kworker/8:2H])
# process['command'] = '?'
# process['full_command'] = ['?']
# else:
# process['command'] = os.path.basename(_cmdline[0])
# process['full_command'] = _cmdline
# process['cpu_percent'] = ps_process.cpu_percent()
# process['cpu_memory_usage'] = \
# round((ps_process.memory_percent() / 100.0) *
# psutil.virtual_memory().total)
# Bytes to MBytes
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
except Exception:
# insufficient permissions
pass
return process
if not GPUStatCollection._gpu_device_info.get(index):
@@ -285,12 +291,13 @@ class GPUStatCollection(object):
# e.g. nvidia-smi reset or reboot the system
pass
# TODO: Do not block if full process info is not requested
time.sleep(0.1)
for process in processes:
pid = process['pid']
cache_process = GPUStatCollection.global_processes[pid]
process['cpu_percent'] = cache_process.cpu_percent()
# we do not actually use these, so no point in collecting them
# # TODO: Do not block if full process info is not requested
# time.sleep(0.1)
# for process in processes:
# pid = process['pid']
# cache_process = GPUStatCollection.global_processes[pid]
# process['cpu_percent'] = cache_process.cpu_percent()
index = N.nvmlDeviceGetIndex(handle)
gpu_info = {

View File

@@ -5,7 +5,7 @@ from contextlib import contextmanager
from typing import Text, Iterable, Union
import six
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
from trains_agent.helper.process import Executable, Argv, PathLike
@@ -66,7 +66,20 @@ class PackageManager(object):
pass
def upgrade_pip(self):
return self._install("pip"+self.get_pip_version(), "--upgrade")
result = self._install(
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
packages = self.run_with_env(('list',), output=True).splitlines()
# p.split is ('pip', 'x.y.z')
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
if pip:
# noinspection PyBroadException
try:
from .requirements import MarkerRequirement
pip = pip[0][1].split('.')
MarkerRequirement.pip_new_version = bool(int(pip[0]) >= 20)
except Exception:
pass
return result
def get_python_command(self, extra=()):
# type: (...) -> Executable

View File

@@ -2,8 +2,9 @@ from __future__ import unicode_literals
import json
import re
import shutil
import os
import subprocess
from collections import OrderedDict
from distutils.spawn import find_executable
from functools import partial
from itertools import chain
@@ -14,17 +15,18 @@ import yaml
from time import time
from attr import attrs, attrib, Factory
from pathlib2 import Path
from requirements import parse
from requirements.requirement import Requirement
from trains_agent.external.requirements_parser import parse
from trains_agent.external.requirements_parser.requirement import Requirement
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
from trains_agent.helper.package.requirements import SimpleVersion
from trains_agent.session import Session
from .base import PackageManager
from .pip_api.venv import VirtualenvPip
from .requirements import RequirementsManager, MarkerRequirement
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
@@ -40,8 +42,8 @@ def _package_diff(path, packages):
class CondaPip(VirtualenvPip):
def __init__(self, source=None, *args, **kwargs):
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
self.source = source
def run_with_env(self, command, output=False, **kwargs):
@@ -61,8 +63,8 @@ class CondaAPI(PackageManager):
MINIMUM_VERSION = "4.3.30"
def __init__(self, session, path, python, requirements_manager):
# type: (Session, PathLike, float, RequirementsManager) -> None
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
"""
:param python: base python version to use (e.g python3.6)
:param path: path of env
@@ -72,7 +74,15 @@ class CondaAPI(PackageManager):
self.source = None
self.requirements_manager = requirements_manager
self.path = path
self.env_read_only = False
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
self.conda_env_as_base_docker = \
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
bool(ENV_CONDA_ENV_PACKAGE.get())
if ENV_CONDA_ENV_PACKAGE.get():
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
else:
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
self.pip = CondaPip(
session=self.session,
source=self.source,
@@ -80,10 +90,15 @@ class CondaAPI(PackageManager):
requirements_manager=self.requirements_manager,
path=self.path,
)
self.conda = (
find_executable("conda")
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
)
try:
self.conda = (
find_executable("conda") or
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
shell=select_for_platform(windows=True, linux=False)).strip()
)
except Exception:
raise ValueError("ERROR: package manager \"conda\" selected, "
"but \'conda\' executable could not be located")
try:
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as ex:
@@ -111,13 +126,58 @@ class CondaAPI(PackageManager):
def bin(self):
return self.pip.bin
# noinspection SpellCheckingInspection
def upgrade_pip(self):
return self._install("pip" + self.pip.get_pip_version())
# do not change pip version if pre built environement is used
if self.env_read_only:
print('Conda environment in read-only mode, skipping pip upgrade.')
return ''
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
def create(self):
"""
Create a new environment
"""
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
if Path(self.conda_pre_build_env_path).is_dir():
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
self.path = Path(self.conda_pre_build_env_path)
self.source = ("conda", "activate", self.path.as_posix())
self.pip = CondaPip(
session=self.session,
source=self.source,
python=self.python,
requirements_manager=self.requirements_manager,
path=self.path,
)
conda_env = self._get_conda_sh()
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
self.env_read_only = True
return self
elif Path(self.conda_pre_build_env_path).is_file():
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
tar_path = find_executable("tar")
self.path.mkdir(parents=True, exist_ok=True)
output = Argv(
tar_path,
"-xzf",
self.conda_pre_build_env_path,
"-C",
self.path,
).get_output()
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
conda_env = self._get_conda_sh()
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# unpack cleanup
print("Fixing prefix in Conda environment {}".format(self.path))
CommandSequence(('source', conda_env.as_posix()),
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
return self
else:
raise ValueError("Could not restore Conda environment, cannot find {}".format(
self.conda_pre_build_env_path))
output = Argv(
self.conda,
"create",
@@ -133,13 +193,15 @@ class CondaAPI(PackageManager):
self.source = self.pip.source = (
tuple(match.group(1).split()) + (match.group(2),)
if match
else ("activate", self.path)
else ("conda", "activate", self.path.as_posix())
)
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
conda_env = self._get_conda_sh()
if conda_env.is_file() and not is_windows_platform():
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
# install cuda toolkit
# noinspection PyBroadException
try:
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
if cuda_version > 0:
@@ -181,6 +243,10 @@ class CondaAPI(PackageManager):
def _install(self, *args):
# type: (*PathLike) -> ()
# if we are in read only mode, do not install anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
return
channels_args = tuple(
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
)
@@ -208,6 +274,10 @@ class CondaAPI(PackageManager):
return self._install(*packages)
def uninstall_packages(self, *packages):
# if we are in read only mode, do not uninstall anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
return ''
return self._run_command(("uninstall", "-p", self.path))
def install_from_file(self, path):
@@ -226,23 +296,158 @@ class CondaAPI(PackageManager):
with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs)
def freeze(self):
def freeze(self, freeze_full_environment=False):
requirements = self.pip.freeze()
req_lines = []
conda_lines = []
# noinspection PyBroadException
try:
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
conda_packages_txt = []
requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
for pkg in conda_packages:
# skip if this is a pypi package or it is not a python package at all
if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
pip_lines = requirements['pip']
conda_packages_json = json.loads(
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
name = (r['name'].replace('-', '_'), r['name'])
pip_req_line = [l for l in pip_lines
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
if pip_req_line and \
('@' not in pip_req_line[0] or
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
req_lines.append(pip_req_line[0])
continue
req_lines.append(
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
continue
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
requirements['conda'] = conda_packages_txt
except:
# check if we have it in our required packages
name = r['name']
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
# skip over packages with _
if name.startswith('_'):
continue
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
# make sure we see the conda packages, put them into the pip as well
if conda_lines:
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
requirements['pip'] = req_lines
requirements['conda'] = conda_lines
except Exception:
pass
if freeze_full_environment:
# noinspection PyBroadException
try:
conda_env_json = json.loads(
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
conda_env_json.pop('name', None)
conda_env_json.pop('prefix', None)
conda_env_json.pop('channels', None)
requirements['conda_env_json'] = json.dumps(conda_env_json)
except Exception:
pass
return requirements
def _load_conda_full_env(self, conda_env_dict, requirements):
# noinspection PyBroadException
try:
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
except Exception:
cuda_version = 0
conda_env_dict['channels'] = self.extra_channels
if 'dependencies' not in conda_env_dict:
conda_env_dict['dependencies'] = []
new_dependencies = OrderedDict()
pip_requirements = None
for line in conda_env_dict['dependencies']:
if isinstance(line, dict):
pip_requirements = line.pop('pip', None)
continue
name = line.strip().split('=', 1)[0].lower()
if name == 'pip':
continue
elif name == 'python':
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
elif name == 'tensorflow-gpu' and cuda_version == 0:
line = 'tensorflow={}'.format(line.split('=')[1])
elif name == 'tensorflow' and cuda_version > 0:
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
elif name in ('cupti', 'cudnn'):
# cudatoolkit should pull them based on the cudatoolkit version
continue
elif name.startswith('_'):
continue
new_dependencies[line.split('=', 1)[0].strip()] = line
# fix packages:
conda_env_dict['dependencies'] = list(new_dependencies.values())
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
result = self._run_command(
("env", "update", "-p", self.path, "--file", name)
)
# check if we need to remove specific packages
bad_req = self._parse_conda_result_bad_packges(result)
if bad_req:
print('failed installing the following conda packages: {}'.format(bad_req))
return False
if pip_requirements:
# create a list of vcs packages that we need to replace in the pip section
vcs_reqs = {}
if 'pip' in requirements:
pip_lines = requirements['pip'].splitlines() \
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
for line in pip_lines:
try:
marker = list(parse(line))
except Exception:
marker = None
if not marker:
continue
m = MarkerRequirement(marker[0])
if m.vcs:
vcs_reqs[m.name] = m
try:
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
PackageManager._selected_manager = self.pip
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
except Exception as e:
print(e)
raise e
finally:
PackageManager._selected_manager = self
self.requirements_manager.post_install(self.session)
def load_requirements(self, requirements):
# if we are in read only mode, do not uninstall anything
if self.env_read_only:
print('Conda environment in read-only mode, skipping requirements installation.')
return None
# if we have a full conda environment, use it and pass the pip to pip
if requirements.get('conda_env_json'):
# noinspection PyBroadException
try:
conda_env_json = json.loads(requirements.get('conda_env_json'))
print('Conda restoring full yaml environment')
return self._load_conda_full_env(conda_env_json, requirements)
except Exception:
print('Could not load fully stored conda environment, falling back to requirements')
# create new environment file
conda_env = dict()
conda_env['channels'] = self.extra_channels
@@ -276,6 +481,15 @@ class CondaAPI(PackageManager):
if m.vcs:
pip_requirements.append(m)
continue
# Skip over pip
if m.name in ('pip', 'virtualenv', ):
continue
# python version, only major.minor
if m.name == 'python' and m.specs:
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
if '.' not in m.specs[0][1]:
continue
conda_supported_req_names.append(m.name.lower())
if m.req.name.lower() == 'matplotlib':
has_matplotlib = True
@@ -303,15 +517,20 @@ class CondaAPI(PackageManager):
continue
m = MarkerRequirement(marker[0])
# skip over local files (we cannot change the version to a local file)
if m.local_file:
continue
m_name = m.name.lower()
if m_name in conda_supported_req_names:
# this package is in the conda list,
# make sure that if we changed version and we match it in conda
conda_supported_req_names.remove(m_name)
## conda_supported_req_names.remove(m_name)
for cr in reqs:
if m_name == cr.name.lower():
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
# match versions
cr.specs = m.specs
# # conda always likes "-" not "_" but only on pypi packages
# cr.name = cr.name.lower().replace('_', '-')
break
else:
# not in conda, it is a pip package
@@ -319,29 +538,39 @@ class CondaAPI(PackageManager):
if m_name == 'matplotlib':
has_matplotlib = True
# remove any leftover conda packages (they were removed from the pip list)
if conda_supported_req_names:
reqs = [r for r in reqs if r.name.lower() not in conda_supported_req_names]
# Conda requirements Hacks:
if has_matplotlib:
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
# remove specific cudatoolkit, it should have being preinstalled.
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
if has_torch and cuda_version == 0:
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
# make sure we have no double entries
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
# conform conda packages (version/name)
for r in reqs:
# change _ to - in name but not the prefix _ (as this is conda prefix)
if not r.name.startswith('_') and not requirements.get('conda', None):
r.name = r.name.replace('_', '-')
# remove .post from version numbers, it fails ~= version, and change == to ~=
if r.specs and r.specs[0]:
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
# conda always likes "-" not "_"
r.req.name = r.req.name.replace('_', '-')
while reqs:
# notice, we give conda more freedom in version selection, to help it choose best combination
conda_env['dependencies'] = [r.tostr() for r in reqs]
def clean_ver(ar):
if not ar.specs:
return ar.tostr()
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
return ar.tostr()
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
result = self._run_command(
@@ -371,12 +600,15 @@ class CondaAPI(PackageManager):
if pip_requirements:
try:
pip_req_str = [r.tostr() for r in pip_requirements]
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
self.pip.load_requirements('\n'.join(pip_req_str))
PackageManager._selected_manager = self.pip
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
except Exception as e:
print(e)
raise e
finally:
PackageManager._selected_manager = self
self.requirements_manager.post_install(self.session)
return True
@@ -441,8 +673,22 @@ class CondaAPI(PackageManager):
def get_python_command(self, extra=()):
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
def _get_conda_sh(self):
# type () -> Path
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if base_conda_env.is_file():
return base_conda_env
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
conda = find_executable("conda", path=path)
if not conda:
continue
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
if conda_env.is_file():
return conda_env
return base_conda_env
# enable hashing with cmp=False because pdb fails on unhashable exceptions
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
exception = attrs(str=True, cmp=False)

View File

@@ -1,25 +0,0 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class CythonRequirement(SimpleSubstitution):
name = ("cython", "numpy", )
def __init__(self, *args, **kwargs):
super(CythonRequirement, self).__init__(*args, **kwargs)
def match(self, req):
# match both Cython & cython
return req.name and req.name.lower() in self.name
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# install Cython before
PackageManager.out_of_scope_install_package(str(req))
return Text(req)

View File

@@ -1,3 +1,4 @@
import re
from collections import OrderedDict
from typing import Text
@@ -21,6 +22,8 @@ class ExternalRequirements(SimpleSubstitution):
return False
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
return False
if req.pip_new_version and not (req.req.editable or req.req.vcs):
return False
return True
def post_install(self, session):
@@ -33,6 +36,9 @@ class ExternalRequirements(SimpleSubstitution):
freeze_base = ''
req_line = req.tostr(markers=False)
if req_line.strip().startswith('-e ') or req_line.strip().startswith('--editable'):
req_line = re.sub(r'^(-e|--editable=?)\s*', '', req_line, count=1)
if req.req.vcs and req_line.startswith('git+'):
try:
url_no_frag = furl(req_line)
@@ -47,22 +53,30 @@ class ExternalRequirements(SimpleSubstitution):
vcs._set_ssh_url()
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
if new_req_line != req_line:
url_pass = furl(new_req_line).password
furl_line = furl(new_req_line)
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
req_line, new_req_line.replace(url_pass, '****', 1) if url_pass else new_req_line))
req_line,
furl_line.set(password='xxxxxx').tostr() if furl_line.password else new_req_line))
req_line = new_req_line
except Exception:
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
try:
freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
if package_name and package_name[0] not in self.post_install_req_lookup:
self.post_install_req_lookup[package_name[0]] = req.req.line
except:
pass
if not PackageManager.out_of_scope_install_package(req_line, "--ignore-installed"):
# if we have older pip version we have to make sure we replace back the package name with the
# git repository link. In new versions this is supported and we get "package @ git+https://..."
if not req.pip_new_version:
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
# noinspection PyBroadException
try:
freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
if package_name and package_name[0] not in self.post_install_req_lookup:
self.post_install_req_lookup[package_name[0]] = req.req.line
except Exception:
pass
# no need to force reinstall, pip will always rebuilt if the package comes from git
# and make sure the required packages are installed (if they are not it will install them)
if not PackageManager.out_of_scope_install_package(req_line):
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
def replace(self, req):
@@ -76,10 +90,17 @@ class ExternalRequirements(SimpleSubstitution):
return Text('')
def replace_back(self, list_of_requirements):
if 'pip' in list_of_requirements:
original_requirements = list_of_requirements['pip']
list_of_requirements['pip'] = [r for r in original_requirements
if r not in self.post_install_req_lookup]
list_of_requirements['pip'] += [self.post_install_req_lookup.get(r, '')
for r in self.post_install_req_lookup.keys() if r in original_requirements]
if not list_of_requirements:
return list_of_requirements
for k in list_of_requirements:
# k is either pip/conda
if k not in ('pip', 'conda'):
continue
original_requirements = list_of_requirements[k]
list_of_requirements[k] = [r for r in original_requirements
if r not in self.post_install_req_lookup]
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
for r in self.post_install_req_lookup.keys() if r in original_requirements]
return list_of_requirements

View File

@@ -1,32 +0,0 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class HorovodRequirement(SimpleSubstitution):
name = "horovod"
def __init__(self, *args, **kwargs):
super(HorovodRequirement, self).__init__(*args, **kwargs)
self.post_install_req = None
def match(self, req):
# match both horovod
return req.name and self.name == req.name.lower()
def post_install(self, session):
if self.post_install_req:
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
self.post_install_req = None
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req = req
# mark skip package, we will install it in post install hook
return Text('')

View File

@@ -1,6 +1,8 @@
from typing import Any
from pathlib2 import Path
from trains_agent.helper.base import select_for_platform, rm_tree
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session
@@ -9,8 +11,8 @@ from ..requirements import RequirementsManager
class VirtualenvPip(SystemPip, PackageManager):
def __init__(self, session, python, requirements_manager, path, interpreter=None):
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
"""
Program interface to virtualenv pip.
Must be given either path to virtualenv or source command.

View File

@@ -0,0 +1,48 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class PostRequirement(SimpleSubstitution):
name = ("horovod", )
optional_package_names = tuple()
def __init__(self, *args, **kwargs):
super(PostRequirement, self).__init__(*args, **kwargs)
self.post_install_req = []
# check if we need to replace the packages:
post_packages = self.config.get('agent.package_manager.post_packages', None)
if post_packages:
self.__class__.name = post_packages
post_optional_packages = self.config.get('agent.package_manager.post_optional_packages', None)
if post_optional_packages:
self.__class__.optional_package_names = post_optional_packages
def match(self, req):
# match both horovod
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
def post_install(self, session):
for req in self.post_install_req:
if req.name in self.optional_package_names:
# noinspection PyBroadException
try:
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
except Exception:
pass
else:
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
self.post_install_req = []
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req.append(req)
# mark skip package, we will install it in post install hook
return Text('')

View File

@@ -0,0 +1,75 @@
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class PriorityPackageRequirement(SimpleSubstitution):
name = ("cython", "numpy", "setuptools", )
optional_package_names = tuple()
def __init__(self, *args, **kwargs):
super(PriorityPackageRequirement, self).__init__(*args, **kwargs)
# check if we need to replace the packages:
priority_packages = self.config.get('agent.package_manager.priority_packages', None)
if priority_packages:
self.__class__.name = priority_packages
priority_optional_packages = self.config.get('agent.package_manager.priority_optional_packages', None)
if priority_optional_packages:
self.__class__.optional_package_names = priority_optional_packages
def match(self, req):
# match both Cython & cython
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
if req.name in self.optional_package_names:
# noinspection PyBroadException
try:
if PackageManager.out_of_scope_install_package(str(req)):
return Text(req)
except Exception:
pass
return Text('')
PackageManager.out_of_scope_install_package(str(req))
return Text(req)
class PackageCollectorRequirement(SimpleSubstitution):
"""
This RequirementSubstitution class will allow you to have multiple instances of the same
package, it will output the last one (by order) to be actually used.
"""
name = tuple()
def __init__(self, session, collect_package):
super(PackageCollectorRequirement, self).__init__(session)
self._collect_packages = collect_package or tuple()
self._last_req = None
def match(self, req):
# match package names
return req.name and req.name.lower() in self._collect_packages
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
self._last_req = req.clone()
return ''
def post_scan_add_req(self):
"""
Allows the RequirementSubstitution to add an extra line/requirements after
the initial requirements scan is completed.
Called only once per requirements.txt object
"""
last_req = self._last_req
self._last_req = None
return last_req

View File

@@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution):
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
}
def __init__(self, *args, **kwargs):
@@ -117,20 +119,24 @@ class SimplePytorchRequirement(SimpleSubstitution):
@classmethod
def get_torch_page(cls, cuda_version, nightly=False):
# noinspection PyBroadException
try:
cuda = int(cuda_version)
except:
except Exception:
cuda = 0
if nightly:
# then try the nightly builds, it might be there...
torch_url = cls.nightly_page_lookup_template.format(cuda)
try:
if requests.get(torch_url, timeout=10).ok:
cls.torch_page_lookup[cuda] = torch_url
return cls.torch_page_lookup[cuda], cuda
except Exception:
pass
for c in range(cuda, max(-1, cuda-15), -1):
# then try the nightly builds, it might be there...
torch_url = cls.nightly_page_lookup_template.format(c)
# noinspection PyBroadException
try:
if requests.get(torch_url, timeout=10).ok:
print('Torch nightly CUDA {} download page found'.format(c))
cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
return
# first check if key is valid
@@ -138,13 +144,16 @@ class SimplePytorchRequirement(SimpleSubstitution):
return cls.torch_page_lookup[cuda], cuda
# then try a new cuda version page
torch_url = cls.page_lookup_template.format(cuda)
try:
if requests.get(torch_url, timeout=10).ok:
cls.torch_page_lookup[cuda] = torch_url
return cls.torch_page_lookup[cuda], cuda
except Exception:
pass
for c in range(cuda, max(-1, cuda-15), -1):
torch_url = cls.page_lookup_template.format(c)
# noinspection PyBroadException
try:
if requests.get(torch_url, timeout=10).ok:
print('Torch CUDA {} download page found'.format(c))
cls.torch_page_lookup[c] = torch_url
return cls.torch_page_lookup[c], c
except Exception:
pass
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
for k in keys:
@@ -157,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
class PytorchRequirement(SimpleSubstitution):
name = "torch"
packages = ("torch", "torchvision", "torchaudio")
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
def __init__(self, *args, **kwargs):
os_name = kwargs.pop("os_override", None)
@@ -235,6 +244,7 @@ class PytorchRequirement(SimpleSubstitution):
py_ver = self.python_major_minor_str.replace('.', '')
url = None
last_v = None
closest_v = None
# search for our package
for l in links_parser.links:
parts = l.split('/')[-1].split('-')
@@ -244,73 +254,94 @@ class PytorchRequirement(SimpleSubstitution):
continue
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
# version ignore .postX suffix (treat as regular version)
# noinspection PyBroadException
try:
v = str(parts[1].split('%')[0].split('+')[0])
except Exception:
continue
if len(parts) < 3 or not parts[2].endswith(py_ver):
continue
if len(parts) < 5 or platform_wheel not in parts[4]:
continue
# update the closest matched version (from above)
if not closest_v:
closest_v = v
elif SimpleVersion.compare_versions(
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
SimpleVersion.compare_versions(
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
closest_v = v
# check if this an actual match
if not req.compare_version(v) or \
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
continue
if not parts[2].endswith(py_ver):
continue
if platform_wheel not in parts[4]:
continue
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
last_v = v
# if we found an exact match, use it
# noinspection PyBroadException
try:
if req.specs[0][0] == '==' and \
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
break
except:
except Exception:
pass
return url
return url, last_v or closest_v
def get_url_for_platform(self, req):
# check if package is already installed with system packages
# noinspection PyBroadException
try:
if self.config.get("agent.package_manager.system_site_packages", None):
from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name]))
# notice the comparision order, the first part will make sure we have a valid installed package
if installed_torch[0]['version'] and req.compare_version(installed_torch[0]['version']):
# notice the comparison order, the first part will make sure we have a valid installed package
if installed_torch and installed_torch[0]['version'] and \
req.compare_version(installed_torch[0]['version']):
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
# package already installed, do nothing
return str(req), True
except:
req.specs = [('==', str(installed_torch[0]['version']))]
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
except Exception:
pass
# make sure we have a specific version to retrieve
if not req.specs:
req.specs = [('>', '0')]
# noinspection PyBroadException
try:
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
except:
except Exception:
pass
op, version = req.specs[0]
# assert op == "=="
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url)
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly", None):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url = self._get_link_from_torch_page(req, torch_url)
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU):
while not url and torch_url_key > 0:
previous_cuda_key = torch_url_key
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
req, previous_cuda_key, closest_matched_version))
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
if url:
break
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
# never fallback to CPU
if torch_url_key < 1:
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
req, previous_cuda_key))
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
req, self.cuda_version))
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
req, previous_cuda_key, torch_url_key))
url = self._get_link_from_torch_page(req, torch_url)
print(
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
req, previous_cuda_key))
raise ValueError(
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
else:
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
if not url:
url = PytorchWheel(
@@ -322,6 +353,8 @@ class PytorchRequirement(SimpleSubstitution):
if url:
# normalize url (sometimes we will get ../ which we should not...
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
# print found
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
self.log.debug("checking url: %s", url)
return url, requests.head(url, timeout=10).ok
@@ -457,7 +490,13 @@ class PytorchRequirement(SimpleSubstitution):
if req.req.name == parts[0]:
# support for pip >= 20.1
if '@' in line:
lines[i] = '{} # {}'.format(str(req), str(new_req))
# skip if we have nothing to add
if str(req).strip() != str(new_req).strip():
# if this is local file and use the version detection
if req.local_file:
lines[i] = '{}'.format(str(new_req))
else:
lines[i] = '{} # {}'.format(str(req), str(new_req))
else:
lines[i] = '{} # {}'.format(line, str(new_req))
break

View File

@@ -4,7 +4,7 @@ import operator
import os
import re
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from copy import deepcopy, copy
from itertools import chain, starmap
from operator import itemgetter
from os import path
@@ -12,15 +12,15 @@ from typing import Text, List, Type, Optional, Tuple, Dict
from pathlib2 import Path
from pyhocon import ConfigTree
from requirements import parse
# noinspection PyPackageRequirements
from requirements.requirement import Requirement
import six
from trains_agent.definitions import PIP_EXTRA_INDICES
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
from trains_agent.helper.process import Argv, PathLike
from trains_agent.session import Session, normalize_cuda_version
from trains_agent.external.requirements_parser import parse
from trains_agent.external.requirements_parser.requirement import Requirement
from .translator import RequirementsTranslator
@@ -35,6 +35,10 @@ class FatalSpecsResolutionError(Exception):
@six.python_2_unicode_compatible
class MarkerRequirement(object):
# if True pip version above 20.x and with support for "package @ scheme://link"
# default is True
pip_new_version = True
def __init__(self, req): # type: (Requirement) -> None
self.req = req
@@ -57,7 +61,7 @@ class MarkerRequirement(object):
elif self.vcs:
# leave the line as is, let pip handle it
if self.line:
parts = [self.line]
return self.line
else:
# let's build the line manually
parts = [
@@ -65,6 +69,10 @@ class MarkerRequirement(object):
'@{}'.format(self.revision) if self.revision else '',
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
]
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
# package @ file:///example.com/somewheel.whl
# leave the line as is, let pip handle it
return self.line
else:
parts = [self.uri]
@@ -73,6 +81,9 @@ class MarkerRequirement(object):
return ''.join(parts)
def clone(self):
return MarkerRequirement(copy(self.req))
__str__ = tostr
def __repr__(self):
@@ -138,7 +149,8 @@ class MarkerRequirement(object):
version = self.specs[0][1]
op = (op or self.specs[0][0]).strip()
return SimpleVersion.compare_versions(requested_version, op, version)
return SimpleVersion.compare_versions(
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
class SimpleVersion:
@@ -177,7 +189,7 @@ class SimpleVersion:
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
@classmethod
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
"""
Compare two versions based on the op operator
returns bool(version_a op version_b)
@@ -188,12 +200,12 @@ class SimpleVersion:
:param str version_b:
:param bool ignore_sub_versions: if true compare only major.minor.patch
(ignore a/b/rc/post/dev in the comparison)
:param int num_parts: number of parts to compare, split by . (dot)
:return bool: version_a op version_b
"""
if not version_b:
return True
num_parts = 3
if op == '~=':
num_parts = max(num_parts, 2)
@@ -326,6 +338,14 @@ class RequirementSubstitution(object):
"""
pass
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
"""
Allows the RequirementSubstitution to add an extra line/requirements after
the initial requirements scan is completed.
Called only once per requirements.txt object
"""
return None
def post_install(self, session):
pass
@@ -480,6 +500,14 @@ class RequirementsManager(object):
)
if not conda:
result = map(self.translator.translate, result)
result = list(result)
# add post scan add requirements call back
for h in self.handlers:
req = h.post_scan_add_req()
if req:
result.append(req.tostr())
return join_lines(result)
def post_install(self, session):
@@ -491,6 +519,9 @@ class RequirementsManager(object):
raise
def replace_back(self, requirements):
if self.translator:
requirements = self.translator.replace_back(requirements)
for h in self.handlers:
try:
requirements = h.replace_back(requirements)

View File

@@ -23,6 +23,7 @@ class RequirementsTranslator(object):
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
self.config = Config()
self.pip = SystemPip(interpreter=interpreter, session=self._session)
self._translate_back = {}
def download(self, url):
self.pip.download_package(url, cache_dir=self.cache_dir)
@@ -60,4 +61,30 @@ class RequirementsTranslator(object):
except Exception:
command.error('Could not download wheel name of "{}"'.format(line))
return line
self._translate_back[str(downloaded)] = line
return downloaded
def replace_back(self, requirements):
if not requirements:
return requirements
for k in requirements:
# k is either pip/conda
if k not in ('pip', 'conda'):
continue
original_requirements = requirements[k]
new_requirements = []
for line in original_requirements:
local_file = [d for d in self._translate_back.keys() if d in line]
if local_file:
local_file = local_file[0]
new_requirements.append(line.replace(local_file, self._translate_back[local_file]))
else:
new_requirements.append(line)
requirements[k] = new_requirements
return requirements

View File

@@ -11,6 +11,7 @@ from copy import deepcopy
from distutils.spawn import find_executable
from itertools import chain, repeat, islice
from os.path import devnull
from time import sleep
from typing import Union, Text, Sequence, Any, TypeVar, Callable
import psutil
@@ -41,6 +42,30 @@ def get_bash_output(cmd, strip=False, stderr=subprocess.STDOUT, stdin=False):
return output if not strip or not output else output.strip()
def terminate_process(pid, timeout=10.):
# noinspection PyBroadException
try:
proc = psutil.Process(pid)
proc.terminate()
cnt = 0
while proc.is_running() and cnt < timeout:
sleep(1.)
cnt += 1
proc.terminate()
cnt = 0
while proc.is_running() and cnt < timeout:
sleep(1.)
cnt += 1
proc.kill()
except Exception:
pass
# noinspection PyBroadException
try:
return not psutil.Process(pid).is_running()
except Exception:
return True
def kill_all_child_processes(pid=None):
# get current process if pid not provided
include_parent = True

View File

@@ -5,7 +5,7 @@ import subprocess
from distutils.spawn import find_executable
from hashlib import md5
from os import environ, getenv
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple, Optional
import attr
from furl import furl
@@ -13,7 +13,7 @@ from pathlib2 import Path
import six
from trains_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS
from trains_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS, ENV_AGENT_GIT_HOST
from trains_agent.helper.console import ensure_text, ensure_binary
from trains_agent.errors import CommandFailedError
from trains_agent.helper.base import (
@@ -150,12 +150,23 @@ class VCS(object):
"""
Apply patch repository at `location`
"""
self.log.info("applying diff to %s", location)
self.log.info("applying diff to %s" % location)
for match in filter(
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
):
create_file_if_not_exists(normalize_path(location, match.group("path")))
# noinspection PyBroadException
try:
for match in filter(
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
):
file_path = None
# noinspection PyBroadException
try:
file_path = normalize_path(location, match.group("path"))
create_file_if_not_exists(file_path)
except Exception:
if file_path:
self.log.warning("failed creating file for git diff (%s)" % file_path)
except Exception:
pass
return_code, errors = self.call_with_stdin(
patch_content, *self.patch_base, cwd=location
@@ -243,8 +254,8 @@ class VCS(object):
return url
@classmethod
def replace_http_url(cls, url):
# type: (Text) -> Text
def replace_http_url(cls, url, port=None):
# type: (Text, Optional[int]) -> Text
"""
Replace HTTPS URL with SSH URL when applicable
"""
@@ -254,7 +265,9 @@ class VCS(object):
parsed_url.username = "git"
parsed_url.password = None
# make sure there is no port in the final url (safe_furl support)
parsed_url.port = None
# the original port was an https port, and we do not know if there is a different ssh port,
# so we have to clear the original port specified (https) and use the default ssh schema port.
parsed_url.port = port or None
url = parsed_url.url
return url
@@ -265,8 +278,14 @@ class VCS(object):
"""
if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
parsed_url = furl(self.url)
# only apply to a specific domain (if requested)
config_domain = \
ENV_AGENT_GIT_HOST.get() or self.session.config.get("agent.git_host", None)
if config_domain and config_domain != parsed_url.host:
return
if parsed_url.scheme == "https":
new_url = self.replace_http_url(self.url)
new_url = self.replace_http_url(
self.url, port=self.session.config.get('agent.force_git_ssh_port', None))
if new_url != self.url:
print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
self.url, new_url))
@@ -276,11 +295,15 @@ class VCS(object):
if not self.session.config.agent.translate_ssh:
return
ssh_agent_variable = "SSH_AUTH_SOCK"
if not getenv(ssh_agent_variable) and (
(ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None))
):
# if we have git_user / git_pass replace ssh credentials with https authentication
if (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and \
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)):
# only apply to a specific domain (if requested)
config_domain = \
ENV_AGENT_GIT_HOST.get() or self.session.config.get("git_host", None)
if config_domain and config_domain != furl(self.url).host:
return
new_url = self.replace_ssh_url(self.url)
if new_url != self.url:
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
@@ -358,9 +381,10 @@ class VCS(object):
"""
Run command with `input_` as stdin
"""
input_ = input_.encode("latin1")
if not input_.endswith(b"\n"):
input_ += b"\n"
input_ = input_.encode("utf-8")
# always add extra empty line
# (there is no downside, and it solves empty lines issue at end of patch cause corrupt message)
input_ += b"\n"
process = self._call_subprocess(
subprocess.Popen,
argv,
@@ -430,10 +454,12 @@ class VCS(object):
return parsed_url.url
config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)
config_pass = ENV_AGENT_GIT_PASS.get() or config.get("agent.{}_pass".format(cls.executable_name), None)
config_domain = ENV_AGENT_GIT_HOST.get() or config.get("agent.{}_host".format(cls.executable_name), None)
if (
(not (parsed_url.username and parsed_url.password))
and config_user
and config_pass
and (not config_domain or config_domain.lower() == parsed_url.host)
):
parsed_url.set(username=config_user, password=config_pass)
return parsed_url.url
@@ -508,7 +534,7 @@ class Git(VCS):
root=Argv(executable_name, "rev-parse", "--show-toplevel"),
)
patch_base = ("apply",)
patch_base = ("apply", "--unidiff-zero", )
class Hg(VCS):

View File

@@ -0,0 +1,170 @@
import re
from datetime import datetime, timedelta
from typing import List, Tuple, Optional
from trains_agent.backend_config.defs import UptimeConf
DAYS = ["SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"]
PATTERN = re.compile(r"^(?P<hours>[^\s]+)\s(?P<days>[^\s]+)")
def check_runtime(ranges_list, is_uptime=True):
# type: (List[str], bool) -> bool
for entry in ranges_list:
days_list = get_days_list(entry)
if not check_day(days_list):
continue
hours_list = get_hours_list(entry)
if check_hour(hours_list):
return is_uptime
return not is_uptime
def check_hour(hours):
# type: (List[str]) -> bool
return datetime.now().hour in hours
def check_day(days):
# type: (List[str]) -> bool
return datetime.now().strftime("%a").upper() in days
def get_days_list(entry):
# type: (str) -> List[str]
days_intervals = PATTERN.match(entry)["days"].split(",")
days_total = []
for days in days_intervals:
start, end = days.split("-") if "-" in days else (days, days)
try:
days_total.extend(
[*range(DAYS.index(start.upper()), DAYS.index(end.upper()) + 1)]
)
except ValueError:
print(
"Warning: days interval '{}' is invalid, use intervals of the format <start>-<end>."
" make sure to use the abbreviated format SUN-SAT".format(days)
)
continue
return [DAYS[day] for day in days_total]
def get_hours_list(entry):
# type: (str) -> List[str]
hours_intervals = PATTERN.match(entry)["hours"].split(",")
hours_total = []
for hours in hours_intervals:
start, end = get_start_end_hours(hours)
if not (start and end):
continue
hours_total.extend([*range(start, end)])
return hours_total
def get_start_end_hours(hours):
# type: (str) -> Tuple[int, int]
try:
start, end = (
tuple(map(int, hours.split("-"))) if "-" in hours else (int(hours), 24)
)
except Exception as ex:
print(
"Warning: hours interval '{}' is invalid, use intervals of the format <start>-<end>".format(
hours, ex
)
)
start, end = (None, None)
if end == 0:
end = 24
return start, end
def print_uptime_properties(
ranges_list, queues_info, runtime_properties, is_uptime=True
):
# type: (List[str], List[dict], List[dict], bool) -> None
if ranges_list:
uptime_string = ["Working hours {} configurations".format("uptime" if is_uptime else "downtime")]
uptime_string.extend(get_uptime_string(entry) for entry in ranges_list)
else:
uptime_string = ["No uptime/downtime configurations found"]
is_server_forced, server_string = get_runtime_properties_string(runtime_properties)
is_queue_forced, queues_string = get_queues_tags_string(queues_info)
res = list(
filter(
len,
[
"\n".join(uptime_string),
"\nCurrently forced {}".format(is_queue_forced or is_server_forced)
if is_queue_forced or is_server_forced
else " ",
server_string,
queues_string,
],
)
)
print("\n".join(res))
def get_uptime_string(entry):
# type: (str) -> str
res = []
days_list = get_days_list(entry)
hours_intervals = PATTERN.match(entry)["hours"].split(",")
for hours in hours_intervals:
start, end = get_start_end_hours(hours)
if not (start and end):
continue
res.append(
" - {}:00-{}:59 on {}".format(start, end - 1, ' and '.join(days_list))
if not (start == end)
else ""
)
return "\n".join(res)
def get_runtime_properties_string(runtime_properties):
# type: (List[dict]) -> Tuple[Optional[str], str]
server_string = []
force_flag = next(
(prop for prop in runtime_properties if prop["key"] == UptimeConf.worker_key),
None,
)
is_server_forced = None
if force_flag:
is_server_forced = force_flag["value"].upper()
expiry_hour = (
(datetime.now() + timedelta(seconds=force_flag["expiry"])).strftime("%H:%M")
if force_flag["expiry"]
else None
)
expires = " expires at {}".format(expiry_hour) if expiry_hour else ""
server_string.append(
" - Server runtime property '{}: {}'{}".format(force_flag['key'], force_flag['value'], expires)
)
return is_server_forced, "\n".join(server_string)
def get_queues_tags_string(queues_info):
# type: (List[dict]) -> Tuple[Optional[str], str]
queues_string = []
is_queue_forced = None
for queue in queues_info:
if "force_workers:off" in queue.get("tags", []):
tag = "force_workers:off"
is_queue_forced = is_queue_forced or "OFF"
elif "force_workers:on" in queue.get("tags", []):
tag = "force_workers:on"
is_queue_forced = "ON"
else:
tag = None
tagged = " (tagged '{}')'".format(tag) if tag else ""
queues_string.append(
" - Listening to queue '{}'{}".format(queue.get('name', ''), tagged)
)
return is_queue_forced, "\n".join(queues_string)

View File

@@ -4,6 +4,8 @@ from time import sleep
from glob import glob
from tempfile import gettempdir, NamedTemporaryFile
from typing import List, Tuple, Optional
from trains_agent.definitions import ENV_DOCKER_HOST_MOUNT
from trains_agent.helper.base import warning
@@ -37,6 +39,10 @@ class Singleton(object):
except:
pass
@classmethod
def get_lock_filename(cls):
return os.path.join(cls._get_temp_folder(), cls._lock_file_name)
@classmethod
def register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
"""
@@ -47,7 +53,7 @@ class Singleton(object):
:return: (str worker_id, int slot_number) Return None value on instance already running
"""
# try to lock file
lock_file = os.path.join(cls._get_temp_folder(), cls._lock_file_name)
lock_file = cls.get_lock_filename()
timeout = 0
while os.path.exists(lock_file):
if timeout > cls._lock_timeout:
@@ -79,30 +85,42 @@ class Singleton(object):
return ret
@classmethod
def _register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
if cls.worker_id:
return cls.worker_id, cls.instance_slot
# make sure we have a unique name
instance_num = 0
def get_running_pids(cls):
# type: () -> List[Tuple[int, Optional[str], Optional[int], str]]
temp_folder = cls._get_temp_folder()
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
slots = {}
pids = []
for file in files:
parts = file.split(cls.sep)
parts = os.path.basename(file).split(cls.sep)
# noinspection PyBroadException
try:
pid = int(parts[1])
if not psutil.pid_exists(pid):
pid = -1
except Exception:
# something is wrong, use non existing pid and delete the file
pid = -1
uid, slot = None, None
# noinspection PyBroadException
try:
with open(file, 'r') as f:
uid, slot = str(f.read()).split('\n')
slot = int(slot)
except Exception:
pass
pids.append((pid, uid, slot, file))
return pids
@classmethod
def _register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
if cls.worker_id:
return cls.worker_id, cls.instance_slot
# make sure we have a unique name
instance_num = 0
slots = {}
for pid, uid, slot, file in cls.get_running_pids():
worker = None
if api_client and ENV_DOCKER_HOST_MOUNT.get() and uid:
try:
@@ -111,7 +129,7 @@ class Singleton(object):
worker = None
# count active instances and delete dead files
if not worker and not psutil.pid_exists(pid):
if not worker and pid < 0:
# delete the file
try:
os.remove(os.path.join(file))
@@ -165,3 +183,9 @@ class Singleton(object):
@classmethod
def get_slot(cls):
return cls.instance_slot or 0
@classmethod
def get_pid_file(cls):
if not cls._pid_file:
return None
return cls._pid_file.name

View File

@@ -68,6 +68,10 @@ DAEMON_ARGS = dict({
'dest': 'queues',
'type': foreign_object_id('queues'),
},
'--order-fairness': {
'help': 'Pull from each queue in a round-robin order, instead of priority order.',
'action': 'store_true',
},
'--standalone-mode': {
'help': 'Do not use any network connects, assume everything is pre-installed',
'action': 'store_true',
@@ -85,7 +89,28 @@ DAEMON_ARGS = dict({
'action': 'store_true',
'aliases': ['-d'],
},
'--stop': {
'help': 'Stop the running agent (based on the same set of arguments)',
'action': 'store_true',
},
'--uptime': {
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "17-20 TUE" to set '
'Tuesday\'s uptime to 17-20'
'Note: Make sure to have only one of uptime/downtime configuration and not both.',
'nargs': '*',
'default': None,
},
'--downtime': {
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "09-13 TUE" to set '
'Tuesday\'s downtime to 09-13'
'Note: Make sure to have only on of uptime/downtime configuration and not both.',
'nargs': '*',
'default': None,
},
'--status': {
'help': 'Print the worker\'s schedule (uptime properties, server\'s runtime properties and listening queues)',
'action': 'store_true',
},
}, **WORKER_ARGS)
COMMANDS = {

View File

@@ -73,9 +73,11 @@ class Session(_Session):
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
if not Path(config_file).is_file():
raise ValueError("Could not open configuration file: {}".format(config_file))
cpu_only = kwargs.get('cpu_only')
if cpu_only:
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = 'none'
if kwargs.get('gpus') and not os.environ.get('KUBERNETES_SERVICE_HOST') \
and not os.environ.get('KUBERNETES_PORT'):
# CUDA_VISIBLE_DEVICES does not support 'all'
@@ -84,6 +86,7 @@ class Session(_Session):
os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
else:
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
if kwargs.get('only_load_config'):
from trains_agent.backend_api.config import load
self.config = load()

View File

@@ -1 +1 @@
__version__ = '0.15.2rc0'
__version__ = '0.16.3'