mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd42423482 | ||
|
|
69eb25db1f | ||
|
|
a41ea52f87 | ||
|
|
259113c989 | ||
|
|
1afa3a3914 | ||
|
|
448e23825c | ||
|
|
b0c0f41f62 | ||
|
|
d2c5fb6512 | ||
|
|
b89cf4ec23 | ||
|
|
74b646af9e | ||
|
|
0cf485f7a9 | ||
|
|
ea63e4f66e | ||
|
|
58eb5fbd5f | ||
|
|
a8c543ef7b | ||
|
|
64e198a57a | ||
|
|
de332b9e6b | ||
|
|
60eeff292d | ||
|
|
52f30b306a | ||
|
|
6df0f81ca0 | ||
|
|
40b3c1502d | ||
|
|
a61265effe | ||
|
|
92efea6b76 | ||
|
|
216b3e2179 | ||
|
|
293a92f486 | ||
|
|
6bad2b5352 | ||
|
|
a09a638b9c | ||
|
|
24f57270ed | ||
|
|
1b7964ce98 | ||
|
|
5a510882b8 | ||
|
|
601ed03198 | ||
|
|
90fe4570b9 | ||
|
|
92fc8e838f | ||
|
|
89a3020c5e | ||
|
|
fc3e47b67e | ||
|
|
b2a80ca314 | ||
|
|
14655f19a0 | ||
|
|
47092c47db | ||
|
|
8e6fce8d63 | ||
|
|
3c514e3418 | ||
|
|
8a425b100b | ||
|
|
eb942cfedd | ||
|
|
0a7fc06108 | ||
|
|
0ae35afa76 | ||
|
|
a2156e73bf | ||
|
|
9fe77f3c28 | ||
|
|
6f078afafd | ||
|
|
15f4aa613e | ||
|
|
7cd9fa6c41 | ||
|
|
234d5fac2c | ||
|
|
6cbfb96ff8 | ||
|
|
6e54e55c31 | ||
|
|
3ff85b7b85 | ||
|
|
5640489f57 | ||
|
|
8135a6facf | ||
|
|
b6ae4f211d | ||
|
|
a56f032ec4 | ||
|
|
075736de20 | ||
|
|
d8543c892e | ||
|
|
ca0870b048 | ||
|
|
c7a739fafa | ||
|
|
7170296162 | ||
|
|
3bed0ef33c | ||
|
|
d419fa1e4f | ||
|
|
31a56c71bd | ||
|
|
28f47419b0 | ||
|
|
6a24da2849 | ||
|
|
782668fd21 | ||
|
|
aaf8d802e7 | ||
|
|
ca89a1e322 | ||
|
|
121dec2a62 | ||
|
|
4aacf9005e | ||
|
|
6b333202e9 | ||
|
|
ce6831368f | ||
|
|
e4111c830b | ||
|
|
52c1772b04 | ||
|
|
699d13bbb3 | ||
|
|
2c8d7d3d9a | ||
|
|
b13cc1e8e7 | ||
|
|
17d2bf2a3e | ||
|
|
94997f9c88 | ||
|
|
c6d998c4df | ||
|
|
f8ea445339 |
18
README.md
18
README.md
@@ -227,6 +227,14 @@ The **Trains Agent** will first try to pull jobs from the `important_jobs` queue
|
||||
|
||||
Adding queues, managing job order within a queue and moving jobs between queues, is available using the Web UI, see example on our [open server](https://demoapp.trains.allegro.ai/workers-and-queues/queues)
|
||||
|
||||
#### Stopping the Trains Agent
|
||||
|
||||
To stop a **Trains Agent** running in the background, run the same command line used to start the agent with `--stop` appended.
|
||||
For example, to stop the first of the above shown same machine, single gpu agents:
|
||||
```bash
|
||||
trains-agent daemon --detached --gpus 0 --queue default --docker nvidia/cuda --stop
|
||||
```
|
||||
|
||||
## How do I create an experiment on the Trains Server? <a name="from-scratch"></a>
|
||||
* Integrate [Trains](https://github.com/allegroai/trains) with your code
|
||||
* Execute the code on your machine (Manually / PyCharm / Jupyter Notebook)
|
||||
@@ -272,18 +280,18 @@ trains-agent daemon --services-mode --detached --queue services --create-queue -
|
||||
## AutoML and Orchestration Pipelines <a name="automl-pipes"></a>
|
||||
The Trains Agent can also be used to implement AutoML orchestration and Experiment Pipelines in conjunction with the Trains package.
|
||||
|
||||
Sample AutoML & Orchestration examples can be found in the Trains [example/automl](https://github.com/allegroai/trains/tree/master/examples/automl) folder.
|
||||
Sample AutoML & Orchestration examples can be found in the Trains [example/automation](https://github.com/allegroai/trains/tree/master/examples/automation) folder.
|
||||
|
||||
AutoML examples
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/automl/automl_base_template_keras_simple.py)
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/optimization/hyper-parameter-optimization/base_template_keras_simple.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automl/automl_random_search_example.py)
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automation/manual_random_param_search_example.py)
|
||||
- This example will create multiple copies of the Keras experiment-template, with different hyper-parameter combinations
|
||||
|
||||
Experiment Pipeline examples
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/task_piping_example.py)
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/task_piping_example.py)
|
||||
- This example will "process data", and once done, will launch a copy of the 'second step' experiment-template
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/toy_base_task.py)
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/toy_base_task.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
|
||||
## License
|
||||
|
||||
@@ -5,8 +5,17 @@ WORKDIR /usr/agent
|
||||
|
||||
COPY . /usr/agent
|
||||
|
||||
ENV LC_ALL=en_US.UTF-8
|
||||
ENV LANG=en_US.UTF-8
|
||||
ENV LANGUAGE=en_US.UTF-8
|
||||
ENV PYTHONIOENCODING=UTF-8
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get dist-upgrade -y
|
||||
RUN apt-get install -y locales
|
||||
|
||||
RUN locale-gen en_US.UTF-8
|
||||
|
||||
RUN apt-get install -y curl python3-pip git
|
||||
RUN curl -sSL https://get.docker.com/ | sh
|
||||
RUN python3 -m pip install -U pip
|
||||
|
||||
@@ -11,4 +11,4 @@ TRAINS_API_HOST=${TRAINS_API_HOST:-"http://$TRAINS_HOST_IP:8008"}
|
||||
echo $TRAINS_FILES_HOST $TRAINS_WEB_HOST $TRAINS_API_HOST 1>&2
|
||||
|
||||
python3 -m pip install -q -U "trains-agent${TRAINS_AGENT_UPDATE_VERSION}"
|
||||
trains-agent daemon --services-mode --queue services --create-queue --docker $TRAINS_AGENT_DEFAULT_BASE_DOCKER --cpu-only $TRAINS_AGENT_EXTRA_ARGS
|
||||
trains-agent daemon --services-mode --queue services --create-queue --docker "$TRAINS_AGENT_DEFAULT_BASE_DOCKER" --cpu-only $TRAINS_AGENT_EXTRA_ARGS
|
||||
|
||||
@@ -17,9 +17,14 @@ agent {
|
||||
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
|
||||
git_user=""
|
||||
git_pass=""
|
||||
# Limit credentials to a single domain, for example: github.com,
|
||||
# all other domains will use public access (no user/pass). Default: always send user/pass for any VCS domain
|
||||
git_host=""
|
||||
|
||||
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
|
||||
force_git_ssh_protocol: false
|
||||
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
|
||||
# force_git_ssh_port: ""
|
||||
|
||||
# unique name of this worker, if None, created based on hostname:process_id
|
||||
# Overridden with os environment: TRAINS_WORKER_NAME
|
||||
@@ -57,6 +62,24 @@ agent {
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["pytorch", "conda-forge", ]
|
||||
# conda_full_env_update: false
|
||||
# conda_env_as_base_docker: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
# set the optional priority packages to be installed before the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# priority_optional_packages: ["pygobject", ]
|
||||
|
||||
# set the post packages to be installed after all the rest of the required packages
|
||||
# post_packages: ["horovod", ]
|
||||
|
||||
# set the optional post packages to be installed after all the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# post_optional_packages: []
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
@@ -146,6 +169,9 @@ sdk {
|
||||
# X images are stored in the upload destination for each matplotlib plot title.
|
||||
matplotlib_untitled_history_size: 100
|
||||
|
||||
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||
# plot_max_num_digits: 5
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
|
||||
75
examples/k8s_glue_example.py
Normal file
75
examples/k8s_glue_example.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
This example assumes you have preconfigured services with selectors in the form of
|
||||
"ai.allegro.agent.serial=pod-<number>" and a targetPort of 10022.
|
||||
The K8sIntegration component will label each pod accordingly.
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from trains_agent.glue.k8s import K8sIntegration
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--queue", type=str, help="Queue to pull tasks from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ports-mode", action='store_true', default=False,
|
||||
help="Ports-Mode will add a label to the pod which can be used as service, in order to expose ports"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-of-services", type=int, default=20,
|
||||
help="Specify the number of k8s services to be used. Use only with ports-mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-port", type=int,
|
||||
help="Used in conjunction with ports-mode, specifies the base port exposed by the services. "
|
||||
"For pod #X, the port will be <base-port>+X"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gateway-address", type=str, default=None,
|
||||
help="Used in conjunction with ports-mode, specify the external address of the k8s ingress / ELB"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pod-trains-conf", type=str,
|
||||
help="Configuration file to be used by the pod itself (if not provided, current configuration is used)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overrides-yaml", type=str,
|
||||
help="YAML file containing pod overrides to be used when launching a new pod"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-yaml", type=str,
|
||||
help="YAML file containing pod template. If provided pod will be scheduled with kubectl apply "
|
||||
"and overrides are ignored, otherwise it will be scheduled with kubectl run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssh-server-port", type=int, default=0,
|
||||
help="If non-zero, every pod will also start an SSH server on the selected port (default: zero, not active)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
user_props_cb = None
|
||||
if args.ports_mode and args.base_port:
|
||||
def k8s_user_props_cb(pod_number):
|
||||
user_prop = {"k8s-pod-port": args.base_port + pod_number}
|
||||
if args.gateway_address:
|
||||
user_prop["k8s-gateway-address"] = args.gateway_address
|
||||
return user_prop
|
||||
user_props_cb = k8s_user_props_cb
|
||||
|
||||
k8s = K8sIntegration(
|
||||
ports_mode=args.ports_mode, num_of_services=args.num_of_services, user_props_cb=user_props_cb,
|
||||
overrides_yaml=args.overrides_yaml, trains_conf_file=args.pod_trains_conf, template_yaml=args.template_yaml,
|
||||
extra_bash_init_script=K8sIntegration.get_ssh_server_bash(
|
||||
ssh_port_number=args.ssh_server_port) if args.ssh_server_port else None
|
||||
)
|
||||
k8s.k8s_daemon(args.queue)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -13,7 +13,6 @@ pyjwt>=1.6.4
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.20.0
|
||||
requirements_parser>=0.2.0
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
typing>=3.6.4
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
|
||||
# git_user: ""
|
||||
# git_pass: ""
|
||||
# git_host: ""
|
||||
|
||||
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
|
||||
force_git_ssh_protocol: false
|
||||
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
|
||||
# force_git_ssh_port: 0
|
||||
|
||||
# Set the python version to use when creating the virtual environment and launching the experiment
|
||||
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
|
||||
@@ -44,6 +47,26 @@
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["defaults", "conda-forge", "pytorch", ]
|
||||
|
||||
# If set to true, Task's "installed packages" are ignored,
|
||||
# and the repository's "requirements.txt" is used instead
|
||||
# force_repo_requirements_txt: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
# set the optional priority packages to be installed before the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# priority_optional_packages: ["pygobject", ]
|
||||
|
||||
# set the post packages to be installed after all the rest of the required packages
|
||||
# post_packages: ["horovod", ]
|
||||
|
||||
# set the optional post packages to be installed after all the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# post_optional_packages: []
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
torch_nightly: false,
|
||||
@@ -86,6 +109,22 @@
|
||||
# optional shell script to run in docker when started before the experiment is started
|
||||
# extra_docker_shell_script: ["apt-get install -y bindfs", ]
|
||||
|
||||
# optional uptime configuration, make sure to use only one of 'uptime/downtime' and not both.
|
||||
# If uptime is specified, agent will actively poll (and execute) tasks in the time-spans defined here.
|
||||
# Outside of the specified time-spans, the agent will be idle.
|
||||
# Defined using a list of items of the format: "<hours> <days>".
|
||||
# hours - use values 0-23, single values would count as start hour and end at midnight.
|
||||
# days - use days in abbreviated format (SUN-SAT)
|
||||
# use '-' for ranges and ',' to separate singular values.
|
||||
# for example, to enable the workers every Sunday and Tuesday between 17:00-20:00 set uptime to:
|
||||
# uptime: ["17-20 SUN,TUE"]
|
||||
|
||||
# optional downtime configuration, can be used only when uptime is not used.
|
||||
# If downtime is specified, agent will be idle in the time-spans defined here.
|
||||
# Outside of the specified time-spans, the agent will actively poll (and execute) tasks.
|
||||
# Use the same format as described above for uptime
|
||||
# downtime: []
|
||||
|
||||
# set to true in order to force "docker pull" before running an experiment using a docker image.
|
||||
# This makes sure the docker image is updated.
|
||||
docker_force_pull: false
|
||||
@@ -109,6 +148,16 @@
|
||||
# "(which {python_single_digit} && {python_single_digit} -m pip --version) || apt-get install -y {python_single_digit}-pip",
|
||||
# ]
|
||||
|
||||
# set the preprocessing bash script to execute at the startup of any docker.
|
||||
# all lines will be executed regardless of their exit code.
|
||||
# docker_preprocess_bash_script = [
|
||||
# "echo \"starting docker\"",
|
||||
#]
|
||||
|
||||
# If False replace \r with \n and display full console output
|
||||
# default is True, report a single \r line in a sequence of consecutive lines, per 5 seconds.
|
||||
# suppress_carriage_return: true
|
||||
|
||||
# cuda versions used for solving pytorch wheel packages
|
||||
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
|
||||
# cuda_version: 10.1
|
||||
|
||||
@@ -31,12 +31,18 @@
|
||||
# X images are stored in the upload destination for each matplotlib plot title.
|
||||
matplotlib_untitled_history_size: 100
|
||||
|
||||
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||
# plot_max_num_digits: 5
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
|
||||
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph)
|
||||
tensorboard_single_series_per_graph: false
|
||||
}
|
||||
|
||||
network {
|
||||
@@ -117,11 +123,11 @@
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
null_log_propagate: false
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
disable_urllib3_info: true
|
||||
}
|
||||
|
||||
development {
|
||||
@@ -131,14 +137,30 @@
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: True
|
||||
vcs_repo_detect_async: true
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
store_uncommitted_code_diff: true
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
support_stopping: true
|
||||
|
||||
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
|
||||
default_output_uri: ""
|
||||
|
||||
# Default auto generated requirements optimize for smaller requirements
|
||||
# If True, analyze the entire repository regardless of the entry point.
|
||||
# If False, first analyze the entry point script, if it does not contain other to local files,
|
||||
# do not analyze the entire repository.
|
||||
force_analyze_entire_repo: false
|
||||
|
||||
# If set to true, *trains* update message will not be printed to the console
|
||||
# this value can be overwritten with os environment variable TRAINS_SUPPRESS_UPDATE_MESSAGE=1
|
||||
suppress_update_message: false
|
||||
|
||||
# If this flag is true (default is false), instead of analyzing the code with Pigar, analyze with `pip freeze`
|
||||
detect_with_pip_freeze: false
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
@@ -149,7 +171,11 @@
|
||||
ping_period_sec: 30
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
log_stdout: true
|
||||
|
||||
# compatibility feature, report memory usage for the entire machine
|
||||
# default (false), report only on the running process and its sub-processes
|
||||
report_global_mem_used: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
ENV_CONDA_ENV_PACKAGE = EnvEntry("TRAINS_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
|
||||
|
||||
@@ -40,6 +40,7 @@ class Session(TokenManager):
|
||||
_session_requests = 0
|
||||
_session_initial_timeout = (3.0, 10.)
|
||||
_session_timeout = (10.0, 30.)
|
||||
_session_initial_connect_retry = 4
|
||||
_write_session_data_size = 15000
|
||||
_write_session_timeout = (30.0, 30.)
|
||||
|
||||
@@ -96,7 +97,7 @@ class Session(TokenManager):
|
||||
else:
|
||||
self.config = load()
|
||||
if initialize_logging:
|
||||
self.config.initialize_logging()
|
||||
self.config.initialize_logging(debug=kwargs.get('debug', False))
|
||||
|
||||
token_expiration_threshold_sec = self.config.get(
|
||||
"auth.token_expiration_threshold_sec", 60
|
||||
@@ -134,7 +135,6 @@ class Session(TokenManager):
|
||||
"api.http.retries", ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
http_retries_config["status_forcelist"] = self._retry_codes
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
self.__worker = worker or gethostname()
|
||||
|
||||
@@ -144,7 +144,14 @@ class Session(TokenManager):
|
||||
|
||||
self.client = client or "api-{}".format(__version__)
|
||||
|
||||
# limit the reconnect retries, so we get an error if we are starting the session
|
||||
http_no_retries_config = dict(**http_retries_config)
|
||||
http_no_retries_config['connect'] = self._session_initial_connect_retry
|
||||
self.__http_session = get_http_session_with_retry(**http_no_retries_config)
|
||||
# try to connect with the server
|
||||
self.refresh_token()
|
||||
# create the default session with many retries
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
# update api version from server response
|
||||
try:
|
||||
@@ -427,16 +434,15 @@ class Session(TokenManager):
|
||||
@classmethod
|
||||
def get_api_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
||||
config.get("api.host", None) or cls.default_host))
|
||||
|
||||
@classmethod
|
||||
def get_app_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
# get from config/environment
|
||||
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
|
||||
@@ -463,8 +469,8 @@ class Session(TokenManager):
|
||||
@classmethod
|
||||
def get_files_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
# get from config/environment
|
||||
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
|
||||
if files_host:
|
||||
@@ -546,6 +552,9 @@ class Session(TokenManager):
|
||||
else:
|
||||
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||
"exception: {}".format(res, ex))
|
||||
except requests.ConnectionError as ex:
|
||||
raise ValueError('Connection Error: it seems *api_server* is misconfigured. '
|
||||
'Is this the TRAINS API server {} ?'.format('/'.join(ex.request.url.split('/')[:3])))
|
||||
except Exception as ex:
|
||||
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ class Config(object):
|
||||
def reload(self):
|
||||
self.replace(self._reload())
|
||||
|
||||
def initialize_logging(self):
|
||||
def initialize_logging(self, debug=False):
|
||||
logging_config = self._config.get("logging", None)
|
||||
if not logging_config:
|
||||
return False
|
||||
@@ -217,6 +217,8 @@ class Config(object):
|
||||
)
|
||||
for logger in loggers:
|
||||
handlers = logger.get("handlers", None)
|
||||
if debug:
|
||||
logger['level'] = 'DEBUG'
|
||||
if not handlers:
|
||||
continue
|
||||
logger["handlers"] = [h for h in handlers if h not in deleted]
|
||||
|
||||
@@ -46,6 +46,15 @@ class Environment(object):
|
||||
local = 'local'
|
||||
|
||||
|
||||
class UptimeConf(object):
|
||||
min_api_version = "2.10"
|
||||
queue_tag_on = "force_workers:on"
|
||||
queue_tag_off = "force_workers:off"
|
||||
worker_key = "force"
|
||||
worker_value_off = ["off"]
|
||||
worker_value_on = ["on"]
|
||||
|
||||
|
||||
CONFIG_FILE_EXTENSION = '.conf'
|
||||
|
||||
|
||||
|
||||
@@ -142,6 +142,7 @@ def main():
|
||||
with open(str(conf_file), 'wt') as f:
|
||||
header = '# TRAINS-AGENT configuration file\n' \
|
||||
'api {\n' \
|
||||
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
|
||||
' api_server: %s\n' \
|
||||
' web_server: %s\n' \
|
||||
' files_server: %s\n' \
|
||||
|
||||
@@ -12,12 +12,12 @@ import sys
|
||||
import shutil
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from tempfile import mkdtemp, gettempdir
|
||||
from tempfile import mkdtemp, NamedTemporaryFile
|
||||
from time import sleep, time
|
||||
from typing import Text, Optional, Any, Tuple
|
||||
|
||||
@@ -30,12 +30,11 @@ from pathlib2 import Path
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from six.moves.urllib.parse import quote
|
||||
|
||||
from trains_agent.backend_config.defs import UptimeConf
|
||||
from trains_agent.helper.check_update import start_check_update_daemon
|
||||
from trains_agent.commands.base import resolve_names, ServiceCommandSection
|
||||
from trains_agent.definitions import (
|
||||
WORKER_ALREADY_REGISTERED,
|
||||
ENVIRONMENT_SDK_PARAMS,
|
||||
INVALID_WORKER_ID,
|
||||
PROGRAM_NAME,
|
||||
DEFAULT_VENV_UPDATE_URL,
|
||||
ENV_TASK_EXECUTE_AS_USER,
|
||||
@@ -66,12 +65,12 @@ from trains_agent.helper.base import (
|
||||
get_python_path,
|
||||
is_linux_platform,
|
||||
rm_file,
|
||||
add_python_path)
|
||||
add_python_path, safe_remove_tree, )
|
||||
from trains_agent.helper.console import ensure_text, print_text, decode_binary_lines
|
||||
from trains_agent.helper.os.daemonize import daemonize_process
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.package.conda_api import CondaAPI
|
||||
from trains_agent.helper.package.horovod_req import HorovodRequirement
|
||||
from trains_agent.helper.package.post_req import PostRequirement
|
||||
from trains_agent.helper.package.external_req import ExternalRequirements
|
||||
from trains_agent.helper.package.pip_api.system import SystemPip
|
||||
from trains_agent.helper.package.pip_api.venv import VirtualenvPip
|
||||
@@ -89,21 +88,21 @@ from trains_agent.helper.process import (
|
||||
get_bash_output,
|
||||
shutdown_docker_process,
|
||||
get_docker_id,
|
||||
commit_docker
|
||||
commit_docker, terminate_process,
|
||||
)
|
||||
from trains_agent.helper.package.cython_req import CythonRequirement
|
||||
from trains_agent.helper.package.priority_req import PriorityPackageRequirement, PackageCollectorRequirement
|
||||
from trains_agent.helper.repo import clone_repository_cached, RepoInfo, VCS
|
||||
from trains_agent.helper.resource_monitor import ResourceMonitor
|
||||
from trains_agent.helper.runtime_verification import check_runtime, print_uptime_properties
|
||||
from trains_agent.session import Session
|
||||
from trains_agent.helper.singleton import Singleton
|
||||
|
||||
from .events import Events
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DOCKER_ROOT_CONF_FILE = "/root/trains.conf"
|
||||
DOCKER_DEFAULT_CONF_FILE = "/root/default_trains.conf"
|
||||
|
||||
|
||||
@attr.s
|
||||
class LiteralScriptManager(object):
|
||||
"""
|
||||
@@ -122,7 +121,24 @@ class LiteralScriptManager(object):
|
||||
if not script:
|
||||
return False
|
||||
diff = script.diff
|
||||
return diff and not diff.strip().lower().startswith("diff ")
|
||||
if not diff:
|
||||
return False
|
||||
|
||||
# test git diff prefix
|
||||
if diff.lstrip().lower().startswith("diff "):
|
||||
return False
|
||||
|
||||
# test git submodule prefix
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if diff.lstrip().lower().startswith("submodule ") and \
|
||||
diff.splitlines()[1].lstrip().lower().startswith("diff "):
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# none of the above
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def write(task, directory, entry_point=None):
|
||||
@@ -151,10 +167,11 @@ class LiteralScriptManager(object):
|
||||
Create notebook file in appropriate location
|
||||
:return: directory and script path
|
||||
"""
|
||||
log = logging.getLogger(__name__)
|
||||
if repo_info and repo_info.root:
|
||||
location = Path(repo_info.root, execution.working_dir)
|
||||
else:
|
||||
if execution.working_dir:
|
||||
if execution.working_dir and execution.working_dir.strip() != '.':
|
||||
log.warning(
|
||||
"found task with `script.working_dir` (`%s`) but without `script.repository`, ignoring",
|
||||
execution.working_dir,
|
||||
@@ -212,6 +229,7 @@ class TaskStopSignal(object):
|
||||
statuses.stopped,
|
||||
statuses.failed,
|
||||
statuses.published,
|
||||
statuses.queued,
|
||||
]
|
||||
default = TaskStopReason.no_stop
|
||||
stopping_message = "stopping"
|
||||
@@ -263,7 +281,7 @@ class TaskStopSignal(object):
|
||||
)
|
||||
return TaskStopReason.stopped
|
||||
|
||||
if status in self.unexpected_statuses: ## and "worker" not in message:
|
||||
if status in self.unexpected_statuses: # ## and "worker" not in message:
|
||||
self.command.log("unexpected status change, task will terminate")
|
||||
return TaskStopReason.status_changed
|
||||
|
||||
@@ -301,9 +319,11 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
_requirement_substitutions = (
|
||||
PytorchRequirement,
|
||||
CythonRequirement,
|
||||
HorovodRequirement,
|
||||
PriorityPackageRequirement,
|
||||
PostRequirement,
|
||||
ExternalRequirements,
|
||||
partial(PackageCollectorRequirement, collect_package=['trains']),
|
||||
partial(PackageCollectorRequirement, collect_package=['clearml']),
|
||||
)
|
||||
|
||||
# poll queues every _polling_interval seconds
|
||||
@@ -319,6 +339,7 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
_run_as_user_home = '/trains_agent_home'
|
||||
_docker_fixed_user_cache = '/trains_agent_cache'
|
||||
_temp_cleanup_list = []
|
||||
|
||||
@property
|
||||
def service(self):
|
||||
@@ -332,6 +353,8 @@ class Worker(ServiceCommandSection):
|
||||
@staticmethod
|
||||
def register_signal_handler():
|
||||
def handler(*_):
|
||||
for f in Worker._temp_cleanup_list + [Singleton.get_pid_file()]:
|
||||
safe_remove_tree(f)
|
||||
raise Sigterm()
|
||||
|
||||
signal.signal(signal.SIGTERM, handler)
|
||||
@@ -372,7 +395,7 @@ class Worker(ServiceCommandSection):
|
||||
self.temp_config_path = None
|
||||
self.queues = ()
|
||||
self.venv_folder = None # type: Optional[Text]
|
||||
self.package_api = None # type: PackageManager
|
||||
self.package_api = None # type: Optional[PackageManager]
|
||||
self.global_package_api = None
|
||||
|
||||
self.is_venv_update = self._session.config.agent.venv_update.enabled
|
||||
@@ -388,6 +411,15 @@ class Worker(ServiceCommandSection):
|
||||
self._standalone_mode = None
|
||||
self._services_mode = None
|
||||
self._force_current_version = None
|
||||
self._redirected_stdout_file_no = None
|
||||
self._uptime_config = self._session.config.get("agent.uptime", None)
|
||||
self._downtime_config = self._session.config.get("agent.downtime", None)
|
||||
self._suppress_cr = self._session.config.get("agent.suppress_carriage_return", True)
|
||||
|
||||
# True - supported
|
||||
# None - not initialized
|
||||
# str - not supported, version string indicates last server version
|
||||
self._runtime_props_support = None
|
||||
|
||||
@classmethod
|
||||
def _verify_command_states(cls, kwargs):
|
||||
@@ -433,7 +465,7 @@ class Worker(ServiceCommandSection):
|
||||
pass
|
||||
|
||||
def run_one_task(self, queue, task_id, worker_args, docker=None):
|
||||
# type: (Text, Text, WorkerParams) -> ()
|
||||
# type: (Text, Text, WorkerParams, Optional[Text]) -> ()
|
||||
"""
|
||||
Run one task pulled from queue.
|
||||
:param queue: ID of queue that task was pulled from
|
||||
@@ -497,7 +529,7 @@ class Worker(ServiceCommandSection):
|
||||
full_docker_cmd = self.docker_image_func(docker_image=docker_image, docker_arguments=docker_arguments)
|
||||
try:
|
||||
self._session.send_api(
|
||||
tasks_api.EditRequest(task_id, force=True, execution=dict(
|
||||
tasks_api.EditRequest(task_id, force=True, execution=dict( # noqa
|
||||
docker_cmd=' '.join([docker_image] + docker_arguments) if docker_arguments else docker_image)))
|
||||
except Exception:
|
||||
pass
|
||||
@@ -509,6 +541,10 @@ class Worker(ServiceCommandSection):
|
||||
'--full-monitoring' if self._services_mode else '--disable-monitoring',
|
||||
'--standalone-mode' if self._standalone_mode else '',
|
||||
task_id)
|
||||
|
||||
# send the actual used command line to the backend
|
||||
self.send_logs(task_id=task_id, lines=['Executing: {}\n'.format(full_docker_cmd)], level="INFO")
|
||||
|
||||
cmd = Argv(*full_docker_cmd)
|
||||
print('Running Docker:\n{}\n'.format(str(cmd)))
|
||||
else:
|
||||
@@ -569,12 +605,12 @@ class Worker(ServiceCommandSection):
|
||||
else:
|
||||
self.handle_task_termination(task_id, status, stop_signal_status)
|
||||
# remove temp files after we sent everything to the backend
|
||||
safe_remove_file(temp_stdout_name)
|
||||
safe_remove_file(temp_stderr_name)
|
||||
if self.docker_image_func:
|
||||
shutdown_docker_process(docker_cmd_contains='--id {}\'\"'.format(task_id))
|
||||
safe_remove_file(temp_stdout_name)
|
||||
safe_remove_file(temp_stderr_name)
|
||||
|
||||
def run_tasks_loop(self, queues, worker_params):
|
||||
def run_tasks_loop(self, queues, worker_params, priority_order=True):
|
||||
"""
|
||||
:summary: Pull and run tasks from queues.
|
||||
:description: 1. Go through ``queues`` by order.
|
||||
@@ -584,16 +620,28 @@ class Worker(ServiceCommandSection):
|
||||
:type queues: list of ``Text``
|
||||
:param worker_params: Worker command line arguments
|
||||
:type worker_params: ``trains_agent.helper.process.WorkerParams``
|
||||
:param priority_order: If True pull order in priority manner. always from the first
|
||||
If False, pull from each queue once in a round robin manner
|
||||
:type priority_order: bool
|
||||
"""
|
||||
|
||||
if not self._daemon_foreground:
|
||||
print('Starting infinite task polling loop...')
|
||||
|
||||
_last_machine_update_ts = 0
|
||||
while True:
|
||||
|
||||
while True:
|
||||
queue_tags = None
|
||||
runtime_props = None
|
||||
# iterate over queues (priority style, queues[0] is highest)
|
||||
for queue in queues:
|
||||
|
||||
if queue_tags is None or runtime_props is None:
|
||||
queue_tags, runtime_props = self.get_worker_properties(queues)
|
||||
|
||||
if not self.should_be_currently_active(queue_tags[queue], runtime_props):
|
||||
continue
|
||||
|
||||
# get next task in queue
|
||||
try:
|
||||
response = self._session.send_api(
|
||||
@@ -614,6 +662,16 @@ class Worker(ServiceCommandSection):
|
||||
print("No tasks in queue {}".format(queue))
|
||||
continue
|
||||
|
||||
# clear output log if we start a new Task
|
||||
if not worker_params.debug and self._redirected_stdout_file_no is not None and \
|
||||
self._redirected_stdout_file_no > 2:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.lseek(self._redirected_stdout_file_no, 0, 0)
|
||||
os.ftruncate(self._redirected_stdout_file_no, 0)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.send_logs(
|
||||
task_id=task_id,
|
||||
lines=["task {} pulled from {} by worker {}\n".format(task_id, queue, self.worker_id)],
|
||||
@@ -622,7 +680,14 @@ class Worker(ServiceCommandSection):
|
||||
self.report_monitor(ResourceMonitor.StatusReport(queues=queues, queue=queue, task=task_id))
|
||||
self.run_one_task(queue, task_id, worker_params)
|
||||
self.report_monitor(ResourceMonitor.StatusReport(queues=self.queues))
|
||||
break
|
||||
|
||||
queue_tags = None
|
||||
runtime_props = None
|
||||
|
||||
# if we are using priority start pulling from the first always,
|
||||
# if we are doing round robin, pull from the next one
|
||||
if priority_order:
|
||||
break
|
||||
else:
|
||||
# sleep and retry polling
|
||||
if self._daemon_foreground or worker_params.debug:
|
||||
@@ -632,6 +697,77 @@ class Worker(ServiceCommandSection):
|
||||
if self._session.config["agent.reload_config"]:
|
||||
self.reload_config()
|
||||
|
||||
def get_worker_properties(self, queue_ids):
|
||||
queue_tags = {
|
||||
q.id: {'name': q.name, 'tags': q.tags}
|
||||
for q in self._session.send_api(
|
||||
queues_api.GetAllRequest(id=queue_ids, only_fields=["id", "tags"])
|
||||
).queues
|
||||
}
|
||||
runtime_props = self.get_runtime_properties()
|
||||
return queue_tags, runtime_props
|
||||
|
||||
def get_runtime_properties(self):
|
||||
if self._runtime_props_support is not True:
|
||||
# either not supported or never tested
|
||||
if self._runtime_props_support == self._session.api_version:
|
||||
# tested against latest api_version, not supported
|
||||
return []
|
||||
if not self._session.check_min_api_version(UptimeConf.min_api_version):
|
||||
# not supported due to insufficient api_version
|
||||
self._runtime_props_support = self._session.api_version
|
||||
return []
|
||||
try:
|
||||
res = self.get("get_runtime_properties", worker=self.worker_id)["runtime_properties"]
|
||||
# definitely supported
|
||||
self._runtime_props_support = True
|
||||
return res
|
||||
except APIError:
|
||||
self._runtime_props_support = self._session.api_version
|
||||
return []
|
||||
|
||||
def should_be_currently_active(self, current_queue, runtime_properties):
|
||||
"""
|
||||
Checks if a worker is active according to queue tags, worker's runtime properties and uptime schedule.
|
||||
"""
|
||||
if UptimeConf.queue_tag_off in current_queue['tags']:
|
||||
self.log.debug("Queue {} is tagged '{}', worker will not pull tasks".format(
|
||||
current_queue['name'], UptimeConf.queue_tag_off)
|
||||
)
|
||||
return False
|
||||
if UptimeConf.queue_tag_on in current_queue['tags']:
|
||||
self.log.debug("Queue {} is tagged '{}', worker will pull tasks".format(
|
||||
current_queue['name'], UptimeConf.queue_tag_on)
|
||||
)
|
||||
return True
|
||||
force_flag = next(
|
||||
(prop for prop in runtime_properties if prop["key"] == UptimeConf.worker_key), None
|
||||
)
|
||||
if force_flag:
|
||||
if force_flag["value"].lower() in UptimeConf.worker_value_off:
|
||||
self.log.debug("worker has the following runtime property: '{}'. worker will not pull tasks".format(
|
||||
force_flag)
|
||||
)
|
||||
return False
|
||||
elif force_flag["value"].lower() in UptimeConf.worker_value_on:
|
||||
self.log.debug("worker has the following runtime property: '{}'. worker will pull tasks".format(
|
||||
force_flag)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"Warning: invalid runtime_property '{}: {}' supported values are: '{}/{}', ignoring".format(
|
||||
force_flag["key"], force_flag["value"], UptimeConf.worker_value_on, UptimeConf.worker_value_off
|
||||
)
|
||||
)
|
||||
if self._uptime_config:
|
||||
self.log.debug("following uptime configurations")
|
||||
return check_runtime(self._uptime_config)
|
||||
if self._downtime_config:
|
||||
self.log.debug("following downtime configurations")
|
||||
return check_runtime(self._downtime_config, is_uptime=False)
|
||||
return True
|
||||
|
||||
def reload_config(self):
|
||||
try:
|
||||
reloaded = self._session.reload()
|
||||
@@ -649,7 +785,7 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
def check(self, **_):
|
||||
try:
|
||||
check_directory_path(str(Path(".").resolve()))
|
||||
check_directory_path(str(Path(".").resolve()), check_whitespace_in_path=False)
|
||||
except OSError as e:
|
||||
if e.errno == errno.ENOENT:
|
||||
raise CommandFailedError("current working directory does not exist")
|
||||
@@ -674,7 +810,7 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
self._session.print_configuration()
|
||||
|
||||
def daemon(self, queues, log_level, foreground=False, docker=False, detached=False, **kwargs):
|
||||
def daemon(self, queues, log_level, foreground=False, docker=False, detached=False, order_fairness=False, **kwargs):
|
||||
# if we do not need to create queues, make sure they are valid
|
||||
# match previous behaviour when we validated queue names before everything else
|
||||
queues = self._resolve_queue_names(queues, create_if_missing=kwargs.get('create_queue', False))
|
||||
@@ -685,6 +821,34 @@ class Worker(ServiceCommandSection):
|
||||
if self._services_mode:
|
||||
kwargs = self._verify_command_states(kwargs)
|
||||
docker = docker or kwargs.get('docker')
|
||||
self._uptime_config = kwargs.get('uptime', None) or self._uptime_config
|
||||
self._downtime_config = kwargs.get('downtime', None) or self._downtime_config
|
||||
if self._uptime_config and self._downtime_config:
|
||||
self.log.error(
|
||||
"Both uptime and downtime were specified when only one of them could be used. Both will be ignored."
|
||||
)
|
||||
self._uptime_config = None
|
||||
self._downtime_config = None
|
||||
|
||||
# We are not running a daemon we are killing one.
|
||||
# find the pid send termination signal and leave
|
||||
if kwargs.get('stop', False):
|
||||
return 1 if not self._kill_daemon() else 0
|
||||
|
||||
queues_info = [
|
||||
q.to_dict()
|
||||
for q in self._session.send_api(
|
||||
queues_api.GetAllRequest(id=queues)
|
||||
).queues
|
||||
]
|
||||
|
||||
if kwargs.get('status', False):
|
||||
runtime_properties = self.get_runtime_properties()
|
||||
if self._downtime_config:
|
||||
print_uptime_properties(self._downtime_config, queues_info, runtime_properties, is_uptime=False)
|
||||
else:
|
||||
print_uptime_properties(self._uptime_config, queues_info, runtime_properties)
|
||||
return 1
|
||||
|
||||
# make sure we only have a single instance,
|
||||
# also make sure we set worker_id properly and cache folders
|
||||
@@ -697,12 +861,6 @@ class Worker(ServiceCommandSection):
|
||||
self.log.debug("starting resource monitor thread")
|
||||
print("Worker \"{}\" - ".format(self.worker_id), end='')
|
||||
|
||||
queues_info = [
|
||||
self._session.send_api(
|
||||
queues_api.GetByIdRequest(queue)
|
||||
).queue.to_dict()
|
||||
for queue in queues
|
||||
]
|
||||
columns = ("id", "name", "tags")
|
||||
print("Listening to queues:")
|
||||
print_table(queues_info, columns=columns, titles=columns)
|
||||
@@ -711,9 +869,8 @@ class Worker(ServiceCommandSection):
|
||||
self._register(queues)
|
||||
|
||||
# create temp config file with current configuration
|
||||
self.temp_config_path = safe_mkstemp(
|
||||
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
|
||||
)
|
||||
self.temp_config_path = NamedTemporaryFile(
|
||||
suffix=".cfg", prefix=".trains_agent.", mode='w+t').name
|
||||
|
||||
# print docker image
|
||||
if docker is not False and docker is not None:
|
||||
@@ -753,6 +910,9 @@ class Worker(ServiceCommandSection):
|
||||
)
|
||||
)
|
||||
|
||||
if not self._session.debug_mode:
|
||||
self._temp_cleanup_list.append(name)
|
||||
|
||||
if not detached:
|
||||
# redirect std out/err to new file
|
||||
sys.stdout = sys.stderr = out_file
|
||||
@@ -760,6 +920,7 @@ class Worker(ServiceCommandSection):
|
||||
# in detached mode
|
||||
# fully detach stdin.stdout/stderr and leave main process, running in the background
|
||||
daemonize_process(out_file.fileno())
|
||||
self._redirected_stdout_file_no = out_file.fileno()
|
||||
# make sure we update the singleton lock file to the new pid
|
||||
Singleton.update_pid_file()
|
||||
# reprint headers to std file (we are now inside the daemon process)
|
||||
@@ -779,6 +940,7 @@ class Worker(ServiceCommandSection):
|
||||
debug=self._session.debug_mode,
|
||||
trace=self._session.trace,
|
||||
),
|
||||
priority_order=not order_fairness,
|
||||
)
|
||||
except Exception:
|
||||
tb = six.text_type(traceback.format_exc())
|
||||
@@ -842,22 +1004,27 @@ class Worker(ServiceCommandSection):
|
||||
stop_signal=None, # type: Optional[TaskStopSignal]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> Tuple[Optional[int], TaskStopReason]
|
||||
def _print_file(file_path, prev_line_count):
|
||||
# type: (...) -> Tuple[Optional[int], Optional[TaskStopReason]]
|
||||
def _print_file(file_path, prev_pos=0):
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(prev_pos)
|
||||
binary_text = f.read()
|
||||
if not binary_text:
|
||||
return []
|
||||
pos = f.tell()
|
||||
# skip the previously printed lines,
|
||||
blines = binary_text.split(b'\n')[prev_line_count:]
|
||||
blines = binary_text.split(b'\n') if binary_text else []
|
||||
if not blines:
|
||||
return blines
|
||||
return decode_binary_lines(blines if blines[-1] else blines[:-1])
|
||||
return blines, pos
|
||||
return (
|
||||
decode_binary_lines(blines if blines[-1] else blines[:-1],
|
||||
replace_cr=not self._suppress_cr,
|
||||
overwrite_cr=self._suppress_cr),
|
||||
pos
|
||||
)
|
||||
|
||||
stdout = open(stdout_path, "wt")
|
||||
stderr = open(stderr_path, "wt") if stderr_path else stdout
|
||||
stdout_line_count, stdout_last_lines = 0, []
|
||||
stderr_line_count, stderr_last_lines = 0, []
|
||||
stdout_line_count, stdout_pos_count, stdout_last_lines = 0, 0, []
|
||||
stderr_line_count, stderr_pos_count, stderr_last_lines = 0, 0, []
|
||||
service_mode_internal_agent_started = None
|
||||
stopping = False
|
||||
status = None
|
||||
@@ -896,7 +1063,7 @@ class Worker(ServiceCommandSection):
|
||||
stderr.flush()
|
||||
|
||||
# get diff from previous poll
|
||||
printed_lines = _print_file(stdout_path, stdout_line_count)
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
if self._services_mode and not stopping and not status:
|
||||
# if the internal agent started, we stop logging, it will take over logging.
|
||||
# if the internal agent started running the task itself, it will return status==0,
|
||||
@@ -906,13 +1073,10 @@ class Worker(ServiceCommandSection):
|
||||
if status is not None:
|
||||
stop_reason = 'Service started'
|
||||
|
||||
stdout_line_count += self.send_logs(
|
||||
task_id, printed_lines
|
||||
)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(
|
||||
task_id, _print_file(stderr_path, stderr_line_count)
|
||||
)
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
|
||||
except subprocess.CalledProcessError as ex:
|
||||
# non zero return code
|
||||
@@ -923,9 +1087,11 @@ class Worker(ServiceCommandSection):
|
||||
raise
|
||||
except Exception:
|
||||
# we should not get here, but better safe than sorry
|
||||
stdout_line_count += self.send_logs(task_id, _print_file(stdout_path, stdout_line_count))
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(task_id, _print_file(stderr_path, stderr_line_count))
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
stop_reason = 'Exception occurred'
|
||||
status = -1
|
||||
|
||||
@@ -939,13 +1105,11 @@ class Worker(ServiceCommandSection):
|
||||
stderr.close()
|
||||
|
||||
# Send last lines
|
||||
stdout_line_count += self.send_logs(
|
||||
task_id, _print_file(stdout_path, stdout_line_count)
|
||||
)
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(
|
||||
task_id, _print_file(stderr_path, stderr_line_count)
|
||||
)
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
|
||||
return status, stop_reason
|
||||
|
||||
@@ -1024,7 +1188,8 @@ class Worker(ServiceCommandSection):
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
raise ValueError("Failed applying git diff:\n{}\n\nERROR! Failed applying git diff, see diff above.".format(diff))
|
||||
raise ValueError("Failed applying git diff:\n{}\n\n"
|
||||
"ERROR! Failed applying git diff, see diff above.".format(diff))
|
||||
|
||||
@resolve_names
|
||||
def build(
|
||||
@@ -1049,10 +1214,15 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
execution = self.get_execution_info(current_task)
|
||||
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
|
||||
requirements = None
|
||||
print("[package_manager.force_repo_requirements_txt=true] "
|
||||
"Skipping requirements, using repository \"requirements.txt\" ")
|
||||
else:
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
requirements = None
|
||||
|
||||
if not python_version:
|
||||
try:
|
||||
@@ -1063,8 +1233,8 @@ class Worker(ServiceCommandSection):
|
||||
except:
|
||||
python_version = None
|
||||
|
||||
venv_folder, requirements_manager = self.install_virtualenv(venv_dir=target,
|
||||
requested_python_version=python_version)
|
||||
venv_folder, requirements_manager = self.install_virtualenv(
|
||||
venv_dir=target, requested_python_version=python_version, execution_info=execution)
|
||||
|
||||
if self._default_pip:
|
||||
if install_globally and self.global_package_api:
|
||||
@@ -1219,7 +1389,8 @@ class Worker(ServiceCommandSection):
|
||||
print("Cloning task id={}".format(task_id))
|
||||
current_task = self._session.api_client.tasks.get_by_id(
|
||||
self._session.send_api(
|
||||
tasks_api.CloneRequest(task=current_task.id, new_task_name='Clone of {}'.format(current_task.name))
|
||||
tasks_api.CloneRequest(task=current_task.id,
|
||||
new_task_name='Clone of {}'.format(current_task.name))
|
||||
).id
|
||||
)
|
||||
print("Task cloned, new task id={}".format(current_task.id))
|
||||
@@ -1269,10 +1440,15 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
execution = self.get_execution_info(current_task)
|
||||
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
|
||||
requirements = None
|
||||
print("[package_manager.force_repo_requirements_txt=true] "
|
||||
"Skipping requirements, using repository \"requirements.txt\" ")
|
||||
else:
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
requirements = None
|
||||
|
||||
try:
|
||||
python_ver = current_task.script.binary
|
||||
@@ -1282,8 +1458,8 @@ class Worker(ServiceCommandSection):
|
||||
except:
|
||||
python_ver = None
|
||||
|
||||
venv_folder, requirements_manager = self.install_virtualenv(standalone_mode=standalone_mode,
|
||||
requested_python_version=python_ver)
|
||||
venv_folder, requirements_manager = self.install_virtualenv(
|
||||
standalone_mode=standalone_mode, requested_python_version=python_ver, execution_info=execution)
|
||||
|
||||
if not standalone_mode:
|
||||
if self._default_pip:
|
||||
@@ -1308,7 +1484,10 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
# do not update the task packages if we are using conda,
|
||||
# it will most likely make the task environment unreproducible
|
||||
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None,
|
||||
skip_freeze_update = self.is_conda and not self._session.config.get(
|
||||
"agent.package_manager.conda_full_env_update", False)
|
||||
|
||||
freeze = self.freeze_task_environment(current_task.id if not skip_freeze_update else None,
|
||||
requirements_manager=requirements_manager)
|
||||
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
|
||||
|
||||
@@ -1362,8 +1541,7 @@ class Worker(ServiceCommandSection):
|
||||
self._update_commit_id(current_task.id, execution, repo_info)
|
||||
|
||||
# Add the script CWD to the python path
|
||||
python_path = get_python_path(script_dir, execution.entry_point, self.package_api) \
|
||||
if not self.is_conda else None
|
||||
python_path = get_python_path(script_dir, execution.entry_point, self.package_api, is_conda_env=self.is_conda)
|
||||
if os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH):
|
||||
python_path = add_python_path(python_path, os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH))
|
||||
if python_path:
|
||||
@@ -1661,7 +1839,7 @@ class Worker(ServiceCommandSection):
|
||||
def install_requirements_for_package_api(
|
||||
self, execution, repo_info, requirements_manager, cached_requirements=None, cwd=None, package_api=None,
|
||||
):
|
||||
# type: (ExecutionInfo, RepoInfo, RequirementsManager, Optional[dict]) -> None
|
||||
# type: (ExecutionInfo, RepoInfo, RequirementsManager, Optional[dict], Optional[str], Optional[Any]) -> None
|
||||
"""
|
||||
:summary: Install requirements for task script using pip.
|
||||
:description: A file named "requirements.txt" is looked for in each containing folder between the
|
||||
@@ -1671,6 +1849,7 @@ class Worker(ServiceCommandSection):
|
||||
:param repo_info: repository information
|
||||
:param requirements_manager: requirements manager for task
|
||||
:param cached_requirements: cached requirements from previous run
|
||||
:param cwd: current folder
|
||||
:param package_api: package_api to be used when installing requirements
|
||||
"""
|
||||
if package_api:
|
||||
@@ -1684,12 +1863,13 @@ class Worker(ServiceCommandSection):
|
||||
package_api.set_selected_package_manager()
|
||||
# always install cython,
|
||||
# if we have a specific version in the requirements,
|
||||
# the CythonRequirement(SimpleSubstitution) will reinstall cython with the specific version
|
||||
# the PriorityPackageRequirement(SimpleSubstitution) will reinstall cython with the specific version
|
||||
if not self.is_conda:
|
||||
package_api.out_of_scope_install_package('Cython')
|
||||
|
||||
cached_requirements_failed = False
|
||||
if cached_requirements and ('pip' in cached_requirements or 'conda' in cached_requirements):
|
||||
if cached_requirements and (cached_requirements.get('pip') is not None or
|
||||
cached_requirements.get('conda') is not None):
|
||||
self.log("Found task requirements section, trying to install")
|
||||
try:
|
||||
package_api.load_requirements(cached_requirements)
|
||||
@@ -1859,8 +2039,9 @@ class Worker(ServiceCommandSection):
|
||||
)
|
||||
)
|
||||
|
||||
def install_virtualenv(self, venv_dir=None, requested_python_version=None, standalone_mode=False):
|
||||
# type: (str, str, bool) -> Tuple[Path, RequirementsManager]
|
||||
def install_virtualenv(
|
||||
self, venv_dir=None, requested_python_version=None, standalone_mode=False, execution_info=None):
|
||||
# type: (str, str, bool, ExecutionInfo) -> Tuple[Path, RequirementsManager]
|
||||
"""
|
||||
Install a new python virtual environment, removing the old one if exists
|
||||
:return: virtualenv directory and requirements manager to use with task
|
||||
@@ -1907,6 +2088,7 @@ class Worker(ServiceCommandSection):
|
||||
python=executable_version_suffix if self.is_conda else executable_name,
|
||||
path=venv_dir,
|
||||
requirements_manager=requirements_manager,
|
||||
execution_info=execution_info,
|
||||
)
|
||||
|
||||
global_package_manager_params = dict(
|
||||
@@ -2005,26 +2187,33 @@ class Worker(ServiceCommandSection):
|
||||
mounted_pip_dl_dir = '/root/.trains/pip-download-cache'
|
||||
mounted_vcs_cache = '/root/.trains/vcs-cache'
|
||||
mounted_venv_dir = '/root/.trains/venvs-builds'
|
||||
host_cache = Path(os.path.expandvars(self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
|
||||
host_pip_dl = Path(os.path.expandvars(self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
|
||||
host_vcs_cache = Path(os.path.expandvars(self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
|
||||
host_cache = Path(os.path.expandvars(
|
||||
self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
|
||||
host_pip_dl = Path(os.path.expandvars(
|
||||
self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
|
||||
host_vcs_cache = Path(os.path.expandvars(
|
||||
self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
|
||||
temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir)
|
||||
temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir)
|
||||
temp_config.put("agent.vcs_cache.path", mounted_vcs_cache)
|
||||
temp_config.put("agent.package_manager.system_site_packages", True)
|
||||
temp_config.put("agent.package_manager.conda_env_as_base_docker", False)
|
||||
temp_config.put("agent.default_python", "")
|
||||
temp_config.put("agent.python_binary", "")
|
||||
temp_config.put("agent.cuda_version", "")
|
||||
temp_config.put("agent.cudnn_version", "")
|
||||
temp_config.put("agent.venvs_dir", mounted_venv_dir)
|
||||
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or self._session.config.get("agent.git_user", None)))
|
||||
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or self._session.config.get("agent.git_pass", None)))
|
||||
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or
|
||||
self._session.config.get("agent.git_user", None)))
|
||||
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or
|
||||
self._session.config.get("agent.git_pass", None)))
|
||||
|
||||
host_apt_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_apt_cache", '~/.trains/apt-cache'))).expanduser().as_posix()
|
||||
host_pip_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_pip_cache", '~/.trains/pip-cache'))).expanduser().as_posix()
|
||||
host_ssh_cache = mkdtemp(prefix='trains_agent.ssh.')
|
||||
self._temp_cleanup_list.append(host_ssh_cache)
|
||||
|
||||
# make sure all folders are valid
|
||||
Path(host_apt_cache).mkdir(parents=True, exist_ok=True)
|
||||
@@ -2042,7 +2231,7 @@ class Worker(ServiceCommandSection):
|
||||
shutil.copytree(Path('~/.ssh').expanduser().as_posix(), host_ssh_cache)
|
||||
except Exception:
|
||||
host_ssh_cache = None
|
||||
log.warning('Failed creating temporary copy of ~/.ssh for git credential')
|
||||
self.log.warning('Failed creating temporary copy of ~/.ssh for git credential')
|
||||
pass
|
||||
|
||||
# check if the .git credentials exist:
|
||||
@@ -2065,6 +2254,7 @@ class Worker(ServiceCommandSection):
|
||||
extra_shell_script_str = " ; ".join(map(str, cmds)) + " ; "
|
||||
|
||||
bash_script = self._session.config.get("agent.docker_init_bash_script", None)
|
||||
preprocess_bash_script = self._session.config.get("agent.docker_preprocess_bash_script", None)
|
||||
|
||||
self.temp_config_path = self.temp_config_path or safe_mkstemp(
|
||||
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
|
||||
@@ -2085,6 +2275,7 @@ class Worker(ServiceCommandSection):
|
||||
standalone_mode=self._standalone_mode,
|
||||
force_current_version=self._force_current_version,
|
||||
bash_script=bash_script,
|
||||
preprocess_bash_script=preprocess_bash_script,
|
||||
)
|
||||
return temp_config, partial(docker_cmd_functor, docker_cmd, temp_config)
|
||||
|
||||
@@ -2098,7 +2289,8 @@ class Worker(ServiceCommandSection):
|
||||
host_pip_dl, mounted_pip_dl,
|
||||
host_vcs_cache, mounted_vcs_cache,
|
||||
standalone_mode=False, extra_docker_arguments=None, extra_shell_script=None,
|
||||
force_current_version=None, host_git_credentials=None, bash_script=None):
|
||||
force_current_version=None, host_git_credentials=None,
|
||||
bash_script=None, preprocess_bash_script=None):
|
||||
docker = 'docker'
|
||||
|
||||
base_cmd = [docker, 'run', '-t']
|
||||
@@ -2115,7 +2307,7 @@ class Worker(ServiceCommandSection):
|
||||
if os.environ.get('TRAINS_DOCKER_SKIP_GPUS_FLAG', None):
|
||||
dockers_nvidia_visible_devices = gpu_devices
|
||||
else:
|
||||
base_cmd += ['--gpus', 'device='+gpu_devices, ]
|
||||
base_cmd += ['--gpus', '\"device={}\"'.format(gpu_devices), ]
|
||||
# We are using --gpu, so we should not pass NVIDIA_VISIBLE_DEVICES, I think.
|
||||
# base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES=' + gpu_devices, ]
|
||||
elif gpu_devices.strip() == 'none':
|
||||
@@ -2132,7 +2324,8 @@ class Worker(ServiceCommandSection):
|
||||
base_cmd += [str(a) for a in extra_docker_arguments if a]
|
||||
|
||||
# check if running inside a kubernetes
|
||||
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and os.environ.get('KUBERNETES_PORT')):
|
||||
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and
|
||||
os.environ.get('KUBERNETES_PORT')):
|
||||
# map network to sibling docker, unless we have other network argument
|
||||
if not any(a.strip().startswith('--network') for a in base_cmd):
|
||||
try:
|
||||
@@ -2154,7 +2347,7 @@ class Worker(ServiceCommandSection):
|
||||
print('Warning: K8S mount missing, ignoring cached folder {}'.format(m))
|
||||
host_mounts[i] = None
|
||||
else:
|
||||
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt)
|
||||
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1)
|
||||
host_apt_cache, host_pip_cache, host_pip_dl, host_cache, host_vcs_cache = host_mounts
|
||||
|
||||
# copy the configuration file into the mounted folder
|
||||
@@ -2177,6 +2370,8 @@ class Worker(ServiceCommandSection):
|
||||
raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
|
||||
|
||||
base_cmd += ['-e', 'TRAINS_WORKER_ID='+worker_id, ]
|
||||
# update the docker image, so the system knows where it runs
|
||||
base_cmd += ['-e', 'TRAINS_DOCKER_IMAGE={} {}'.format(docker_image, ' '.join(docker_arguments)).strip()]
|
||||
|
||||
# if we are running a RC version, install the same version in the docker
|
||||
# because the default latest, will be a release version (not RC)
|
||||
@@ -2200,21 +2395,33 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
if not standalone_mode:
|
||||
if not bash_script:
|
||||
# Find the highest python version installed, or install from apt-get
|
||||
# python+pip is the requirement to match
|
||||
bash_script = [
|
||||
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
"chown -R root /root/.cache/pip",
|
||||
"export DEBIAN_FRONTEND=noninteractive",
|
||||
"apt-get update",
|
||||
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
"(which {python_single_digit} && {python_single_digit} -m pip --version) || " +
|
||||
"apt-get install -y {python_single_digit}-pip",
|
||||
"declare LOCAL_PYTHON",
|
||||
"for i in {{10..5}}; do which {python_single_digit}.$i && " +
|
||||
"{python_single_digit}.$i -m pip --version && " +
|
||||
"export LOCAL_PYTHON=$(which {python_single_digit}.$i) && break ; done",
|
||||
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y {python_single_digit}-pip",
|
||||
]
|
||||
|
||||
if preprocess_bash_script:
|
||||
bash_script = preprocess_bash_script + bash_script
|
||||
|
||||
docker_bash_script = " ; ".join(bash_script) if not isinstance(bash_script, str) else bash_script
|
||||
|
||||
# make sure that if we do not have $LOCAL_PYTHON defined
|
||||
# we set it to python3
|
||||
update_scheme += (
|
||||
docker_bash_script + " ; " +
|
||||
"{python} -m pip install -U \"pip{pip_version}\" ; " +
|
||||
"{python} -m pip install -U {trains_agent_wheel} ; ").format(
|
||||
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON={python} ; " +
|
||||
"$LOCAL_PYTHON -m pip install -U \"pip{pip_version}\" ; " +
|
||||
"$LOCAL_PYTHON -m pip install -U {trains_agent_wheel} ; ").format(
|
||||
python_single_digit=python_version.split('.')[0],
|
||||
python=python_version, pip_version=PackageManager.get_pip_version(),
|
||||
trains_agent_wheel=trains_agent_wheel)
|
||||
@@ -2235,7 +2442,7 @@ class Worker(ServiceCommandSection):
|
||||
update_scheme +
|
||||
extra_shell_script +
|
||||
"cp {} {} ; ".format(DOCKER_ROOT_CONF_FILE, DOCKER_DEFAULT_CONF_FILE) +
|
||||
"NVIDIA_VISIBLE_DEVICES={nv_visible} {python} -u -m trains_agent ".format(
|
||||
"NVIDIA_VISIBLE_DEVICES={nv_visible} $LOCAL_PYTHON -u -m trains_agent ".format(
|
||||
nv_visible=dockers_nvidia_visible_devices, python=python_version)
|
||||
])
|
||||
|
||||
@@ -2266,7 +2473,8 @@ class Worker(ServiceCommandSection):
|
||||
os.setuid(self.uid)
|
||||
|
||||
# create a home folder for our user
|
||||
trains_agent_home = self._run_as_user_home + '{}'.format('.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
|
||||
trains_agent_home = self._run_as_user_home + '{}'.format(
|
||||
'.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
|
||||
try:
|
||||
home_folder = self._run_as_user_home
|
||||
rm_tree(home_folder)
|
||||
@@ -2322,18 +2530,26 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
return command, script_dir
|
||||
|
||||
def _kill_daemon(self):
|
||||
worker_id, worker_name = self._generate_worker_id_name()
|
||||
# Iterate over all running process
|
||||
for pid, uid, slot, file in sorted(Singleton.get_running_pids(), key=lambda x: x[1] or ''):
|
||||
# wither we have a match for the worker_id or we just pick the first one
|
||||
if pid >= 0 and uid is not None and (
|
||||
(worker_id and uid == worker_id) or
|
||||
(not worker_id and uid.startswith('{}:'.format(worker_name)))):
|
||||
# this is us kill it
|
||||
print('Terminating trains-agent worker_id={} pid={}'.format(uid, pid))
|
||||
if not terminate_process(pid, timeout=10):
|
||||
error('Could not terminate process pid={}'.format(pid))
|
||||
return True
|
||||
print('Could not find a running trains-agent instance with worker_name={} worker_id={}'.format(
|
||||
worker_name, worker_id))
|
||||
return False
|
||||
|
||||
def _singleton(self):
|
||||
# ensure singleton
|
||||
worker_id = self._session.config["agent.worker_id"]
|
||||
worker_name = self._session.config["agent.worker_name"]
|
||||
if not worker_id and os.environ.get('NVIDIA_VISIBLE_DEVICES') is not None:
|
||||
nvidia_visible_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
||||
if nvidia_visible_devices and nvidia_visible_devices.lower() != 'none':
|
||||
worker_id = '{}:gpu{}'.format(worker_name, nvidia_visible_devices)
|
||||
elif nvidia_visible_devices == '':
|
||||
pass
|
||||
else:
|
||||
worker_name = '{}:cpu'.format(worker_name)
|
||||
worker_id, worker_name = self._generate_worker_id_name()
|
||||
|
||||
# if we are running in services mode, we allow double register since
|
||||
# docker-compose will kill instances before they cleanup
|
||||
@@ -2348,6 +2564,19 @@ class Worker(ServiceCommandSection):
|
||||
# update folders based on free slot
|
||||
self._session.create_cache_folders(slot_index=worker_slot)
|
||||
|
||||
def _generate_worker_id_name(self):
|
||||
worker_id = self._session.config["agent.worker_id"]
|
||||
worker_name = self._session.config["agent.worker_name"]
|
||||
if not worker_id and os.environ.get('NVIDIA_VISIBLE_DEVICES') is not None:
|
||||
nvidia_visible_devices = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
||||
if nvidia_visible_devices and nvidia_visible_devices.lower() != 'none':
|
||||
worker_id = '{}:gpu{}'.format(worker_name, nvidia_visible_devices)
|
||||
elif nvidia_visible_devices == '':
|
||||
pass
|
||||
else:
|
||||
worker_name = '{}:cpu'.format(worker_name)
|
||||
return worker_id, worker_name
|
||||
|
||||
def _resolve_queue_names(self, queues, create_if_missing=False):
|
||||
if not queues:
|
||||
default_queue = self._session.send_api(queues_api.GetDefaultRequest())
|
||||
|
||||
@@ -122,6 +122,7 @@ PIP_EXTRA_INDICES = [
|
||||
DEFAULT_PIP_DOWNLOAD_CACHE = normalize_path(CONFIG_DIR, "pip-download-cache")
|
||||
ENV_AGENT_GIT_USER = EnvironmentConfig('TRAINS_AGENT_GIT_USER')
|
||||
ENV_AGENT_GIT_PASS = EnvironmentConfig('TRAINS_AGENT_GIT_PASS')
|
||||
ENV_AGENT_GIT_HOST = EnvironmentConfig('TRAINS_AGENT_GIT_HOST')
|
||||
ENV_TASK_EXECUTE_AS_USER = 'TRAINS_AGENT_EXEC_USER'
|
||||
ENV_TASK_EXTRA_PYTHON_PATH = 'TRAINS_AGENT_EXTRA_PYTHON_PATH'
|
||||
ENV_DOCKER_HOST_MOUNT = EnvironmentConfig('TRAINS_AGENT_K8S_HOST_MOUNT', 'TRAINS_AGENT_DOCKER_HOST_MOUNT')
|
||||
|
||||
0
trains_agent/external/__init__.py
vendored
Normal file
0
trains_agent/external/__init__.py
vendored
Normal file
22
trains_agent/external/requirements_parser/__init__.py
vendored
Normal file
22
trains_agent/external/requirements_parser/__init__.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
from .parser import parse # noqa
|
||||
|
||||
_MAJOR = 0
|
||||
_MINOR = 2
|
||||
_PATCH = 0
|
||||
|
||||
|
||||
def version_tuple():
|
||||
'''
|
||||
Returns a 3-tuple of ints that represent the version
|
||||
'''
|
||||
return (_MAJOR, _MINOR, _PATCH)
|
||||
|
||||
|
||||
def version():
|
||||
'''
|
||||
Returns a string representation of the version
|
||||
'''
|
||||
return '%d.%d.%d' % (version_tuple())
|
||||
|
||||
|
||||
__version__ = version()
|
||||
44
trains_agent/external/requirements_parser/fragment.py
vendored
Normal file
44
trains_agent/external/requirements_parser/fragment.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
import re
|
||||
|
||||
# Copied from pip
|
||||
# https://github.com/pypa/pip/blob/281eb61b09d87765d7c2b92f6982b3fe76ccb0af/pip/index.py#L947
|
||||
HASH_ALGORITHMS = set(['sha1', 'sha224', 'sha384', 'sha256', 'sha512', 'md5'])
|
||||
|
||||
extras_require_search = re.compile(
|
||||
r'(?P<name>.+)\[(?P<extras>[^\]]+)\]').search
|
||||
|
||||
|
||||
def parse_fragment(fragment_string):
|
||||
"""Takes a fragment string nd returns a dict of the components"""
|
||||
fragment_string = fragment_string.lstrip('#')
|
||||
|
||||
try:
|
||||
return dict(
|
||||
key_value_string.split('=')
|
||||
for key_value_string in fragment_string.split('&')
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
'Invalid fragment string {fragment_string}'.format(
|
||||
fragment_string=fragment_string
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_hash_info(d):
|
||||
"""Returns the first matching hashlib name and value from a dict"""
|
||||
for key in d.keys():
|
||||
if key.lower() in HASH_ALGORITHMS:
|
||||
return key, d[key]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def parse_extras_require(egg):
|
||||
if egg is not None:
|
||||
match = extras_require_search(egg)
|
||||
if match is not None:
|
||||
name = match.group('name')
|
||||
extras = match.group('extras')
|
||||
return name, [extra.strip() for extra in extras.split(',')]
|
||||
return egg, []
|
||||
50
trains_agent/external/requirements_parser/parser.py
vendored
Normal file
50
trains_agent/external/requirements_parser/parser.py
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from .requirement import Requirement
|
||||
|
||||
|
||||
def parse(reqstr):
|
||||
"""
|
||||
Parse a requirements file into a list of Requirements
|
||||
|
||||
See: pip/req.py:parse_requirements()
|
||||
|
||||
:param reqstr: a string or file like object containing requirements
|
||||
:returns: a *generator* of Requirement objects
|
||||
"""
|
||||
filename = getattr(reqstr, 'name', None)
|
||||
try:
|
||||
# Python 2.x compatibility
|
||||
if not isinstance(reqstr, basestring):
|
||||
reqstr = reqstr.read()
|
||||
except NameError:
|
||||
# Python 3.x only
|
||||
if not isinstance(reqstr, str):
|
||||
reqstr = reqstr.read()
|
||||
|
||||
for line in reqstr.splitlines():
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
continue
|
||||
elif not line or line.startswith('#'):
|
||||
# comments are lines that start with # only
|
||||
continue
|
||||
elif line.startswith('-r') or line.startswith('--requirement'):
|
||||
_, new_filename = line.split()
|
||||
new_file_path = os.path.join(os.path.dirname(filename or '.'),
|
||||
new_filename)
|
||||
with open(new_file_path) as f:
|
||||
for requirement in parse(f):
|
||||
yield requirement
|
||||
elif line.startswith('-f') or line.startswith('--find-links') or \
|
||||
line.startswith('-i') or line.startswith('--index-url') or \
|
||||
line.startswith('--extra-index-url') or \
|
||||
line.startswith('--no-index'):
|
||||
warnings.warn('Private repos not supported. Skipping.')
|
||||
continue
|
||||
elif line.startswith('-Z') or line.startswith('--always-unzip'):
|
||||
warnings.warn('Unused option --always-unzip. Skipping.')
|
||||
continue
|
||||
else:
|
||||
yield Requirement.parse(line)
|
||||
241
trains_agent/external/requirements_parser/requirement.py
vendored
Normal file
241
trains_agent/external/requirements_parser/requirement.py
vendored
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import unicode_literals
|
||||
import re
|
||||
from pkg_resources import Requirement as Req
|
||||
|
||||
from .fragment import get_hash_info, parse_fragment, parse_extras_require
|
||||
from .vcs import VCS, VCS_SCHEMES
|
||||
|
||||
|
||||
URI_REGEX = re.compile(
|
||||
r'^(?P<scheme>https?|file|ftps?)://(?P<path>[^#]+)'
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
VCS_REGEX = re.compile(
|
||||
r'^(?P<scheme>{0})://'.format(r'|'.join(
|
||||
[scheme.replace('+', r'\+') for scheme in VCS_SCHEMES])) +
|
||||
r'((?P<login>[^/@]+)@)?'
|
||||
r'(?P<path>[^#@]+)'
|
||||
r'(@(?P<revision>[^#]+))?'
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
# This matches just about everyting
|
||||
LOCAL_REGEX = re.compile(
|
||||
r'^((?P<scheme>file)://)?'
|
||||
r'(?P<path>[^#]+)' +
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
|
||||
class Requirement(object):
|
||||
"""
|
||||
Represents a single requirementfrom trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
Typically instances of this class are created with ``Requirement.parse``.
|
||||
For local file requirements, there's no verification that the file
|
||||
exists. This class attempts to be *dict-like*.
|
||||
|
||||
See: http://www.pip-installer.org/en/latest/logic.html
|
||||
|
||||
**Members**:
|
||||
|
||||
* ``line`` - the actual requirement line being parsed
|
||||
* ``editable`` - a boolean whether this requirement is "editable"
|
||||
* ``local_file`` - a boolean whether this requirement is a local file/path
|
||||
* ``specifier`` - a boolean whether this requirement used a requirement
|
||||
specifier (eg. "django>=1.5" or "requirements")
|
||||
* ``vcs`` - a string specifying the version control system
|
||||
* ``revision`` - a version control system specifier
|
||||
* ``name`` - the name of the requirement
|
||||
* ``uri`` - the URI if this requirement was specified by URI
|
||||
* ``subdirectory`` - the subdirectory fragment of the URI
|
||||
* ``path`` - the local path to the requirement
|
||||
* ``hash_name`` - the type of hashing algorithm indicated in the line
|
||||
* ``hash`` - the hash value indicated by the requirement line
|
||||
* ``extras`` - a list of extras for this requirement
|
||||
(eg. "mymodule[extra1, extra2]")
|
||||
* ``specs`` - a list of specs for this requirement
|
||||
(eg. "mymodule>1.5,<1.6" => [('>', '1.5'), ('<', '1.6')])
|
||||
"""
|
||||
|
||||
def __init__(self, line):
|
||||
# Do not call this private method
|
||||
self.line = line
|
||||
self.editable = False
|
||||
self.local_file = False
|
||||
self.specifier = False
|
||||
self.vcs = None
|
||||
self.name = None
|
||||
self.subdirectory = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.revision = None
|
||||
self.hash_name = None
|
||||
self.hash = None
|
||||
self.extras = []
|
||||
self.specs = []
|
||||
|
||||
def __repr__(self):
|
||||
return '<Requirement: "{0}">'.format(self.line)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
@classmethod
|
||||
def parse_editable(cls, line):
|
||||
"""
|
||||
Parses a Requirement from an "editable" requirement which is either
|
||||
a local project path or a VCS project URI.
|
||||
|
||||
See: pip/req.py:from_editable()
|
||||
|
||||
:param line: an "editable" requirement
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
|
||||
req = cls('-e {0}'.format(line))
|
||||
req.editable = True
|
||||
vcs_match = VCS_REGEX.match(line)
|
||||
local_match = LOCAL_REGEX.match(line)
|
||||
|
||||
if vcs_match is not None:
|
||||
groups = vcs_match.groupdict()
|
||||
if groups.get('login'):
|
||||
req.uri = '{scheme}://{login}@{path}'.format(**groups)
|
||||
else:
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
req.revision = groups['revision']
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
for vcs in VCS:
|
||||
if req.uri.startswith(vcs):
|
||||
req.vcs = vcs
|
||||
else:
|
||||
assert local_match is not None, 'This should match everything'
|
||||
groups = local_match.groupdict()
|
||||
req.local_file = True
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
req.path = groups['path']
|
||||
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
def parse_line(cls, line):
|
||||
"""
|
||||
Parses a Requirement from a non-editable requirement.
|
||||
|
||||
See: pip/req.py:from_line()
|
||||
|
||||
:param line: a "non-editable" requirement
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
|
||||
req = cls(line)
|
||||
|
||||
vcs_match = VCS_REGEX.match(line)
|
||||
uri_match = URI_REGEX.match(line)
|
||||
local_match = LOCAL_REGEX.match(line)
|
||||
|
||||
if vcs_match is not None:
|
||||
groups = vcs_match.groupdict()
|
||||
if groups.get('login'):
|
||||
req.uri = '{scheme}://{login}@{path}'.format(**groups)
|
||||
else:
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
req.revision = groups['revision']
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
for vcs in VCS:
|
||||
if req.uri.startswith(vcs):
|
||||
req.vcs = vcs
|
||||
elif uri_match is not None:
|
||||
groups = uri_match.groupdict()
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
if groups['scheme'] == 'file':
|
||||
req.local_file = True
|
||||
elif '#egg=' in line:
|
||||
# Assume a local file match
|
||||
assert local_match is not None, 'This should match everything'
|
||||
groups = local_match.groupdict()
|
||||
req.local_file = True
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
name, extras = parse_extras_require(egg)
|
||||
req.name = fragment.get('egg')
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
req.path = groups['path']
|
||||
else:
|
||||
# This is a requirement specifier.
|
||||
# Delegate to pkg_resources and hope for the best
|
||||
req.specifier = True
|
||||
pkg_req = Req.parse(line)
|
||||
req.name = pkg_req.unsafe_name
|
||||
req.extras = list(pkg_req.extras)
|
||||
req.specs = pkg_req.specs
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line):
|
||||
"""
|
||||
Parses a Requirement from a line of a requirement file.
|
||||
|
||||
:param line: a line of a requirement file
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
line = line.lstrip()
|
||||
if line.startswith('-e') or line.startswith('--editable'):
|
||||
# Editable installs are either a local project path
|
||||
# or a VCS project URI
|
||||
return cls.parse_editable(
|
||||
re.sub(r'^(-e|--editable=?)\s*', '', line))
|
||||
elif '@' in line and ('#' not in line or line.index('#') > line.index('@')):
|
||||
# Allegro bug fix: support 'name @ git+' entries
|
||||
name, uri = line.split('@', 1)
|
||||
name = name.strip()
|
||||
uri = uri.strip()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# check if the name is valid & parsed
|
||||
Req.parse(name)
|
||||
# if we are here, name is a valid package name, check if the vcs part is valid
|
||||
if VCS_REGEX.match(uri):
|
||||
req = cls.parse_line(uri)
|
||||
req.name = name
|
||||
return req
|
||||
elif URI_REGEX.match(uri):
|
||||
req = cls.parse_line(uri)
|
||||
req.name = name
|
||||
req.line = line
|
||||
return req
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return cls.parse_line(line)
|
||||
30
trains_agent/external/requirements_parser/vcs.py
vendored
Normal file
30
trains_agent/external/requirements_parser/vcs.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
VCS = [
|
||||
'git',
|
||||
'hg',
|
||||
'svn',
|
||||
'bzr',
|
||||
]
|
||||
|
||||
VCS_SCHEMES = [
|
||||
'git',
|
||||
'git+https',
|
||||
'git+ssh',
|
||||
'git+git',
|
||||
'hg+http',
|
||||
'hg+https',
|
||||
'hg+static-http',
|
||||
'hg+ssh',
|
||||
'svn',
|
||||
'svn+svn',
|
||||
'svn+http',
|
||||
'svn+https',
|
||||
'svn+ssh',
|
||||
'bzr+http',
|
||||
'bzr+https',
|
||||
'bzr+ssh',
|
||||
'bzr+sftp',
|
||||
'bzr+ftp',
|
||||
'bzr+lp',
|
||||
]
|
||||
@@ -1,45 +1,114 @@
|
||||
from __future__ import print_function, division, unicode_literals
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import Text, List
|
||||
|
||||
from pyhocon import HOCONConverter
|
||||
|
||||
from trains_agent.commands.events import Events
|
||||
from trains_agent.commands.worker import Worker
|
||||
from trains_agent.errors import APIError
|
||||
from trains_agent.helper.base import safe_remove_file
|
||||
from trains_agent.helper.dicts import merge_dicts
|
||||
from trains_agent.helper.process import get_bash_output
|
||||
from trains_agent.helper.resource_monitor import ResourceMonitor
|
||||
from trains_agent.interface.base import ObjectID
|
||||
|
||||
|
||||
class K8sIntegration(Worker):
|
||||
K8S_PENDING_QUEUE = "k8s_scheduler"
|
||||
|
||||
KUBECTL_RUN_CMD = "kubectl run trains_id_{task_id} " \
|
||||
KUBECTL_APPLY_CMD = "kubectl apply -f"
|
||||
|
||||
KUBECTL_RUN_CMD = "kubectl run trains-id-{task_id} " \
|
||||
"--image {docker_image} " \
|
||||
"--restart=Never --replicas=1 " \
|
||||
"--generator=run-pod/v1"
|
||||
"--generator=run-pod/v1 " \
|
||||
"--namespace=trains"
|
||||
|
||||
KUBECTL_DELETE_CMD = "kubectl delete pods " \
|
||||
"--selector=TRAINS=agent " \
|
||||
"--field-selector=status.phase!=Pending,status.phase!=Running"
|
||||
"--field-selector=status.phase!=Pending,status.phase!=Running " \
|
||||
"--namespace=trains"
|
||||
|
||||
CONTAINER_BASH_SCRIPT = "apt-get install -y git python-pip && " \
|
||||
"pip install trains-agent && " \
|
||||
"python -u -m trains_agent execute --full-monitoring --require-queue --id {}"
|
||||
BASH_INSTALL_SSH_CMD = [
|
||||
"apt-get install -y openssh-server",
|
||||
"mkdir -p /var/run/sshd",
|
||||
"echo 'root:training' | chpasswd",
|
||||
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config",
|
||||
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config",
|
||||
r"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd",
|
||||
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY' >> /etc/ssh/sshd_config",
|
||||
'echo "export VISIBLE=now" >> /etc/profile',
|
||||
'echo "export PATH=$PATH" >> /etc/profile',
|
||||
'echo "ldconfig" >> /etc/profile',
|
||||
"/usr/sbin/sshd -p {port}"]
|
||||
|
||||
def __init__(self, k8s_pending_queue_name=None, kubectl_cmd=None, container_bash_script=None, debug=False):
|
||||
CONTAINER_BASH_SCRIPT = [
|
||||
"export DEBIAN_FRONTEND='noninteractive'",
|
||||
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
"chown -R root /root/.cache/pip",
|
||||
"apt-get update",
|
||||
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
"declare LOCAL_PYTHON",
|
||||
"for i in {{10..5}}; do which python3.$i && python3.$i -m pip --version && "
|
||||
"export LOCAL_PYTHON=$(which python3.$i) && break ; done",
|
||||
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip",
|
||||
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3",
|
||||
"$LOCAL_PYTHON -m pip install trains-agent",
|
||||
"{extra_bash_init_cmd}",
|
||||
"$LOCAL_PYTHON -m trains_agent execute --full-monitoring --require-queue --id {task_id}"
|
||||
]
|
||||
|
||||
AGENT_LABEL = "TRAINS=agent"
|
||||
LIMIT_POD_LABEL = "ai.allegro.agent.serial=pod-{pod_number}"
|
||||
|
||||
_edit_hyperparams_version = "2.9"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k8s_pending_queue_name=None,
|
||||
kubectl_cmd=None,
|
||||
container_bash_script=None,
|
||||
debug=False,
|
||||
ports_mode=False,
|
||||
num_of_services=20,
|
||||
user_props_cb=None,
|
||||
overrides_yaml=None,
|
||||
template_yaml=None,
|
||||
trains_conf_file=None,
|
||||
extra_bash_init_script=None,
|
||||
):
|
||||
"""
|
||||
Initialize the k8s integration glue layer daemon
|
||||
|
||||
:param str k8s_pending_queue_name: queue name to use when task is pending in the k8s scheduler
|
||||
:param str|callable kubectl_cmd: kubectl command line str, supports formating (default: KUBECTL_RUN_CMD)
|
||||
:param str|callable kubectl_cmd: kubectl command line str, supports formatting (default: KUBECTL_RUN_CMD)
|
||||
example: "task={task_id} image={docker_image} queue_id={queue_id}"
|
||||
or a callable function: kubectl_cmd(task_id, docker_image, queue_id, task_data)
|
||||
:param str container_bash_script: container bash script to be executed in k8s (default: CONTAINER_BASH_SCRIPT)
|
||||
Notice this string will use format() call, if you have curly brackets they should be doubled { -> {{
|
||||
Format arguments passed: {task_id} and {extra_bash_init_cmd}
|
||||
:param bool debug: Switch logging on
|
||||
:param bool ports_mode: Adds a label to each pod which can be used in services in order to expose ports.
|
||||
Requires the `num_of_services` parameter.
|
||||
:param int num_of_services: Number of k8s services configured in the cluster. Required if `port_mode` is True.
|
||||
(default: 20)
|
||||
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
|
||||
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
|
||||
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
|
||||
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
|
||||
:param str template_yaml: YAML file containing the template for the pod (optional).
|
||||
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
|
||||
:param str trains_conf_file: trains.conf file to be use by the pod itself (optional)
|
||||
:param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container
|
||||
"""
|
||||
super(K8sIntegration, self).__init__()
|
||||
self.k8s_pending_queue_name = k8s_pending_queue_name or self.K8S_PENDING_QUEUE
|
||||
@@ -51,12 +120,88 @@ class K8sIntegration(Worker):
|
||||
if debug:
|
||||
self.log.logger.disabled = False
|
||||
self.log.logger.setLevel(logging.INFO)
|
||||
self.ports_mode = ports_mode
|
||||
self.num_of_services = num_of_services
|
||||
self._edit_hyperparams_support = None
|
||||
self._user_props_cb = user_props_cb
|
||||
self.trains_conf_file = None
|
||||
self.overrides_json_string = None
|
||||
self.template_dict = None
|
||||
self.extra_bash_init_script = extra_bash_init_script or None
|
||||
if self.extra_bash_init_script and not isinstance(self.extra_bash_init_script, str):
|
||||
self.extra_bash_init_script = ' ; '.join(self.extra_bash_init_script) # noqa
|
||||
self.pod_limits = []
|
||||
self.pod_requests = []
|
||||
if overrides_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(overrides_yaml))), 'rt') as f:
|
||||
overrides = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
if overrides:
|
||||
containers = overrides.get('spec', {}).get('containers', [])
|
||||
for c in containers:
|
||||
resources = {str(k).lower(): v for k, v in c.get('resources', {}).items()}
|
||||
if not resources:
|
||||
continue
|
||||
if resources.get('limits'):
|
||||
self.pod_limits += ['{}={}'.format(k, v) for k, v in resources['limits'].items()]
|
||||
if resources.get('requests'):
|
||||
self.pod_requests += ['{}={}'.format(k, v) for k, v in resources['requests'].items()]
|
||||
# remove double entries
|
||||
self.pod_limits = list(set(self.pod_limits))
|
||||
self.pod_requests = list(set(self.pod_requests))
|
||||
if self.pod_limits or self.pod_requests:
|
||||
self.log.warning('Found pod container requests={} limits={}'.format(
|
||||
self.pod_limits, self.pod_requests))
|
||||
if containers:
|
||||
self.log.warning('Removing containers section: {}'.format(overrides['spec'].pop('containers')))
|
||||
self.overrides_json_string = json.dumps(overrides)
|
||||
if template_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(template_yaml))), 'rt') as f:
|
||||
self.template_dict = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
|
||||
def run_one_task(self, queue: Text, task_id: Text, worker_args=None):
|
||||
if trains_conf_file:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(trains_conf_file))), 'rt') as f:
|
||||
self.trains_conf_file = f.read()
|
||||
# make sure we use system packages!
|
||||
self.trains_conf_file += '\nagent.package_manager.system_site_packages=true\n'
|
||||
|
||||
def _set_task_user_properties(self, task_id: str, **properties: str):
|
||||
if self._edit_hyperparams_support is not True:
|
||||
# either not supported or never tested
|
||||
if self._edit_hyperparams_support == self._session.api_version:
|
||||
# tested against latest api_version, not supported
|
||||
return
|
||||
if not self._session.check_min_api_version(self._edit_hyperparams_version):
|
||||
# not supported due to insufficient api_version
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
return
|
||||
try:
|
||||
self._session.get(
|
||||
service="tasks",
|
||||
action="edit_hyper_params",
|
||||
task=task_id,
|
||||
hyperparams=[
|
||||
{
|
||||
"section": "properties",
|
||||
"name": k,
|
||||
"value": str(v),
|
||||
}
|
||||
for k, v in properties.items()
|
||||
],
|
||||
)
|
||||
# definitely supported
|
||||
self._runtime_props_support = True
|
||||
except APIError as error:
|
||||
if error.code == 404:
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
|
||||
def run_one_task(self, queue: Text, task_id: Text, worker_args=None, **_):
|
||||
print('Pulling task {} launching on kubernetes cluster'.format(task_id))
|
||||
task_data = self._session.api_client.tasks.get_all(id=[task_id])[0]
|
||||
|
||||
# push task into the k8s queue, so we have visibility on pending tasks in the k8s scheduler
|
||||
try:
|
||||
print('Pushing task {} into temporary pending queue'.format(task_id))
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(task_id, queue=self.k8s_pending_queue_name,
|
||||
status_reason='k8s pending scheduler')
|
||||
except Exception as e:
|
||||
@@ -65,36 +210,215 @@ class K8sIntegration(Worker):
|
||||
return
|
||||
|
||||
if task_data.execution.docker_cmd:
|
||||
docker_image = task_data.execution.docker_cmd
|
||||
docker_parts = task_data.execution.docker_cmd
|
||||
else:
|
||||
docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
|
||||
docker_parts = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
|
||||
self._session.config.get("agent.default_docker.image", "nvidia/cuda"))
|
||||
|
||||
# take the first part, this is the docker image name (not arguments)
|
||||
docker_image = docker_image.split()[0]
|
||||
docker_parts = docker_parts.split()
|
||||
docker_image = docker_parts[0]
|
||||
docker_args = docker_parts[1:] if len(docker_parts) > 1 else []
|
||||
|
||||
create_trains_conf = "echo '{}' >> ~/trains.conf && ".format(
|
||||
HOCONConverter.to_hocon(self._session.config._config))
|
||||
# get the trains.conf encoded file
|
||||
# noinspection PyProtectedMember
|
||||
hocon_config_encoded = (self.trains_conf_file or self._session._config_file).encode('ascii')
|
||||
create_trains_conf = "echo '{}' | base64 --decode >> ~/trains.conf".format(
|
||||
base64.b64encode(
|
||||
hocon_config_encoded
|
||||
).decode('ascii')
|
||||
)
|
||||
|
||||
if callable(self.kubectl_cmd):
|
||||
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data)
|
||||
if self.ports_mode:
|
||||
print("Kubernetes looking for available pod to use")
|
||||
|
||||
# Search for a free pod number
|
||||
pod_number = 1
|
||||
while self.ports_mode:
|
||||
kubectl_cmd_new = "kubectl get pods -l {pod_label},{agent_label} -n trains".format(
|
||||
pod_label=self.LIMIT_POD_LABEL.format(pod_number=pod_number),
|
||||
agent_label=self.AGENT_LABEL
|
||||
)
|
||||
process = subprocess.Popen(kubectl_cmd_new.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
output = '' if not output else output if isinstance(output, str) else output.decode('utf-8')
|
||||
error = '' if not error else error if isinstance(error, str) else error.decode('utf-8')
|
||||
|
||||
if not output:
|
||||
# No such pod exist so we can use the pod_number we found
|
||||
break
|
||||
if pod_number >= self.num_of_services:
|
||||
# All pod numbers are taken, exit
|
||||
self.log.warning(
|
||||
"kubectl last result: {}\n{}\nAll k8s services are in use, task '{}' "
|
||||
"will be enqueued back to queue '{}'".format(
|
||||
error, output, task_id, queue
|
||||
)
|
||||
)
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(task_id, queue=queue)
|
||||
return
|
||||
pod_number += 1
|
||||
|
||||
labels = ([self.LIMIT_POD_LABEL.format(pod_number=pod_number)] if self.ports_mode else []) + [self.AGENT_LABEL]
|
||||
|
||||
if self.ports_mode:
|
||||
print("Kubernetes scheduling task id={} on pod={}".format(task_id, pod_number))
|
||||
else:
|
||||
kubectl_cmd = self.kubectl_cmd.format(task_id=task_id, docker_image=docker_image, queue_id=queue)
|
||||
print("Kubernetes scheduling task id={}".format(task_id))
|
||||
|
||||
# make sure we gave a list
|
||||
if self.template_dict:
|
||||
output, error = self._kubectl_apply(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image, docker_args=docker_args,
|
||||
task_id=task_id, queue=queue)
|
||||
else:
|
||||
output, error = self._kubectl_run(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image,
|
||||
task_data=task_data,
|
||||
task_id=task_id, queue=queue)
|
||||
|
||||
error = '' if not error else (error if isinstance(error, str) else error.decode('utf-8'))
|
||||
output = '' if not output else (output if isinstance(output, str) else output.decode('utf-8'))
|
||||
print('kubectl output:\n{}\n{}'.format(error, output))
|
||||
|
||||
if error:
|
||||
self.log.error("Running kubectl encountered an error: {}".format(error))
|
||||
elif self.ports_mode:
|
||||
user_props = {"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]}
|
||||
if self._user_props_cb:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb()
|
||||
user_props.update(custom_props)
|
||||
except Exception:
|
||||
pass
|
||||
self._set_task_user_properties(
|
||||
task_id=task_id,
|
||||
**user_props
|
||||
)
|
||||
|
||||
def _parse_docker_args(self, docker_args):
|
||||
# type: (list) -> dict
|
||||
kube_args = {'env': []}
|
||||
while docker_args:
|
||||
cmd = docker_args.pop().strip()
|
||||
if cmd in ('-e', '--env',):
|
||||
env = docker_args.pop().strip()
|
||||
key, value = env.split('=', 1)
|
||||
kube_args[key] += {key: value}
|
||||
else:
|
||||
self.log.warning('skipping docker argument {} (only -e --env supported)'.format(cmd))
|
||||
return kube_args
|
||||
|
||||
def _kubectl_apply(self, create_trains_conf, docker_image, docker_args, labels, queue, task_id):
|
||||
template = deepcopy(self.template_dict)
|
||||
template.setdefault('apiVersion', 'v1')
|
||||
template['kind'] = 'Pod'
|
||||
template.setdefault('metadata', {})
|
||||
name = 'trains-id-{task_id}'.format(task_id=task_id)
|
||||
template['metadata']['name'] = name
|
||||
template.setdefault('spec', {})
|
||||
template['spec'].setdefault('containers', [])
|
||||
if labels:
|
||||
labels_dict = dict(pair.split('=', 1) for pair in labels)
|
||||
template['metadata'].setdefault('labels', {})
|
||||
template['metadata']['labels'].update(labels_dict)
|
||||
container = self._parse_docker_args(docker_args)
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
|
||||
script_encoded = '\n'.join(
|
||||
['#!/bin/bash', ] +
|
||||
[line.format(extra_bash_init_cmd=self.extra_bash_init_script or '', task_id=task_id)
|
||||
for line in container_bash_script])
|
||||
|
||||
create_init_script = \
|
||||
"echo '{}' | base64 --decode >> ~/__start_agent__.sh ; " \
|
||||
"/bin/bash ~/__start_agent__.sh".format(
|
||||
base64.b64encode(
|
||||
script_encoded.encode('ascii')
|
||||
).decode('ascii'))
|
||||
|
||||
container = merge_dicts(
|
||||
container,
|
||||
dict(name=name, image=docker_image,
|
||||
command=['/bin/bash'],
|
||||
args=['-c', '{} ; {}'.format(create_trains_conf, create_init_script)])
|
||||
)
|
||||
|
||||
if template['spec']['containers']:
|
||||
template['spec']['containers'][0] = merge_dicts(template['spec']['containers'][0], container)
|
||||
else:
|
||||
template['spec']['containers'].append(container)
|
||||
|
||||
fp, yaml_file = tempfile.mkstemp(prefix='trains_k8stmpl_', suffix='.yml')
|
||||
os.close(fp)
|
||||
with open(yaml_file, 'wt') as f:
|
||||
yaml.dump(template, f)
|
||||
|
||||
kubectl_cmd = self.KUBECTL_APPLY_CMD.format(
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue,
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
kubectl_cmd += ["--labels=TRAINS=agent", "--command", "--", "/bin/sh", "-c",
|
||||
create_trains_conf + self.container_bash_script.format(task_id)]
|
||||
# add the template file at the end
|
||||
kubectl_cmd += [yaml_file]
|
||||
try:
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
except Exception as ex:
|
||||
return None, str(ex)
|
||||
finally:
|
||||
safe_remove_file(yaml_file)
|
||||
|
||||
return output, error
|
||||
|
||||
def _kubectl_run(self, create_trains_conf, docker_image, labels, queue, task_data, task_id):
|
||||
if callable(self.kubectl_cmd):
|
||||
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data)
|
||||
else:
|
||||
kubectl_cmd = self.kubectl_cmd.format(
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
if self.overrides_json_string:
|
||||
kubectl_cmd += ['--overrides=' + self.overrides_json_string]
|
||||
|
||||
if self.pod_limits:
|
||||
kubectl_cmd += ['--limits', ",".join(self.pod_limits)]
|
||||
if self.pod_requests:
|
||||
kubectl_cmd += ['--requests', ",".join(self.pod_requests)]
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
container_bash_script = ' ; '.join(container_bash_script)
|
||||
|
||||
kubectl_cmd += [
|
||||
"--labels=" + ",".join(labels),
|
||||
"--command",
|
||||
"--",
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"{} ; {}".format(create_trains_conf, container_bash_script.format(
|
||||
extra_bash_init_cmd=self.extra_bash_init_script, task_id=task_id)),
|
||||
]
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
self.log.info("K8s scheduling experiment task id={}".format(task_id))
|
||||
if error:
|
||||
self.log.error("Running kubectl encountered an error: {}".format(
|
||||
error if isinstance(error, str) else error.decode()))
|
||||
return output, error
|
||||
|
||||
def run_tasks_loop(self, queues: List[Text], worker_params):
|
||||
def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
|
||||
"""
|
||||
:summary: Pull and run tasks from queues.
|
||||
:description: 1. Go through ``queues`` by order.
|
||||
@@ -108,6 +432,7 @@ class K8sIntegration(Worker):
|
||||
events_service = self.get_service(Events)
|
||||
|
||||
# make sure we have a k8s pending queue
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._session.api_client.queues.create(self.k8s_pending_queue_name)
|
||||
except Exception:
|
||||
@@ -119,7 +444,7 @@ class K8sIntegration(Worker):
|
||||
while True:
|
||||
# iterate over queues (priority style, queues[0] is highest)
|
||||
for queue in queues:
|
||||
# delete old completed /failed pods
|
||||
# delete old completed / failed pods
|
||||
get_bash_output(self.KUBECTL_DELETE_CMD)
|
||||
|
||||
# get next task in queue
|
||||
@@ -155,15 +480,20 @@ class K8sIntegration(Worker):
|
||||
if self._session.config["agent.reload_config"]:
|
||||
self.reload_config()
|
||||
|
||||
def k8s_daemon(self, queues):
|
||||
def k8s_daemon(self, queue):
|
||||
"""
|
||||
Start the k8s Glue service.
|
||||
This service will be pulling tasks from *queues* and scheduling them for execution using kubectl.
|
||||
This service will be pulling tasks from *queue* and scheduling them for execution using kubectl.
|
||||
Notice all scheduled tasks are pushed back into K8S_PENDING_QUEUE,
|
||||
and popped when execution actually starts. This creates full visibility into the k8s scheduler.
|
||||
Manually popping a task from the K8S_PENDING_QUEUE,
|
||||
will cause the k8s scheduler to skip the execution once the scheduled tasks needs to be executed
|
||||
|
||||
:param list(str) queues: List of queue names to pull from
|
||||
:param list(str) queue: queue name to pull from
|
||||
"""
|
||||
return self.daemon(queues=queues, log_level=logging.INFO, foreground=True, docker=False)
|
||||
return self.daemon(queues=[ObjectID(name=queue)] if queue else None,
|
||||
log_level=logging.INFO, foreground=True, docker=False)
|
||||
|
||||
@classmethod
|
||||
def get_ssh_server_bash(cls, ssh_port_number):
|
||||
return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD)
|
||||
|
||||
@@ -173,14 +173,32 @@ def normalize_path(*paths):
|
||||
|
||||
|
||||
def safe_remove_file(filename, error_message=None):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(filename)
|
||||
if filename:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
if error_message:
|
||||
print(error_message)
|
||||
|
||||
|
||||
def get_python_path(script_dir, entry_point, package_api):
|
||||
def safe_remove_tree(filename):
|
||||
if not filename:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
shutil.rmtree(filename, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_python_path(script_dir, entry_point, package_api, is_conda_env=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
python_path_cmd = package_api.get_python_command(
|
||||
@@ -192,9 +210,9 @@ def get_python_path(script_dir, entry_point, package_api):
|
||||
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
|
||||
python_path_sep=python_path_sep)
|
||||
if is_windows_platform():
|
||||
return python_path.replace('/', '\\') + org_python_path
|
||||
python_path = python_path.replace('/', '\\')
|
||||
|
||||
return python_path + org_python_path
|
||||
return python_path if is_conda_env else (python_path + org_python_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -442,9 +460,9 @@ def chain_map(*args):
|
||||
return reduce(lambda x, y: x.update(y) or x, args, {})
|
||||
|
||||
|
||||
def check_directory_path(path):
|
||||
def check_directory_path(path, check_whitespace_in_path=True):
|
||||
message = 'Could not create directory "{}": {}'
|
||||
if not is_windows_platform():
|
||||
if not is_windows_platform() and check_whitespace_in_path:
|
||||
match = re.search(r'\s', path)
|
||||
if match:
|
||||
raise CommandFailedError(
|
||||
@@ -537,6 +555,7 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
branch = nullable_string
|
||||
version_num = nullable_string
|
||||
tag = nullable_string
|
||||
docker_cmd = nullable_string
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task_info):
|
||||
@@ -554,6 +573,12 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
execution.entry_point = entry_point
|
||||
execution.working_dir = working_dir or ""
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
execution.docker_cmd = task_info.execution.docker_cmd
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from attr import attrs, attrib
|
||||
|
||||
import six
|
||||
from six import binary_type, text_type
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort
|
||||
|
||||
|
||||
def print_text(text, newline=True):
|
||||
@@ -22,15 +22,21 @@ def print_text(text, newline=True):
|
||||
sys.stdout.write(data)
|
||||
|
||||
|
||||
def decode_binary_lines(binary_lines, encoding='utf-8'):
|
||||
def decode_binary_lines(binary_lines, encoding='utf-8', replace_cr=False, overwrite_cr=False):
|
||||
# decode per line, if we failed decoding skip the line
|
||||
lines = []
|
||||
for b in binary_lines:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
l = b.decode(encoding=encoding, errors='replace').replace('\r', '\n')
|
||||
except:
|
||||
l = ''
|
||||
lines.append(l + '\n' if l and l[-1] != '\n' else l)
|
||||
line = b.decode(encoding=encoding, errors='replace')
|
||||
if replace_cr:
|
||||
line = line.replace('\r', '\n')
|
||||
elif overwrite_cr:
|
||||
cr_lines = line.split('\r')
|
||||
line = cr_lines[-1] if cr_lines[-1] or len(cr_lines) < 2 else cr_lines[-2]
|
||||
except Exception:
|
||||
line = ''
|
||||
lines.append(line + '\n' if not line or line[-1] != '\n' else line)
|
||||
return lines
|
||||
|
||||
|
||||
|
||||
@@ -3,3 +3,15 @@ from typing import Callable, Dict, Any
|
||||
|
||||
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
|
||||
return {key: value for key, value in dct.items() if filter_(key)}
|
||||
|
||||
|
||||
def merge_dicts(dict1, dict2):
|
||||
""" Recursively merges dict2 into dict1 """
|
||||
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
|
||||
return dict2
|
||||
for k in dict2:
|
||||
if k in dict1:
|
||||
dict1[k] = merge_dicts(dict1[k], dict2[k])
|
||||
else:
|
||||
dict1[k] = dict2[k]
|
||||
return dict1
|
||||
|
||||
@@ -200,24 +200,30 @@ class GPUStatCollection(object):
|
||||
GPUStatCollection.global_processes[nv_process.pid] = \
|
||||
psutil.Process(pid=nv_process.pid)
|
||||
ps_process = GPUStatCollection.global_processes[nv_process.pid]
|
||||
process['username'] = ps_process.username()
|
||||
# cmdline returns full path;
|
||||
# as in `ps -o comm`, get short cmdnames.
|
||||
_cmdline = ps_process.cmdline()
|
||||
if not _cmdline:
|
||||
# sometimes, zombie or unknown (e.g. [kworker/8:2H])
|
||||
process['command'] = '?'
|
||||
process['full_command'] = ['?']
|
||||
else:
|
||||
process['command'] = os.path.basename(_cmdline[0])
|
||||
process['full_command'] = _cmdline
|
||||
# Bytes to MBytes
|
||||
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
|
||||
process['cpu_percent'] = ps_process.cpu_percent()
|
||||
process['cpu_memory_usage'] = \
|
||||
round((ps_process.memory_percent() / 100.0) *
|
||||
psutil.virtual_memory().total)
|
||||
process['pid'] = nv_process.pid
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# we do not actually use these, so no point in collecting them
|
||||
# process['username'] = ps_process.username()
|
||||
# # cmdline returns full path;
|
||||
# # as in `ps -o comm`, get short cmdnames.
|
||||
# _cmdline = ps_process.cmdline()
|
||||
# if not _cmdline:
|
||||
# # sometimes, zombie or unknown (e.g. [kworker/8:2H])
|
||||
# process['command'] = '?'
|
||||
# process['full_command'] = ['?']
|
||||
# else:
|
||||
# process['command'] = os.path.basename(_cmdline[0])
|
||||
# process['full_command'] = _cmdline
|
||||
# process['cpu_percent'] = ps_process.cpu_percent()
|
||||
# process['cpu_memory_usage'] = \
|
||||
# round((ps_process.memory_percent() / 100.0) *
|
||||
# psutil.virtual_memory().total)
|
||||
# Bytes to MBytes
|
||||
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
|
||||
except Exception:
|
||||
# insufficient permissions
|
||||
pass
|
||||
return process
|
||||
|
||||
if not GPUStatCollection._gpu_device_info.get(index):
|
||||
@@ -285,12 +291,13 @@ class GPUStatCollection(object):
|
||||
# e.g. nvidia-smi reset or reboot the system
|
||||
pass
|
||||
|
||||
# TODO: Do not block if full process info is not requested
|
||||
time.sleep(0.1)
|
||||
for process in processes:
|
||||
pid = process['pid']
|
||||
cache_process = GPUStatCollection.global_processes[pid]
|
||||
process['cpu_percent'] = cache_process.cpu_percent()
|
||||
# we do not actually use these, so no point in collecting them
|
||||
# # TODO: Do not block if full process info is not requested
|
||||
# time.sleep(0.1)
|
||||
# for process in processes:
|
||||
# pid = process['pid']
|
||||
# cache_process = GPUStatCollection.global_processes[pid]
|
||||
# process['cpu_percent'] = cache_process.cpu_percent()
|
||||
|
||||
index = N.nvmlDeviceGetIndex(handle)
|
||||
gpu_info = {
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import contextmanager
|
||||
from typing import Text, Iterable, Union
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
|
||||
from trains_agent.helper.process import Executable, Argv, PathLike
|
||||
|
||||
|
||||
@@ -66,7 +66,20 @@ class PackageManager(object):
|
||||
pass
|
||||
|
||||
def upgrade_pip(self):
|
||||
return self._install("pip"+self.get_pip_version(), "--upgrade")
|
||||
result = self._install(
|
||||
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
|
||||
packages = self.run_with_env(('list',), output=True).splitlines()
|
||||
# p.split is ('pip', 'x.y.z')
|
||||
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
|
||||
if pip:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from .requirements import MarkerRequirement
|
||||
pip = pip[0][1].split('.')
|
||||
MarkerRequirement.pip_new_version = bool(int(pip[0]) >= 20)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
# type: (...) -> Executable
|
||||
|
||||
@@ -2,8 +2,9 @@ from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import os
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
@@ -14,17 +15,18 @@ import yaml
|
||||
from time import time
|
||||
from attr import attrs, attrib, Factory
|
||||
from pathlib2 import Path
|
||||
from requirements import parse
|
||||
from requirements.requirement import Requirement
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.helper.package.requirements import SimpleVersion
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
from .requirements import RequirementsManager, MarkerRequirement
|
||||
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
|
||||
|
||||
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
|
||||
|
||||
@@ -40,8 +42,8 @@ def _package_diff(path, packages):
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
@@ -61,8 +63,8 @@ class CondaAPI(PackageManager):
|
||||
|
||||
MINIMUM_VERSION = "4.3.30"
|
||||
|
||||
def __init__(self, session, path, python, requirements_manager):
|
||||
# type: (Session, PathLike, float, RequirementsManager) -> None
|
||||
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
|
||||
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
|
||||
"""
|
||||
:param python: base python version to use (e.g python3.6)
|
||||
:param path: path of env
|
||||
@@ -72,7 +74,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = None
|
||||
self.requirements_manager = requirements_manager
|
||||
self.path = path
|
||||
self.env_read_only = False
|
||||
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
|
||||
self.conda_env_as_base_docker = \
|
||||
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
|
||||
bool(ENV_CONDA_ENV_PACKAGE.get())
|
||||
if ENV_CONDA_ENV_PACKAGE.get():
|
||||
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
|
||||
else:
|
||||
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
@@ -80,10 +90,15 @@ class CondaAPI(PackageManager):
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
self.conda = (
|
||||
find_executable("conda")
|
||||
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
|
||||
)
|
||||
try:
|
||||
self.conda = (
|
||||
find_executable("conda") or
|
||||
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
|
||||
shell=select_for_platform(windows=True, linux=False)).strip()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("ERROR: package manager \"conda\" selected, "
|
||||
"but \'conda\' executable could not be located")
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
@@ -111,13 +126,58 @@ class CondaAPI(PackageManager):
|
||||
def bin(self):
|
||||
return self.pip.bin
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
def upgrade_pip(self):
|
||||
return self._install("pip" + self.pip.get_pip_version())
|
||||
# do not change pip version if pre built environement is used
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping pip upgrade.')
|
||||
return ''
|
||||
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create a new environment
|
||||
"""
|
||||
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
|
||||
if Path(self.conda_pre_build_env_path).is_dir():
|
||||
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
self.path = Path(self.conda_pre_build_env_path)
|
||||
self.source = ("conda", "activate", self.path.as_posix())
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
self.env_read_only = True
|
||||
return self
|
||||
elif Path(self.conda_pre_build_env_path).is_file():
|
||||
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
tar_path = find_executable("tar")
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
output = Argv(
|
||||
tar_path,
|
||||
"-xzf",
|
||||
self.conda_pre_build_env_path,
|
||||
"-C",
|
||||
self.path,
|
||||
).get_output()
|
||||
|
||||
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
# unpack cleanup
|
||||
print("Fixing prefix in Conda environment {}".format(self.path))
|
||||
CommandSequence(('source', conda_env.as_posix()),
|
||||
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
|
||||
return self
|
||||
else:
|
||||
raise ValueError("Could not restore Conda environment, cannot find {}".format(
|
||||
self.conda_pre_build_env_path))
|
||||
|
||||
output = Argv(
|
||||
self.conda,
|
||||
"create",
|
||||
@@ -133,13 +193,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = self.pip.source = (
|
||||
tuple(match.group(1).split()) + (match.group(2),)
|
||||
if match
|
||||
else ("activate", self.path)
|
||||
else ("conda", "activate", self.path.as_posix())
|
||||
)
|
||||
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
|
||||
conda_env = self._get_conda_sh()
|
||||
if conda_env.is_file() and not is_windows_platform():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
|
||||
if cuda_version > 0:
|
||||
@@ -181,6 +243,10 @@ class CondaAPI(PackageManager):
|
||||
|
||||
def _install(self, *args):
|
||||
# type: (*PathLike) -> ()
|
||||
# if we are in read only mode, do not install anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
|
||||
return
|
||||
channels_args = tuple(
|
||||
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
|
||||
)
|
||||
@@ -208,6 +274,10 @@ class CondaAPI(PackageManager):
|
||||
return self._install(*packages)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
|
||||
return ''
|
||||
return self._run_command(("uninstall", "-p", self.path))
|
||||
|
||||
def install_from_file(self, path):
|
||||
@@ -226,23 +296,158 @@ class CondaAPI(PackageManager):
|
||||
with self.temp_file("pip_reqs", pip_packages) as reqs:
|
||||
self.pip.install_from_file(reqs)
|
||||
|
||||
def freeze(self):
|
||||
def freeze(self, freeze_full_environment=False):
|
||||
requirements = self.pip.freeze()
|
||||
req_lines = []
|
||||
conda_lines = []
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
conda_packages_txt = []
|
||||
requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
|
||||
for pkg in conda_packages:
|
||||
# skip if this is a pypi package or it is not a python package at all
|
||||
if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
|
||||
pip_lines = requirements['pip']
|
||||
conda_packages_json = json.loads(
|
||||
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
for r in conda_packages_json:
|
||||
# check if this is a pypi package, if it is, leave it outside
|
||||
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||
name = (r['name'].replace('-', '_'), r['name'])
|
||||
pip_req_line = [l for l in pip_lines
|
||||
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
|
||||
if pip_req_line and \
|
||||
('@' not in pip_req_line[0] or
|
||||
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
|
||||
req_lines.append(pip_req_line[0])
|
||||
continue
|
||||
|
||||
req_lines.append(
|
||||
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
|
||||
continue
|
||||
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
|
||||
requirements['conda'] = conda_packages_txt
|
||||
except:
|
||||
|
||||
# check if we have it in our required packages
|
||||
name = r['name']
|
||||
# hack support pytorch/torch different naming convention
|
||||
if name == 'pytorch':
|
||||
name = 'torch'
|
||||
# skip over packages with _
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
|
||||
# make sure we see the conda packages, put them into the pip as well
|
||||
if conda_lines:
|
||||
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
|
||||
|
||||
requirements['pip'] = req_lines
|
||||
requirements['conda'] = conda_lines
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if freeze_full_environment:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(
|
||||
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
|
||||
conda_env_json.pop('name', None)
|
||||
conda_env_json.pop('prefix', None)
|
||||
conda_env_json.pop('channels', None)
|
||||
requirements['conda_env_json'] = json.dumps(conda_env_json)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return requirements
|
||||
|
||||
def _load_conda_full_env(self, conda_env_dict, requirements):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except Exception:
|
||||
cuda_version = 0
|
||||
|
||||
conda_env_dict['channels'] = self.extra_channels
|
||||
if 'dependencies' not in conda_env_dict:
|
||||
conda_env_dict['dependencies'] = []
|
||||
new_dependencies = OrderedDict()
|
||||
pip_requirements = None
|
||||
for line in conda_env_dict['dependencies']:
|
||||
if isinstance(line, dict):
|
||||
pip_requirements = line.pop('pip', None)
|
||||
continue
|
||||
name = line.strip().split('=', 1)[0].lower()
|
||||
if name == 'pip':
|
||||
continue
|
||||
elif name == 'python':
|
||||
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
|
||||
elif name == 'tensorflow-gpu' and cuda_version == 0:
|
||||
line = 'tensorflow={}'.format(line.split('=')[1])
|
||||
elif name == 'tensorflow' and cuda_version > 0:
|
||||
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
|
||||
elif name in ('cupti', 'cudnn'):
|
||||
# cudatoolkit should pull them based on the cudatoolkit version
|
||||
continue
|
||||
elif name.startswith('_'):
|
||||
continue
|
||||
new_dependencies[line.split('=', 1)[0].strip()] = line
|
||||
|
||||
# fix packages:
|
||||
conda_env_dict['dependencies'] = list(new_dependencies.values())
|
||||
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if bad_req:
|
||||
print('failed installing the following conda packages: {}'.format(bad_req))
|
||||
return False
|
||||
|
||||
if pip_requirements:
|
||||
# create a list of vcs packages that we need to replace in the pip section
|
||||
vcs_reqs = {}
|
||||
if 'pip' in requirements:
|
||||
pip_lines = requirements['pip'].splitlines() \
|
||||
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
|
||||
for line in pip_lines:
|
||||
try:
|
||||
marker = list(parse(line))
|
||||
except Exception:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
if m.vcs:
|
||||
vcs_reqs[m.name] = m
|
||||
try:
|
||||
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
|
||||
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping requirements installation.')
|
||||
return None
|
||||
|
||||
# if we have a full conda environment, use it and pass the pip to pip
|
||||
if requirements.get('conda_env_json'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(requirements.get('conda_env_json'))
|
||||
print('Conda restoring full yaml environment')
|
||||
return self._load_conda_full_env(conda_env_json, requirements)
|
||||
except Exception:
|
||||
print('Could not load fully stored conda environment, falling back to requirements')
|
||||
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
@@ -276,6 +481,15 @@ class CondaAPI(PackageManager):
|
||||
if m.vcs:
|
||||
pip_requirements.append(m)
|
||||
continue
|
||||
# Skip over pip
|
||||
if m.name in ('pip', 'virtualenv', ):
|
||||
continue
|
||||
# python version, only major.minor
|
||||
if m.name == 'python' and m.specs:
|
||||
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
|
||||
if '.' not in m.specs[0][1]:
|
||||
continue
|
||||
|
||||
conda_supported_req_names.append(m.name.lower())
|
||||
if m.req.name.lower() == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
@@ -303,15 +517,20 @@ class CondaAPI(PackageManager):
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# skip over local files (we cannot change the version to a local file)
|
||||
if m.local_file:
|
||||
continue
|
||||
m_name = m.name.lower()
|
||||
if m_name in conda_supported_req_names:
|
||||
# this package is in the conda list,
|
||||
# make sure that if we changed version and we match it in conda
|
||||
conda_supported_req_names.remove(m_name)
|
||||
## conda_supported_req_names.remove(m_name)
|
||||
for cr in reqs:
|
||||
if m_name == cr.name.lower():
|
||||
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
|
||||
# match versions
|
||||
cr.specs = m.specs
|
||||
# # conda always likes "-" not "_" but only on pypi packages
|
||||
# cr.name = cr.name.lower().replace('_', '-')
|
||||
break
|
||||
else:
|
||||
# not in conda, it is a pip package
|
||||
@@ -319,29 +538,39 @@ class CondaAPI(PackageManager):
|
||||
if m_name == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
|
||||
# remove any leftover conda packages (they were removed from the pip list)
|
||||
if conda_supported_req_names:
|
||||
reqs = [r for r in reqs if r.name.lower() not in conda_supported_req_names]
|
||||
|
||||
# Conda requirements Hacks:
|
||||
if has_matplotlib:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||
|
||||
# remove specific cudatoolkit, it should have being preinstalled.
|
||||
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
|
||||
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
|
||||
|
||||
if has_torch and cuda_version == 0:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||
|
||||
# make sure we have no double entries
|
||||
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
|
||||
|
||||
# conform conda packages (version/name)
|
||||
for r in reqs:
|
||||
# change _ to - in name but not the prefix _ (as this is conda prefix)
|
||||
if not r.name.startswith('_') and not requirements.get('conda', None):
|
||||
r.name = r.name.replace('_', '-')
|
||||
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
||||
if r.specs and r.specs[0]:
|
||||
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
||||
# conda always likes "-" not "_"
|
||||
r.req.name = r.req.name.replace('_', '-')
|
||||
|
||||
while reqs:
|
||||
# notice, we give conda more freedom in version selection, to help it choose best combination
|
||||
conda_env['dependencies'] = [r.tostr() for r in reqs]
|
||||
def clean_ver(ar):
|
||||
if not ar.specs:
|
||||
return ar.tostr()
|
||||
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
|
||||
return ar.tostr()
|
||||
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||
result = self._run_command(
|
||||
@@ -371,12 +600,15 @@ class CondaAPI(PackageManager):
|
||||
|
||||
if pip_requirements:
|
||||
try:
|
||||
pip_req_str = [r.tostr() for r in pip_requirements]
|
||||
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
self.pip.load_requirements('\n'.join(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
return True
|
||||
@@ -441,8 +673,22 @@ class CondaAPI(PackageManager):
|
||||
def get_python_command(self, extra=()):
|
||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||
|
||||
def _get_conda_sh(self):
|
||||
# type () -> Path
|
||||
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if base_conda_env.is_file():
|
||||
return base_conda_env
|
||||
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
|
||||
conda = find_executable("conda", path=path)
|
||||
if not conda:
|
||||
continue
|
||||
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
return conda_env
|
||||
return base_conda_env
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on unhashable exceptions
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
|
||||
exception = attrs(str=True, cmp=False)
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class CythonRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("cython", "numpy", )
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CythonRequirement, self).__init__(*args, **kwargs)
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return req.name and req.name.lower() in self.name
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# install Cython before
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Text
|
||||
|
||||
@@ -21,6 +22,8 @@ class ExternalRequirements(SimpleSubstitution):
|
||||
return False
|
||||
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
|
||||
return False
|
||||
if req.pip_new_version and not (req.req.editable or req.req.vcs):
|
||||
return False
|
||||
return True
|
||||
|
||||
def post_install(self, session):
|
||||
@@ -33,6 +36,9 @@ class ExternalRequirements(SimpleSubstitution):
|
||||
freeze_base = ''
|
||||
|
||||
req_line = req.tostr(markers=False)
|
||||
if req_line.strip().startswith('-e ') or req_line.strip().startswith('--editable'):
|
||||
req_line = re.sub(r'^(-e|--editable=?)\s*', '', req_line, count=1)
|
||||
|
||||
if req.req.vcs and req_line.startswith('git+'):
|
||||
try:
|
||||
url_no_frag = furl(req_line)
|
||||
@@ -47,22 +53,30 @@ class ExternalRequirements(SimpleSubstitution):
|
||||
vcs._set_ssh_url()
|
||||
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
|
||||
if new_req_line != req_line:
|
||||
url_pass = furl(new_req_line).password
|
||||
furl_line = furl(new_req_line)
|
||||
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
|
||||
req_line, new_req_line.replace(url_pass, '****', 1) if url_pass else new_req_line))
|
||||
req_line,
|
||||
furl_line.set(password='xxxxxx').tostr() if furl_line.password else new_req_line))
|
||||
req_line = new_req_line
|
||||
except Exception:
|
||||
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
|
||||
|
||||
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
|
||||
try:
|
||||
freeze_post = PackageManager.out_of_scope_freeze() or ''
|
||||
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
|
||||
if package_name and package_name[0] not in self.post_install_req_lookup:
|
||||
self.post_install_req_lookup[package_name[0]] = req.req.line
|
||||
except:
|
||||
pass
|
||||
if not PackageManager.out_of_scope_install_package(req_line, "--ignore-installed"):
|
||||
# if we have older pip version we have to make sure we replace back the package name with the
|
||||
# git repository link. In new versions this is supported and we get "package @ git+https://..."
|
||||
if not req.pip_new_version:
|
||||
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
freeze_post = PackageManager.out_of_scope_freeze() or ''
|
||||
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
|
||||
if package_name and package_name[0] not in self.post_install_req_lookup:
|
||||
self.post_install_req_lookup[package_name[0]] = req.req.line
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# no need to force reinstall, pip will always rebuilt if the package comes from git
|
||||
# and make sure the required packages are installed (if they are not it will install them)
|
||||
if not PackageManager.out_of_scope_install_package(req_line):
|
||||
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
|
||||
|
||||
def replace(self, req):
|
||||
@@ -76,10 +90,17 @@ class ExternalRequirements(SimpleSubstitution):
|
||||
return Text('')
|
||||
|
||||
def replace_back(self, list_of_requirements):
|
||||
if 'pip' in list_of_requirements:
|
||||
original_requirements = list_of_requirements['pip']
|
||||
list_of_requirements['pip'] = [r for r in original_requirements
|
||||
if r not in self.post_install_req_lookup]
|
||||
list_of_requirements['pip'] += [self.post_install_req_lookup.get(r, '')
|
||||
for r in self.post_install_req_lookup.keys() if r in original_requirements]
|
||||
if not list_of_requirements:
|
||||
return list_of_requirements
|
||||
|
||||
for k in list_of_requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = list_of_requirements[k]
|
||||
list_of_requirements[k] = [r for r in original_requirements
|
||||
if r not in self.post_install_req_lookup]
|
||||
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
|
||||
for r in self.post_install_req_lookup.keys() if r in original_requirements]
|
||||
return list_of_requirements
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class HorovodRequirement(SimpleSubstitution):
|
||||
|
||||
name = "horovod"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(HorovodRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return req.name and self.name == req.name.lower()
|
||||
|
||||
def post_install(self, session):
|
||||
if self.post_install_req:
|
||||
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
|
||||
self.post_install_req = None
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req = req
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session
|
||||
@@ -9,8 +11,8 @@ from ..requirements import RequirementsManager
|
||||
|
||||
|
||||
class VirtualenvPip(SystemPip, PackageManager):
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
|
||||
"""
|
||||
Program interface to virtualenv pip.
|
||||
Must be given either path to virtualenv or source command.
|
||||
|
||||
48
trains_agent/helper/package/post_req.py
Normal file
48
trains_agent/helper/package/post_req.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PostRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("horovod", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PostRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = []
|
||||
# check if we need to replace the packages:
|
||||
post_packages = self.config.get('agent.package_manager.post_packages', None)
|
||||
if post_packages:
|
||||
self.__class__.name = post_packages
|
||||
post_optional_packages = self.config.get('agent.package_manager.post_optional_packages', None)
|
||||
if post_optional_packages:
|
||||
self.__class__.optional_package_names = post_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def post_install(self, session):
|
||||
for req in self.post_install_req:
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
|
||||
self.post_install_req = []
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req.append(req)
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
75
trains_agent/helper/package/priority_req.py
Normal file
75
trains_agent/helper/package/priority_req.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PriorityPackageRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("cython", "numpy", "setuptools", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PriorityPackageRequirement, self).__init__(*args, **kwargs)
|
||||
# check if we need to replace the packages:
|
||||
priority_packages = self.config.get('agent.package_manager.priority_packages', None)
|
||||
if priority_packages:
|
||||
self.__class__.name = priority_packages
|
||||
priority_optional_packages = self.config.get('agent.package_manager.priority_optional_packages', None)
|
||||
if priority_optional_packages:
|
||||
self.__class__.optional_package_names = priority_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if PackageManager.out_of_scope_install_package(str(req)):
|
||||
return Text(req)
|
||||
except Exception:
|
||||
pass
|
||||
return Text('')
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
|
||||
|
||||
class PackageCollectorRequirement(SimpleSubstitution):
|
||||
"""
|
||||
This RequirementSubstitution class will allow you to have multiple instances of the same
|
||||
package, it will output the last one (by order) to be actually used.
|
||||
"""
|
||||
name = tuple()
|
||||
|
||||
def __init__(self, session, collect_package):
|
||||
super(PackageCollectorRequirement, self).__init__(session)
|
||||
self._collect_packages = collect_package or tuple()
|
||||
self._last_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match package names
|
||||
return req.name and req.name.lower() in self._collect_packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
self._last_req = req.clone()
|
||||
return ''
|
||||
|
||||
def post_scan_add_req(self):
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
last_req = self._last_req
|
||||
self._last_req = None
|
||||
return last_req
|
||||
@@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
|
||||
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
|
||||
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
|
||||
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
|
||||
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -117,20 +119,24 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version, nightly=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except:
|
||||
except Exception:
|
||||
cuda = 0
|
||||
|
||||
if nightly:
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch nightly CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# first check if key is valid
|
||||
@@ -138,13 +144,16 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
|
||||
# then try a new cuda version page
|
||||
torch_url = cls.page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
torch_url = cls.page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
|
||||
for k in keys:
|
||||
@@ -157,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
name = "torch"
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os_name = kwargs.pop("os_override", None)
|
||||
@@ -235,6 +244,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
py_ver = self.python_major_minor_str.replace('.', '')
|
||||
url = None
|
||||
last_v = None
|
||||
closest_v = None
|
||||
# search for our package
|
||||
for l in links_parser.links:
|
||||
parts = l.split('/')[-1].split('-')
|
||||
@@ -244,73 +254,94 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
continue
|
||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||
# version ignore .postX suffix (treat as regular version)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = str(parts[1].split('%')[0].split('+')[0])
|
||||
except Exception:
|
||||
continue
|
||||
if len(parts) < 3 or not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if len(parts) < 5 or platform_wheel not in parts[4]:
|
||||
continue
|
||||
# update the closest matched version (from above)
|
||||
if not closest_v:
|
||||
closest_v = v
|
||||
elif SimpleVersion.compare_versions(
|
||||
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
|
||||
SimpleVersion.compare_versions(
|
||||
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
|
||||
closest_v = v
|
||||
# check if this an actual match
|
||||
if not req.compare_version(v) or \
|
||||
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||
continue
|
||||
if not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if platform_wheel not in parts[4]:
|
||||
continue
|
||||
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
last_v = v
|
||||
# if we found an exact match, use it
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if req.specs[0][0] == '==' and \
|
||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url
|
||||
return url, last_v or closest_v
|
||||
|
||||
def get_url_for_platform(self, req):
|
||||
# check if package is already installed with system packages
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self.config.get("agent.package_manager.system_site_packages", None):
|
||||
from pip._internal.commands.show import search_packages_info
|
||||
installed_torch = list(search_packages_info([req.name]))
|
||||
# notice the comparision order, the first part will make sure we have a valid installed package
|
||||
if installed_torch[0]['version'] and req.compare_version(installed_torch[0]['version']):
|
||||
# notice the comparison order, the first part will make sure we have a valid installed package
|
||||
if installed_torch and installed_torch[0]['version'] and \
|
||||
req.compare_version(installed_torch[0]['version']):
|
||||
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
||||
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
|
||||
# package already installed, do nothing
|
||||
return str(req), True
|
||||
except:
|
||||
req.specs = [('==', str(installed_torch[0]['version']))]
|
||||
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# make sure we have a specific version to retrieve
|
||||
if not req.specs:
|
||||
req.specs = [('>', '0')]
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
op, version = req.specs[0]
|
||||
# assert op == "=="
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||
while not url and torch_url_key > 0:
|
||||
previous_cuda_key = torch_url_key
|
||||
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
|
||||
req, previous_cuda_key, closest_matched_version))
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if url:
|
||||
break
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||
# never fallback to CPU
|
||||
if torch_url_key < 1:
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, self.cuda_version))
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
|
||||
req, previous_cuda_key, torch_url_key))
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
print(
|
||||
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError(
|
||||
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
|
||||
else:
|
||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
@@ -322,6 +353,8 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if url:
|
||||
# normalize url (sometimes we will get ../ which we should not...
|
||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||
# print found
|
||||
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
|
||||
|
||||
self.log.debug("checking url: %s", url)
|
||||
return url, requests.head(url, timeout=10).ok
|
||||
@@ -457,7 +490,13 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if req.req.name == parts[0]:
|
||||
# support for pip >= 20.1
|
||||
if '@' in line:
|
||||
lines[i] = '{} # {}'.format(str(req), str(new_req))
|
||||
# skip if we have nothing to add
|
||||
if str(req).strip() != str(new_req).strip():
|
||||
# if this is local file and use the version detection
|
||||
if req.local_file:
|
||||
lines[i] = '{}'.format(str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(str(req), str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(line, str(new_req))
|
||||
break
|
||||
|
||||
@@ -4,7 +4,7 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
from itertools import chain, starmap
|
||||
from operator import itemgetter
|
||||
from os import path
|
||||
@@ -12,15 +12,15 @@ from typing import Text, List, Type, Optional, Tuple, Dict
|
||||
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
from requirements import parse
|
||||
# noinspection PyPackageRequirements
|
||||
from requirements.requirement import Requirement
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES
|
||||
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session, normalize_cuda_version
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from .translator import RequirementsTranslator
|
||||
|
||||
|
||||
@@ -35,6 +35,10 @@ class FatalSpecsResolutionError(Exception):
|
||||
@six.python_2_unicode_compatible
|
||||
class MarkerRequirement(object):
|
||||
|
||||
# if True pip version above 20.x and with support for "package @ scheme://link"
|
||||
# default is True
|
||||
pip_new_version = True
|
||||
|
||||
def __init__(self, req): # type: (Requirement) -> None
|
||||
self.req = req
|
||||
|
||||
@@ -57,7 +61,7 @@ class MarkerRequirement(object):
|
||||
elif self.vcs:
|
||||
# leave the line as is, let pip handle it
|
||||
if self.line:
|
||||
parts = [self.line]
|
||||
return self.line
|
||||
else:
|
||||
# let's build the line manually
|
||||
parts = [
|
||||
@@ -65,6 +69,10 @@ class MarkerRequirement(object):
|
||||
'@{}'.format(self.revision) if self.revision else '',
|
||||
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
|
||||
]
|
||||
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
|
||||
# package @ file:///example.com/somewheel.whl
|
||||
# leave the line as is, let pip handle it
|
||||
return self.line
|
||||
else:
|
||||
parts = [self.uri]
|
||||
|
||||
@@ -73,6 +81,9 @@ class MarkerRequirement(object):
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
def clone(self):
|
||||
return MarkerRequirement(copy(self.req))
|
||||
|
||||
__str__ = tostr
|
||||
|
||||
def __repr__(self):
|
||||
@@ -138,7 +149,8 @@ class MarkerRequirement(object):
|
||||
version = self.specs[0][1]
|
||||
op = (op or self.specs[0][0]).strip()
|
||||
|
||||
return SimpleVersion.compare_versions(requested_version, op, version)
|
||||
return SimpleVersion.compare_versions(
|
||||
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
|
||||
|
||||
|
||||
class SimpleVersion:
|
||||
@@ -177,7 +189,7 @@ class SimpleVersion:
|
||||
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||
|
||||
@classmethod
|
||||
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True):
|
||||
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
|
||||
"""
|
||||
Compare two versions based on the op operator
|
||||
returns bool(version_a op version_b)
|
||||
@@ -188,12 +200,12 @@ class SimpleVersion:
|
||||
:param str version_b:
|
||||
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
||||
(ignore a/b/rc/post/dev in the comparison)
|
||||
:param int num_parts: number of parts to compare, split by . (dot)
|
||||
:return bool: version_a op version_b
|
||||
"""
|
||||
|
||||
if not version_b:
|
||||
return True
|
||||
num_parts = 3
|
||||
|
||||
if op == '~=':
|
||||
num_parts = max(num_parts, 2)
|
||||
@@ -326,6 +338,14 @@ class RequirementSubstitution(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
return None
|
||||
|
||||
def post_install(self, session):
|
||||
pass
|
||||
|
||||
@@ -480,6 +500,14 @@ class RequirementsManager(object):
|
||||
)
|
||||
if not conda:
|
||||
result = map(self.translator.translate, result)
|
||||
|
||||
result = list(result)
|
||||
# add post scan add requirements call back
|
||||
for h in self.handlers:
|
||||
req = h.post_scan_add_req()
|
||||
if req:
|
||||
result.append(req.tostr())
|
||||
|
||||
return join_lines(result)
|
||||
|
||||
def post_install(self, session):
|
||||
@@ -491,6 +519,9 @@ class RequirementsManager(object):
|
||||
raise
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if self.translator:
|
||||
requirements = self.translator.replace_back(requirements)
|
||||
|
||||
for h in self.handlers:
|
||||
try:
|
||||
requirements = h.replace_back(requirements)
|
||||
|
||||
@@ -23,6 +23,7 @@ class RequirementsTranslator(object):
|
||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.config = Config()
|
||||
self.pip = SystemPip(interpreter=interpreter, session=self._session)
|
||||
self._translate_back = {}
|
||||
|
||||
def download(self, url):
|
||||
self.pip.download_package(url, cache_dir=self.cache_dir)
|
||||
@@ -60,4 +61,30 @@ class RequirementsTranslator(object):
|
||||
except Exception:
|
||||
command.error('Could not download wheel name of "{}"'.format(line))
|
||||
return line
|
||||
|
||||
self._translate_back[str(downloaded)] = line
|
||||
|
||||
return downloaded
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if not requirements:
|
||||
return requirements
|
||||
|
||||
for k in requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = requirements[k]
|
||||
new_requirements = []
|
||||
for line in original_requirements:
|
||||
local_file = [d for d in self._translate_back.keys() if d in line]
|
||||
if local_file:
|
||||
local_file = local_file[0]
|
||||
new_requirements.append(line.replace(local_file, self._translate_back[local_file]))
|
||||
else:
|
||||
new_requirements.append(line)
|
||||
|
||||
requirements[k] = new_requirements
|
||||
|
||||
return requirements
|
||||
|
||||
@@ -11,6 +11,7 @@ from copy import deepcopy
|
||||
from distutils.spawn import find_executable
|
||||
from itertools import chain, repeat, islice
|
||||
from os.path import devnull
|
||||
from time import sleep
|
||||
from typing import Union, Text, Sequence, Any, TypeVar, Callable
|
||||
|
||||
import psutil
|
||||
@@ -41,6 +42,30 @@ def get_bash_output(cmd, strip=False, stderr=subprocess.STDOUT, stdin=False):
|
||||
return output if not strip or not output else output.strip()
|
||||
|
||||
|
||||
def terminate_process(pid, timeout=10.):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
proc = psutil.Process(pid)
|
||||
proc.terminate()
|
||||
cnt = 0
|
||||
while proc.is_running() and cnt < timeout:
|
||||
sleep(1.)
|
||||
cnt += 1
|
||||
proc.terminate()
|
||||
cnt = 0
|
||||
while proc.is_running() and cnt < timeout:
|
||||
sleep(1.)
|
||||
cnt += 1
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return not psutil.Process(pid).is_running()
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def kill_all_child_processes(pid=None):
|
||||
# get current process if pid not provided
|
||||
include_parent = True
|
||||
|
||||
@@ -5,7 +5,7 @@ import subprocess
|
||||
from distutils.spawn import find_executable
|
||||
from hashlib import md5
|
||||
from os import environ, getenv
|
||||
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple
|
||||
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple, Optional
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
@@ -13,7 +13,7 @@ from pathlib2 import Path
|
||||
|
||||
import six
|
||||
|
||||
from trains_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS
|
||||
from trains_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS, ENV_AGENT_GIT_HOST
|
||||
from trains_agent.helper.console import ensure_text, ensure_binary
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import (
|
||||
@@ -150,12 +150,23 @@ class VCS(object):
|
||||
"""
|
||||
Apply patch repository at `location`
|
||||
"""
|
||||
self.log.info("applying diff to %s", location)
|
||||
self.log.info("applying diff to %s" % location)
|
||||
|
||||
for match in filter(
|
||||
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
|
||||
):
|
||||
create_file_if_not_exists(normalize_path(location, match.group("path")))
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
for match in filter(
|
||||
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
|
||||
):
|
||||
file_path = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
file_path = normalize_path(location, match.group("path"))
|
||||
create_file_if_not_exists(file_path)
|
||||
except Exception:
|
||||
if file_path:
|
||||
self.log.warning("failed creating file for git diff (%s)" % file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return_code, errors = self.call_with_stdin(
|
||||
patch_content, *self.patch_base, cwd=location
|
||||
@@ -243,8 +254,8 @@ class VCS(object):
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def replace_http_url(cls, url):
|
||||
# type: (Text) -> Text
|
||||
def replace_http_url(cls, url, port=None):
|
||||
# type: (Text, Optional[int]) -> Text
|
||||
"""
|
||||
Replace HTTPS URL with SSH URL when applicable
|
||||
"""
|
||||
@@ -254,7 +265,9 @@ class VCS(object):
|
||||
parsed_url.username = "git"
|
||||
parsed_url.password = None
|
||||
# make sure there is no port in the final url (safe_furl support)
|
||||
parsed_url.port = None
|
||||
# the original port was an https port, and we do not know if there is a different ssh port,
|
||||
# so we have to clear the original port specified (https) and use the default ssh schema port.
|
||||
parsed_url.port = port or None
|
||||
url = parsed_url.url
|
||||
return url
|
||||
|
||||
@@ -265,8 +278,14 @@ class VCS(object):
|
||||
"""
|
||||
if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
|
||||
parsed_url = furl(self.url)
|
||||
# only apply to a specific domain (if requested)
|
||||
config_domain = \
|
||||
ENV_AGENT_GIT_HOST.get() or self.session.config.get("agent.git_host", None)
|
||||
if config_domain and config_domain != parsed_url.host:
|
||||
return
|
||||
if parsed_url.scheme == "https":
|
||||
new_url = self.replace_http_url(self.url)
|
||||
new_url = self.replace_http_url(
|
||||
self.url, port=self.session.config.get('agent.force_git_ssh_port', None))
|
||||
if new_url != self.url:
|
||||
print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
|
||||
self.url, new_url))
|
||||
@@ -276,11 +295,15 @@ class VCS(object):
|
||||
if not self.session.config.agent.translate_ssh:
|
||||
return
|
||||
|
||||
ssh_agent_variable = "SSH_AUTH_SOCK"
|
||||
if not getenv(ssh_agent_variable) and (
|
||||
(ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and
|
||||
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None))
|
||||
):
|
||||
# if we have git_user / git_pass replace ssh credentials with https authentication
|
||||
if (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and \
|
||||
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)):
|
||||
# only apply to a specific domain (if requested)
|
||||
config_domain = \
|
||||
ENV_AGENT_GIT_HOST.get() or self.session.config.get("git_host", None)
|
||||
if config_domain and config_domain != furl(self.url).host:
|
||||
return
|
||||
|
||||
new_url = self.replace_ssh_url(self.url)
|
||||
if new_url != self.url:
|
||||
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
|
||||
@@ -358,9 +381,10 @@ class VCS(object):
|
||||
"""
|
||||
Run command with `input_` as stdin
|
||||
"""
|
||||
input_ = input_.encode("latin1")
|
||||
if not input_.endswith(b"\n"):
|
||||
input_ += b"\n"
|
||||
input_ = input_.encode("utf-8")
|
||||
# always add extra empty line
|
||||
# (there is no downside, and it solves empty lines issue at end of patch cause corrupt message)
|
||||
input_ += b"\n"
|
||||
process = self._call_subprocess(
|
||||
subprocess.Popen,
|
||||
argv,
|
||||
@@ -430,10 +454,12 @@ class VCS(object):
|
||||
return parsed_url.url
|
||||
config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)
|
||||
config_pass = ENV_AGENT_GIT_PASS.get() or config.get("agent.{}_pass".format(cls.executable_name), None)
|
||||
config_domain = ENV_AGENT_GIT_HOST.get() or config.get("agent.{}_host".format(cls.executable_name), None)
|
||||
if (
|
||||
(not (parsed_url.username and parsed_url.password))
|
||||
and config_user
|
||||
and config_pass
|
||||
and (not config_domain or config_domain.lower() == parsed_url.host)
|
||||
):
|
||||
parsed_url.set(username=config_user, password=config_pass)
|
||||
return parsed_url.url
|
||||
@@ -508,7 +534,7 @@ class Git(VCS):
|
||||
root=Argv(executable_name, "rev-parse", "--show-toplevel"),
|
||||
)
|
||||
|
||||
patch_base = ("apply",)
|
||||
patch_base = ("apply", "--unidiff-zero", )
|
||||
|
||||
|
||||
class Hg(VCS):
|
||||
|
||||
170
trains_agent/helper/runtime_verification.py
Normal file
170
trains_agent/helper/runtime_verification.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from trains_agent.backend_config.defs import UptimeConf
|
||||
|
||||
DAYS = ["SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"]
|
||||
PATTERN = re.compile(r"^(?P<hours>[^\s]+)\s(?P<days>[^\s]+)")
|
||||
|
||||
|
||||
def check_runtime(ranges_list, is_uptime=True):
|
||||
# type: (List[str], bool) -> bool
|
||||
for entry in ranges_list:
|
||||
|
||||
days_list = get_days_list(entry)
|
||||
if not check_day(days_list):
|
||||
continue
|
||||
|
||||
hours_list = get_hours_list(entry)
|
||||
if check_hour(hours_list):
|
||||
return is_uptime
|
||||
return not is_uptime
|
||||
|
||||
|
||||
def check_hour(hours):
|
||||
# type: (List[str]) -> bool
|
||||
return datetime.now().hour in hours
|
||||
|
||||
|
||||
def check_day(days):
|
||||
# type: (List[str]) -> bool
|
||||
return datetime.now().strftime("%a").upper() in days
|
||||
|
||||
|
||||
def get_days_list(entry):
|
||||
# type: (str) -> List[str]
|
||||
days_intervals = PATTERN.match(entry)["days"].split(",")
|
||||
days_total = []
|
||||
for days in days_intervals:
|
||||
start, end = days.split("-") if "-" in days else (days, days)
|
||||
try:
|
||||
days_total.extend(
|
||||
[*range(DAYS.index(start.upper()), DAYS.index(end.upper()) + 1)]
|
||||
)
|
||||
except ValueError:
|
||||
print(
|
||||
"Warning: days interval '{}' is invalid, use intervals of the format <start>-<end>."
|
||||
" make sure to use the abbreviated format SUN-SAT".format(days)
|
||||
)
|
||||
continue
|
||||
return [DAYS[day] for day in days_total]
|
||||
|
||||
|
||||
def get_hours_list(entry):
|
||||
# type: (str) -> List[str]
|
||||
hours_intervals = PATTERN.match(entry)["hours"].split(",")
|
||||
hours_total = []
|
||||
for hours in hours_intervals:
|
||||
start, end = get_start_end_hours(hours)
|
||||
if not (start and end):
|
||||
continue
|
||||
hours_total.extend([*range(start, end)])
|
||||
return hours_total
|
||||
|
||||
|
||||
def get_start_end_hours(hours):
|
||||
# type: (str) -> Tuple[int, int]
|
||||
try:
|
||||
start, end = (
|
||||
tuple(map(int, hours.split("-"))) if "-" in hours else (int(hours), 24)
|
||||
)
|
||||
except Exception as ex:
|
||||
print(
|
||||
"Warning: hours interval '{}' is invalid, use intervals of the format <start>-<end>".format(
|
||||
hours, ex
|
||||
)
|
||||
)
|
||||
start, end = (None, None)
|
||||
if end == 0:
|
||||
end = 24
|
||||
return start, end
|
||||
|
||||
|
||||
def print_uptime_properties(
|
||||
ranges_list, queues_info, runtime_properties, is_uptime=True
|
||||
):
|
||||
# type: (List[str], List[dict], List[dict], bool) -> None
|
||||
if ranges_list:
|
||||
uptime_string = ["Working hours {} configurations".format("uptime" if is_uptime else "downtime")]
|
||||
uptime_string.extend(get_uptime_string(entry) for entry in ranges_list)
|
||||
else:
|
||||
uptime_string = ["No uptime/downtime configurations found"]
|
||||
|
||||
is_server_forced, server_string = get_runtime_properties_string(runtime_properties)
|
||||
is_queue_forced, queues_string = get_queues_tags_string(queues_info)
|
||||
|
||||
res = list(
|
||||
filter(
|
||||
len,
|
||||
[
|
||||
"\n".join(uptime_string),
|
||||
"\nCurrently forced {}".format(is_queue_forced or is_server_forced)
|
||||
if is_queue_forced or is_server_forced
|
||||
else " ",
|
||||
server_string,
|
||||
queues_string,
|
||||
],
|
||||
)
|
||||
)
|
||||
print("\n".join(res))
|
||||
|
||||
|
||||
def get_uptime_string(entry):
|
||||
# type: (str) -> str
|
||||
res = []
|
||||
days_list = get_days_list(entry)
|
||||
hours_intervals = PATTERN.match(entry)["hours"].split(",")
|
||||
for hours in hours_intervals:
|
||||
start, end = get_start_end_hours(hours)
|
||||
if not (start and end):
|
||||
continue
|
||||
res.append(
|
||||
" - {}:00-{}:59 on {}".format(start, end - 1, ' and '.join(days_list))
|
||||
if not (start == end)
|
||||
else ""
|
||||
)
|
||||
return "\n".join(res)
|
||||
|
||||
|
||||
def get_runtime_properties_string(runtime_properties):
|
||||
# type: (List[dict]) -> Tuple[Optional[str], str]
|
||||
server_string = []
|
||||
force_flag = next(
|
||||
(prop for prop in runtime_properties if prop["key"] == UptimeConf.worker_key),
|
||||
None,
|
||||
)
|
||||
is_server_forced = None
|
||||
if force_flag:
|
||||
is_server_forced = force_flag["value"].upper()
|
||||
expiry_hour = (
|
||||
(datetime.now() + timedelta(seconds=force_flag["expiry"])).strftime("%H:%M")
|
||||
if force_flag["expiry"]
|
||||
else None
|
||||
)
|
||||
expires = " expires at {}".format(expiry_hour) if expiry_hour else ""
|
||||
server_string.append(
|
||||
" - Server runtime property '{}: {}'{}".format(force_flag['key'], force_flag['value'], expires)
|
||||
)
|
||||
return is_server_forced, "\n".join(server_string)
|
||||
|
||||
|
||||
def get_queues_tags_string(queues_info):
|
||||
# type: (List[dict]) -> Tuple[Optional[str], str]
|
||||
queues_string = []
|
||||
is_queue_forced = None
|
||||
for queue in queues_info:
|
||||
if "force_workers:off" in queue.get("tags", []):
|
||||
tag = "force_workers:off"
|
||||
is_queue_forced = is_queue_forced or "OFF"
|
||||
elif "force_workers:on" in queue.get("tags", []):
|
||||
tag = "force_workers:on"
|
||||
is_queue_forced = "ON"
|
||||
else:
|
||||
tag = None
|
||||
tagged = " (tagged '{}')'".format(tag) if tag else ""
|
||||
queues_string.append(
|
||||
" - Listening to queue '{}'{}".format(queue.get('name', ''), tagged)
|
||||
)
|
||||
return is_queue_forced, "\n".join(queues_string)
|
||||
@@ -4,6 +4,8 @@ from time import sleep
|
||||
from glob import glob
|
||||
from tempfile import gettempdir, NamedTemporaryFile
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from trains_agent.definitions import ENV_DOCKER_HOST_MOUNT
|
||||
from trains_agent.helper.base import warning
|
||||
|
||||
@@ -37,6 +39,10 @@ class Singleton(object):
|
||||
except:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_lock_filename(cls):
|
||||
return os.path.join(cls._get_temp_folder(), cls._lock_file_name)
|
||||
|
||||
@classmethod
|
||||
def register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
|
||||
"""
|
||||
@@ -47,7 +53,7 @@ class Singleton(object):
|
||||
:return: (str worker_id, int slot_number) Return None value on instance already running
|
||||
"""
|
||||
# try to lock file
|
||||
lock_file = os.path.join(cls._get_temp_folder(), cls._lock_file_name)
|
||||
lock_file = cls.get_lock_filename()
|
||||
timeout = 0
|
||||
while os.path.exists(lock_file):
|
||||
if timeout > cls._lock_timeout:
|
||||
@@ -79,30 +85,42 @@ class Singleton(object):
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
|
||||
if cls.worker_id:
|
||||
return cls.worker_id, cls.instance_slot
|
||||
# make sure we have a unique name
|
||||
instance_num = 0
|
||||
def get_running_pids(cls):
|
||||
# type: () -> List[Tuple[int, Optional[str], Optional[int], str]]
|
||||
temp_folder = cls._get_temp_folder()
|
||||
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
|
||||
slots = {}
|
||||
pids = []
|
||||
for file in files:
|
||||
parts = file.split(cls.sep)
|
||||
parts = os.path.basename(file).split(cls.sep)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if not psutil.pid_exists(pid):
|
||||
pid = -1
|
||||
except Exception:
|
||||
# something is wrong, use non existing pid and delete the file
|
||||
pid = -1
|
||||
|
||||
uid, slot = None, None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
uid, slot = str(f.read()).split('\n')
|
||||
slot = int(slot)
|
||||
except Exception:
|
||||
pass
|
||||
pids.append((pid, uid, slot, file))
|
||||
|
||||
return pids
|
||||
|
||||
@classmethod
|
||||
def _register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
|
||||
if cls.worker_id:
|
||||
return cls.worker_id, cls.instance_slot
|
||||
# make sure we have a unique name
|
||||
instance_num = 0
|
||||
slots = {}
|
||||
for pid, uid, slot, file in cls.get_running_pids():
|
||||
worker = None
|
||||
if api_client and ENV_DOCKER_HOST_MOUNT.get() and uid:
|
||||
try:
|
||||
@@ -111,7 +129,7 @@ class Singleton(object):
|
||||
worker = None
|
||||
|
||||
# count active instances and delete dead files
|
||||
if not worker and not psutil.pid_exists(pid):
|
||||
if not worker and pid < 0:
|
||||
# delete the file
|
||||
try:
|
||||
os.remove(os.path.join(file))
|
||||
@@ -165,3 +183,9 @@ class Singleton(object):
|
||||
@classmethod
|
||||
def get_slot(cls):
|
||||
return cls.instance_slot or 0
|
||||
|
||||
@classmethod
|
||||
def get_pid_file(cls):
|
||||
if not cls._pid_file:
|
||||
return None
|
||||
return cls._pid_file.name
|
||||
|
||||
@@ -68,6 +68,10 @@ DAEMON_ARGS = dict({
|
||||
'dest': 'queues',
|
||||
'type': foreign_object_id('queues'),
|
||||
},
|
||||
'--order-fairness': {
|
||||
'help': 'Pull from each queue in a round-robin order, instead of priority order.',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--standalone-mode': {
|
||||
'help': 'Do not use any network connects, assume everything is pre-installed',
|
||||
'action': 'store_true',
|
||||
@@ -85,7 +89,28 @@ DAEMON_ARGS = dict({
|
||||
'action': 'store_true',
|
||||
'aliases': ['-d'],
|
||||
},
|
||||
|
||||
'--stop': {
|
||||
'help': 'Stop the running agent (based on the same set of arguments)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--uptime': {
|
||||
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "17-20 TUE" to set '
|
||||
'Tuesday\'s uptime to 17-20'
|
||||
'Note: Make sure to have only one of uptime/downtime configuration and not both.',
|
||||
'nargs': '*',
|
||||
'default': None,
|
||||
},
|
||||
'--downtime': {
|
||||
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "09-13 TUE" to set '
|
||||
'Tuesday\'s downtime to 09-13'
|
||||
'Note: Make sure to have only on of uptime/downtime configuration and not both.',
|
||||
'nargs': '*',
|
||||
'default': None,
|
||||
},
|
||||
'--status': {
|
||||
'help': 'Print the worker\'s schedule (uptime properties, server\'s runtime properties and listening queues)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
}, **WORKER_ARGS)
|
||||
|
||||
COMMANDS = {
|
||||
|
||||
@@ -73,9 +73,11 @@ class Session(_Session):
|
||||
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
|
||||
if not Path(config_file).is_file():
|
||||
raise ValueError("Could not open configuration file: {}".format(config_file))
|
||||
|
||||
cpu_only = kwargs.get('cpu_only')
|
||||
if cpu_only:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = 'none'
|
||||
|
||||
if kwargs.get('gpus') and not os.environ.get('KUBERNETES_SERVICE_HOST') \
|
||||
and not os.environ.get('KUBERNETES_PORT'):
|
||||
# CUDA_VISIBLE_DEVICES does not support 'all'
|
||||
@@ -84,6 +86,7 @@ class Session(_Session):
|
||||
os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
|
||||
|
||||
if kwargs.get('only_load_config'):
|
||||
from trains_agent.backend_api.config import load
|
||||
self.config = load()
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.15.2rc0'
|
||||
__version__ = '0.16.2'
|
||||
|
||||
Reference in New Issue
Block a user