mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
090327234a | ||
|
|
3620c3a12d | ||
|
|
9a3f950ac6 | ||
|
|
0b36cb0f85 | ||
|
|
dd42423482 | ||
|
|
69eb25db1f | ||
|
|
a41ea52f87 | ||
|
|
259113c989 | ||
|
|
1afa3a3914 | ||
|
|
448e23825c | ||
|
|
b0c0f41f62 | ||
|
|
d2c5fb6512 | ||
|
|
b89cf4ec23 | ||
|
|
74b646af9e | ||
|
|
0cf485f7a9 | ||
|
|
ea63e4f66e | ||
|
|
58eb5fbd5f | ||
|
|
a8c543ef7b | ||
|
|
64e198a57a | ||
|
|
de332b9e6b | ||
|
|
60eeff292d | ||
|
|
52f30b306a | ||
|
|
6df0f81ca0 | ||
|
|
40b3c1502d | ||
|
|
a61265effe | ||
|
|
92efea6b76 | ||
|
|
216b3e2179 | ||
|
|
293a92f486 | ||
|
|
6bad2b5352 | ||
|
|
a09a638b9c | ||
|
|
24f57270ed | ||
|
|
1b7964ce98 | ||
|
|
5a510882b8 | ||
|
|
601ed03198 | ||
|
|
90fe4570b9 | ||
|
|
92fc8e838f | ||
|
|
89a3020c5e | ||
|
|
fc3e47b67e | ||
|
|
b2a80ca314 | ||
|
|
14655f19a0 | ||
|
|
47092c47db | ||
|
|
8e6fce8d63 | ||
|
|
3c514e3418 | ||
|
|
8a425b100b | ||
|
|
eb942cfedd | ||
|
|
0a7fc06108 | ||
|
|
0ae35afa76 | ||
|
|
a2156e73bf | ||
|
|
9fe77f3c28 | ||
|
|
6f078afafd | ||
|
|
15f4aa613e |
18
README.md
18
README.md
@@ -227,6 +227,14 @@ The **Trains Agent** will first try to pull jobs from the `important_jobs` queue
|
||||
|
||||
Adding queues, managing job order within a queue and moving jobs between queues, is available using the Web UI, see example on our [open server](https://demoapp.trains.allegro.ai/workers-and-queues/queues)
|
||||
|
||||
#### Stopping the Trains Agent
|
||||
|
||||
To stop a **Trains Agent** running in the background, run the same command line used to start the agent with `--stop` appended.
|
||||
For example, to stop the first of the above shown same machine, single gpu agents:
|
||||
```bash
|
||||
trains-agent daemon --detached --gpus 0 --queue default --docker nvidia/cuda --stop
|
||||
```
|
||||
|
||||
## How do I create an experiment on the Trains Server? <a name="from-scratch"></a>
|
||||
* Integrate [Trains](https://github.com/allegroai/trains) with your code
|
||||
* Execute the code on your machine (Manually / PyCharm / Jupyter Notebook)
|
||||
@@ -272,18 +280,18 @@ trains-agent daemon --services-mode --detached --queue services --create-queue -
|
||||
## AutoML and Orchestration Pipelines <a name="automl-pipes"></a>
|
||||
The Trains Agent can also be used to implement AutoML orchestration and Experiment Pipelines in conjunction with the Trains package.
|
||||
|
||||
Sample AutoML & Orchestration examples can be found in the Trains [example/automl](https://github.com/allegroai/trains/tree/master/examples/automl) folder.
|
||||
Sample AutoML & Orchestration examples can be found in the Trains [example/automation](https://github.com/allegroai/trains/tree/master/examples/automation) folder.
|
||||
|
||||
AutoML examples
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/automl/automl_base_template_keras_simple.py)
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/optimization/hyper-parameter-optimization/base_template_keras_simple.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automl/automl_random_search_example.py)
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automation/manual_random_param_search_example.py)
|
||||
- This example will create multiple copies of the Keras experiment-template, with different hyper-parameter combinations
|
||||
|
||||
Experiment Pipeline examples
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/task_piping_example.py)
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/task_piping_example.py)
|
||||
- This example will "process data", and once done, will launch a copy of the 'second step' experiment-template
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/toy_base_task.py)
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/toy_base_task.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
|
||||
## License
|
||||
|
||||
@@ -62,6 +62,8 @@ agent {
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["pytorch", "conda-forge", ]
|
||||
# conda_full_env_update: false
|
||||
# conda_env_as_base_docker: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
75
examples/k8s_glue_example.py
Normal file
75
examples/k8s_glue_example.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
This example assumes you have preconfigured services with selectors in the form of
|
||||
"ai.allegro.agent.serial=pod-<number>" and a targetPort of 10022.
|
||||
The K8sIntegration component will label each pod accordingly.
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from trains_agent.glue.k8s import K8sIntegration
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--queue", type=str, help="Queue to pull tasks from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ports-mode", action='store_true', default=False,
|
||||
help="Ports-Mode will add a label to the pod which can be used as service, in order to expose ports"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-of-services", type=int, default=20,
|
||||
help="Specify the number of k8s services to be used. Use only with ports-mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-port", type=int,
|
||||
help="Used in conjunction with ports-mode, specifies the base port exposed by the services. "
|
||||
"For pod #X, the port will be <base-port>+X"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gateway-address", type=str, default=None,
|
||||
help="Used in conjunction with ports-mode, specify the external address of the k8s ingress / ELB"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pod-trains-conf", type=str,
|
||||
help="Configuration file to be used by the pod itself (if not provided, current configuration is used)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overrides-yaml", type=str,
|
||||
help="YAML file containing pod overrides to be used when launching a new pod"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-yaml", type=str,
|
||||
help="YAML file containing pod template. If provided pod will be scheduled with kubectl apply "
|
||||
"and overrides are ignored, otherwise it will be scheduled with kubectl run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssh-server-port", type=int, default=0,
|
||||
help="If non-zero, every pod will also start an SSH server on the selected port (default: zero, not active)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
user_props_cb = None
|
||||
if args.ports_mode and args.base_port:
|
||||
def k8s_user_props_cb(pod_number=0):
|
||||
user_prop = {"k8s-pod-port": args.base_port + pod_number}
|
||||
if args.gateway_address:
|
||||
user_prop["k8s-gateway-address"] = args.gateway_address
|
||||
return user_prop
|
||||
user_props_cb = k8s_user_props_cb
|
||||
|
||||
k8s = K8sIntegration(
|
||||
ports_mode=args.ports_mode, num_of_services=args.num_of_services, user_props_cb=user_props_cb,
|
||||
overrides_yaml=args.overrides_yaml, trains_conf_file=args.pod_trains_conf, template_yaml=args.template_yaml,
|
||||
extra_bash_init_script=K8sIntegration.get_ssh_server_bash(
|
||||
ssh_port_number=args.ssh_server_port) if args.ssh_server_port else None
|
||||
)
|
||||
k8s.k8s_daemon(args.queue)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,20 +1,20 @@
|
||||
attrs>=18.0
|
||||
enum34>=0.9 ; python_version < '3.6'
|
||||
furl>=2.0.0
|
||||
future>=0.16.0
|
||||
humanfriendly>=2.1
|
||||
jsonschema>=2.6.0
|
||||
pathlib2>=2.3.0
|
||||
psutil>=3.4.2
|
||||
pyhocon>=0.3.38
|
||||
pyparsing>=2.0.3
|
||||
python-dateutil>=2.4.2
|
||||
pyjwt>=1.6.4
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.20.0
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
typing>=3.6.4
|
||||
urllib3>=1.21.1
|
||||
attrs>=18.0,<20.4.0
|
||||
enum34>=0.9,<1.2.0 ; python_version < '3.6'
|
||||
furl>=2.0.0,<2.2.0
|
||||
future>=0.16.0,<0.19.0
|
||||
humanfriendly>=2.1,<9.2
|
||||
jsonschema>=2.6.0,<3.3.0
|
||||
pathlib2>=2.3.0,<2.4.0
|
||||
psutil>=3.4.2,<5.9.0
|
||||
pyhocon>=0.3.38,<0.4.0
|
||||
pyparsing>=2.0.3,<2.5.0
|
||||
python-dateutil>=2.4.2,<2.9.0
|
||||
pyjwt>=1.6.4,<2.0.0
|
||||
PyYAML>=3.12,<5.4.0
|
||||
requests-file>=1.4.2,<1.6.0
|
||||
requests>=2.20.0,<2.26.0
|
||||
six>=1.11.0,<1.16.0
|
||||
tqdm>=4.19.5,<4.55.0
|
||||
typing>=3.6.4,<3.8.0
|
||||
urllib3>=1.21.1,<1.27.0
|
||||
virtualenv>=16,<20
|
||||
|
||||
@@ -47,6 +47,10 @@
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["defaults", "conda-forge", "pytorch", ]
|
||||
|
||||
# If set to true, Task's "installed packages" are ignored,
|
||||
# and the repository's "requirements.txt" is used instead
|
||||
# force_repo_requirements_txt: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
@@ -144,6 +148,16 @@
|
||||
# "(which {python_single_digit} && {python_single_digit} -m pip --version) || apt-get install -y {python_single_digit}-pip",
|
||||
# ]
|
||||
|
||||
# set the preprocessing bash script to execute at the startup of any docker.
|
||||
# all lines will be executed regardless of their exit code.
|
||||
# docker_preprocess_bash_script = [
|
||||
# "echo \"starting docker\"",
|
||||
#]
|
||||
|
||||
# If False replace \r with \n and display full console output
|
||||
# default is True, report a single \r line in a sequence of consecutive lines, per 5 seconds.
|
||||
# suppress_carriage_return: true
|
||||
|
||||
# cuda versions used for solving pytorch wheel packages
|
||||
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
|
||||
# cuda_version: 10.1
|
||||
|
||||
@@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
ENV_CONDA_ENV_PACKAGE = EnvEntry("TRAINS_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
|
||||
|
||||
@@ -12,7 +12,7 @@ import sys
|
||||
import shutil
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
@@ -30,14 +30,11 @@ from pathlib2 import Path
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from six.moves.urllib.parse import quote
|
||||
|
||||
from trains_agent.backend_api.services import queues
|
||||
from trains_agent.backend_config.defs import UptimeConf
|
||||
from trains_agent.helper.check_update import start_check_update_daemon
|
||||
from trains_agent.commands.base import resolve_names, ServiceCommandSection
|
||||
from trains_agent.definitions import (
|
||||
WORKER_ALREADY_REGISTERED,
|
||||
ENVIRONMENT_SDK_PARAMS,
|
||||
INVALID_WORKER_ID,
|
||||
PROGRAM_NAME,
|
||||
DEFAULT_VENV_UPDATE_URL,
|
||||
ENV_TASK_EXECUTE_AS_USER,
|
||||
@@ -93,7 +90,7 @@ from trains_agent.helper.process import (
|
||||
get_docker_id,
|
||||
commit_docker, terminate_process,
|
||||
)
|
||||
from trains_agent.helper.package.priority_req import PriorityPackageRequirement
|
||||
from trains_agent.helper.package.priority_req import PriorityPackageRequirement, PackageCollectorRequirement
|
||||
from trains_agent.helper.repo import clone_repository_cached, RepoInfo, VCS
|
||||
from trains_agent.helper.resource_monitor import ResourceMonitor
|
||||
from trains_agent.helper.runtime_verification import check_runtime, print_uptime_properties
|
||||
@@ -124,7 +121,24 @@ class LiteralScriptManager(object):
|
||||
if not script:
|
||||
return False
|
||||
diff = script.diff
|
||||
return diff and not diff.strip().lower().startswith("diff ")
|
||||
if not diff:
|
||||
return False
|
||||
|
||||
# test git diff prefix
|
||||
if diff.lstrip().lower().startswith("diff "):
|
||||
return False
|
||||
|
||||
# test git submodule prefix
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if diff.lstrip().lower().startswith("submodule ") and \
|
||||
diff.splitlines()[1].lstrip().lower().startswith("diff "):
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# none of the above
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def write(task, directory, entry_point=None):
|
||||
@@ -267,7 +281,7 @@ class TaskStopSignal(object):
|
||||
)
|
||||
return TaskStopReason.stopped
|
||||
|
||||
if status in self.unexpected_statuses: ## and "worker" not in message:
|
||||
if status in self.unexpected_statuses: # ## and "worker" not in message:
|
||||
self.command.log("unexpected status change, task will terminate")
|
||||
return TaskStopReason.status_changed
|
||||
|
||||
@@ -308,6 +322,8 @@ class Worker(ServiceCommandSection):
|
||||
PriorityPackageRequirement,
|
||||
PostRequirement,
|
||||
ExternalRequirements,
|
||||
partial(PackageCollectorRequirement, collect_package=['trains']),
|
||||
partial(PackageCollectorRequirement, collect_package=['clearml']),
|
||||
)
|
||||
|
||||
# poll queues every _polling_interval seconds
|
||||
@@ -379,7 +395,7 @@ class Worker(ServiceCommandSection):
|
||||
self.temp_config_path = None
|
||||
self.queues = ()
|
||||
self.venv_folder = None # type: Optional[Text]
|
||||
self.package_api = None # type: PackageManager
|
||||
self.package_api = None # type: Optional[PackageManager]
|
||||
self.global_package_api = None
|
||||
|
||||
self.is_venv_update = self._session.config.agent.venv_update.enabled
|
||||
@@ -398,6 +414,7 @@ class Worker(ServiceCommandSection):
|
||||
self._redirected_stdout_file_no = None
|
||||
self._uptime_config = self._session.config.get("agent.uptime", None)
|
||||
self._downtime_config = self._session.config.get("agent.downtime", None)
|
||||
self._suppress_cr = self._session.config.get("agent.suppress_carriage_return", True)
|
||||
|
||||
# True - supported
|
||||
# None - not initialized
|
||||
@@ -512,7 +529,7 @@ class Worker(ServiceCommandSection):
|
||||
full_docker_cmd = self.docker_image_func(docker_image=docker_image, docker_arguments=docker_arguments)
|
||||
try:
|
||||
self._session.send_api(
|
||||
tasks_api.EditRequest(task_id, force=True, execution=dict(
|
||||
tasks_api.EditRequest(task_id, force=True, execution=dict( # noqa
|
||||
docker_cmd=' '.join([docker_image] + docker_arguments) if docker_arguments else docker_image)))
|
||||
except Exception:
|
||||
pass
|
||||
@@ -524,6 +541,10 @@ class Worker(ServiceCommandSection):
|
||||
'--full-monitoring' if self._services_mode else '--disable-monitoring',
|
||||
'--standalone-mode' if self._standalone_mode else '',
|
||||
task_id)
|
||||
|
||||
# send the actual used command line to the backend
|
||||
self.send_logs(task_id=task_id, lines=['Executing: {}\n'.format(full_docker_cmd)], level="INFO")
|
||||
|
||||
cmd = Argv(*full_docker_cmd)
|
||||
print('Running Docker:\n{}\n'.format(str(cmd)))
|
||||
else:
|
||||
@@ -983,22 +1004,27 @@ class Worker(ServiceCommandSection):
|
||||
stop_signal=None, # type: Optional[TaskStopSignal]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> Tuple[Optional[int], TaskStopReason]
|
||||
def _print_file(file_path, prev_line_count):
|
||||
# type: (...) -> Tuple[Optional[int], Optional[TaskStopReason]]
|
||||
def _print_file(file_path, prev_pos=0):
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(prev_pos)
|
||||
binary_text = f.read()
|
||||
if not binary_text:
|
||||
return []
|
||||
pos = f.tell()
|
||||
# skip the previously printed lines,
|
||||
blines = binary_text.split(b'\n')[prev_line_count:]
|
||||
blines = binary_text.split(b'\n') if binary_text else []
|
||||
if not blines:
|
||||
return blines
|
||||
return decode_binary_lines(blines if blines[-1] else blines[:-1])
|
||||
return blines, pos
|
||||
return (
|
||||
decode_binary_lines(blines if blines[-1] else blines[:-1],
|
||||
replace_cr=not self._suppress_cr,
|
||||
overwrite_cr=self._suppress_cr),
|
||||
pos
|
||||
)
|
||||
|
||||
stdout = open(stdout_path, "wt")
|
||||
stderr = open(stderr_path, "wt") if stderr_path else stdout
|
||||
stdout_line_count, stdout_last_lines = 0, []
|
||||
stderr_line_count, stderr_last_lines = 0, []
|
||||
stdout_line_count, stdout_pos_count, stdout_last_lines = 0, 0, []
|
||||
stderr_line_count, stderr_pos_count, stderr_last_lines = 0, 0, []
|
||||
service_mode_internal_agent_started = None
|
||||
stopping = False
|
||||
status = None
|
||||
@@ -1037,7 +1063,7 @@ class Worker(ServiceCommandSection):
|
||||
stderr.flush()
|
||||
|
||||
# get diff from previous poll
|
||||
printed_lines = _print_file(stdout_path, stdout_line_count)
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
if self._services_mode and not stopping and not status:
|
||||
# if the internal agent started, we stop logging, it will take over logging.
|
||||
# if the internal agent started running the task itself, it will return status==0,
|
||||
@@ -1047,13 +1073,10 @@ class Worker(ServiceCommandSection):
|
||||
if status is not None:
|
||||
stop_reason = 'Service started'
|
||||
|
||||
stdout_line_count += self.send_logs(
|
||||
task_id, printed_lines
|
||||
)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(
|
||||
task_id, _print_file(stderr_path, stderr_line_count)
|
||||
)
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
|
||||
except subprocess.CalledProcessError as ex:
|
||||
# non zero return code
|
||||
@@ -1064,9 +1087,11 @@ class Worker(ServiceCommandSection):
|
||||
raise
|
||||
except Exception:
|
||||
# we should not get here, but better safe than sorry
|
||||
stdout_line_count += self.send_logs(task_id, _print_file(stdout_path, stdout_line_count))
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(task_id, _print_file(stderr_path, stderr_line_count))
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
stop_reason = 'Exception occurred'
|
||||
status = -1
|
||||
|
||||
@@ -1080,13 +1105,11 @@ class Worker(ServiceCommandSection):
|
||||
stderr.close()
|
||||
|
||||
# Send last lines
|
||||
stdout_line_count += self.send_logs(
|
||||
task_id, _print_file(stdout_path, stdout_line_count)
|
||||
)
|
||||
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
|
||||
stdout_line_count += self.send_logs(task_id, printed_lines)
|
||||
if stderr_path:
|
||||
stderr_line_count += self.send_logs(
|
||||
task_id, _print_file(stderr_path, stderr_line_count)
|
||||
)
|
||||
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
|
||||
stderr_line_count += self.send_logs(task_id, printed_lines)
|
||||
|
||||
return status, stop_reason
|
||||
|
||||
@@ -1165,7 +1188,8 @@ class Worker(ServiceCommandSection):
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
raise ValueError("Failed applying git diff:\n{}\n\nERROR! Failed applying git diff, see diff above.".format(diff))
|
||||
raise ValueError("Failed applying git diff:\n{}\n\n"
|
||||
"ERROR! Failed applying git diff, see diff above.".format(diff))
|
||||
|
||||
@resolve_names
|
||||
def build(
|
||||
@@ -1190,10 +1214,15 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
execution = self.get_execution_info(current_task)
|
||||
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
|
||||
requirements = None
|
||||
print("[package_manager.force_repo_requirements_txt=true] "
|
||||
"Skipping requirements, using repository \"requirements.txt\" ")
|
||||
else:
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
requirements = None
|
||||
|
||||
if not python_version:
|
||||
try:
|
||||
@@ -1204,8 +1233,8 @@ class Worker(ServiceCommandSection):
|
||||
except:
|
||||
python_version = None
|
||||
|
||||
venv_folder, requirements_manager = self.install_virtualenv(venv_dir=target,
|
||||
requested_python_version=python_version)
|
||||
venv_folder, requirements_manager = self.install_virtualenv(
|
||||
venv_dir=target, requested_python_version=python_version, execution_info=execution)
|
||||
|
||||
if self._default_pip:
|
||||
if install_globally and self.global_package_api:
|
||||
@@ -1360,7 +1389,8 @@ class Worker(ServiceCommandSection):
|
||||
print("Cloning task id={}".format(task_id))
|
||||
current_task = self._session.api_client.tasks.get_by_id(
|
||||
self._session.send_api(
|
||||
tasks_api.CloneRequest(task=current_task.id, new_task_name='Clone of {}'.format(current_task.name))
|
||||
tasks_api.CloneRequest(task=current_task.id,
|
||||
new_task_name='Clone of {}'.format(current_task.name))
|
||||
).id
|
||||
)
|
||||
print("Task cloned, new task id={}".format(current_task.id))
|
||||
@@ -1410,10 +1440,15 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
execution = self.get_execution_info(current_task)
|
||||
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
if self._session.config.get("agent.package_manager.force_repo_requirements_txt", False):
|
||||
requirements = None
|
||||
print("[package_manager.force_repo_requirements_txt=true] "
|
||||
"Skipping requirements, using repository \"requirements.txt\" ")
|
||||
else:
|
||||
try:
|
||||
requirements = current_task.script.requirements
|
||||
except AttributeError:
|
||||
requirements = None
|
||||
|
||||
try:
|
||||
python_ver = current_task.script.binary
|
||||
@@ -1423,8 +1458,8 @@ class Worker(ServiceCommandSection):
|
||||
except:
|
||||
python_ver = None
|
||||
|
||||
venv_folder, requirements_manager = self.install_virtualenv(standalone_mode=standalone_mode,
|
||||
requested_python_version=python_ver)
|
||||
venv_folder, requirements_manager = self.install_virtualenv(
|
||||
standalone_mode=standalone_mode, requested_python_version=python_ver, execution_info=execution)
|
||||
|
||||
if not standalone_mode:
|
||||
if self._default_pip:
|
||||
@@ -1449,7 +1484,10 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
# do not update the task packages if we are using conda,
|
||||
# it will most likely make the task environment unreproducible
|
||||
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None,
|
||||
skip_freeze_update = self.is_conda and not self._session.config.get(
|
||||
"agent.package_manager.conda_full_env_update", False)
|
||||
|
||||
freeze = self.freeze_task_environment(current_task.id if not skip_freeze_update else None,
|
||||
requirements_manager=requirements_manager)
|
||||
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
|
||||
|
||||
@@ -1503,8 +1541,7 @@ class Worker(ServiceCommandSection):
|
||||
self._update_commit_id(current_task.id, execution, repo_info)
|
||||
|
||||
# Add the script CWD to the python path
|
||||
python_path = get_python_path(script_dir, execution.entry_point, self.package_api) \
|
||||
if not self.is_conda else None
|
||||
python_path = get_python_path(script_dir, execution.entry_point, self.package_api, is_conda_env=self.is_conda)
|
||||
if os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH):
|
||||
python_path = add_python_path(python_path, os.environ.get(ENV_TASK_EXTRA_PYTHON_PATH))
|
||||
if python_path:
|
||||
@@ -1831,7 +1868,8 @@ class Worker(ServiceCommandSection):
|
||||
package_api.out_of_scope_install_package('Cython')
|
||||
|
||||
cached_requirements_failed = False
|
||||
if cached_requirements and ('pip' in cached_requirements or 'conda' in cached_requirements):
|
||||
if cached_requirements and (cached_requirements.get('pip') is not None or
|
||||
cached_requirements.get('conda') is not None):
|
||||
self.log("Found task requirements section, trying to install")
|
||||
try:
|
||||
package_api.load_requirements(cached_requirements)
|
||||
@@ -2001,8 +2039,9 @@ class Worker(ServiceCommandSection):
|
||||
)
|
||||
)
|
||||
|
||||
def install_virtualenv(self, venv_dir=None, requested_python_version=None, standalone_mode=False):
|
||||
# type: (str, str, bool) -> Tuple[Path, RequirementsManager]
|
||||
def install_virtualenv(
|
||||
self, venv_dir=None, requested_python_version=None, standalone_mode=False, execution_info=None):
|
||||
# type: (str, str, bool, ExecutionInfo) -> Tuple[Path, RequirementsManager]
|
||||
"""
|
||||
Install a new python virtual environment, removing the old one if exists
|
||||
:return: virtualenv directory and requirements manager to use with task
|
||||
@@ -2049,6 +2088,7 @@ class Worker(ServiceCommandSection):
|
||||
python=executable_version_suffix if self.is_conda else executable_name,
|
||||
path=venv_dir,
|
||||
requirements_manager=requirements_manager,
|
||||
execution_info=execution_info,
|
||||
)
|
||||
|
||||
global_package_manager_params = dict(
|
||||
@@ -2147,20 +2187,26 @@ class Worker(ServiceCommandSection):
|
||||
mounted_pip_dl_dir = '/root/.trains/pip-download-cache'
|
||||
mounted_vcs_cache = '/root/.trains/vcs-cache'
|
||||
mounted_venv_dir = '/root/.trains/venvs-builds'
|
||||
host_cache = Path(os.path.expandvars(self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
|
||||
host_pip_dl = Path(os.path.expandvars(self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
|
||||
host_vcs_cache = Path(os.path.expandvars(self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
|
||||
host_cache = Path(os.path.expandvars(
|
||||
self._session.config["sdk.storage.cache.default_base_dir"])).expanduser().as_posix()
|
||||
host_pip_dl = Path(os.path.expandvars(
|
||||
self._session.config["agent.pip_download_cache.path"])).expanduser().as_posix()
|
||||
host_vcs_cache = Path(os.path.expandvars(
|
||||
self._session.config["agent.vcs_cache.path"])).expanduser().as_posix()
|
||||
temp_config.put("sdk.storage.cache.default_base_dir", mounted_cache_dir)
|
||||
temp_config.put("agent.pip_download_cache.path", mounted_pip_dl_dir)
|
||||
temp_config.put("agent.vcs_cache.path", mounted_vcs_cache)
|
||||
temp_config.put("agent.package_manager.system_site_packages", True)
|
||||
temp_config.put("agent.package_manager.conda_env_as_base_docker", False)
|
||||
temp_config.put("agent.default_python", "")
|
||||
temp_config.put("agent.python_binary", "")
|
||||
temp_config.put("agent.cuda_version", "")
|
||||
temp_config.put("agent.cudnn_version", "")
|
||||
temp_config.put("agent.venvs_dir", mounted_venv_dir)
|
||||
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or self._session.config.get("agent.git_user", None)))
|
||||
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or self._session.config.get("agent.git_pass", None)))
|
||||
temp_config.put("agent.git_user", (ENV_AGENT_GIT_USER.get() or
|
||||
self._session.config.get("agent.git_user", None)))
|
||||
temp_config.put("agent.git_pass", (ENV_AGENT_GIT_PASS.get() or
|
||||
self._session.config.get("agent.git_pass", None)))
|
||||
|
||||
host_apt_cache = Path(os.path.expandvars(self._session.config.get(
|
||||
"agent.docker_apt_cache", '~/.trains/apt-cache'))).expanduser().as_posix()
|
||||
@@ -2208,6 +2254,7 @@ class Worker(ServiceCommandSection):
|
||||
extra_shell_script_str = " ; ".join(map(str, cmds)) + " ; "
|
||||
|
||||
bash_script = self._session.config.get("agent.docker_init_bash_script", None)
|
||||
preprocess_bash_script = self._session.config.get("agent.docker_preprocess_bash_script", None)
|
||||
|
||||
self.temp_config_path = self.temp_config_path or safe_mkstemp(
|
||||
suffix=".cfg", prefix=".trains_agent.", text=True, name_only=True
|
||||
@@ -2228,6 +2275,7 @@ class Worker(ServiceCommandSection):
|
||||
standalone_mode=self._standalone_mode,
|
||||
force_current_version=self._force_current_version,
|
||||
bash_script=bash_script,
|
||||
preprocess_bash_script=preprocess_bash_script,
|
||||
)
|
||||
return temp_config, partial(docker_cmd_functor, docker_cmd, temp_config)
|
||||
|
||||
@@ -2241,7 +2289,8 @@ class Worker(ServiceCommandSection):
|
||||
host_pip_dl, mounted_pip_dl,
|
||||
host_vcs_cache, mounted_vcs_cache,
|
||||
standalone_mode=False, extra_docker_arguments=None, extra_shell_script=None,
|
||||
force_current_version=None, host_git_credentials=None, bash_script=None):
|
||||
force_current_version=None, host_git_credentials=None,
|
||||
bash_script=None, preprocess_bash_script=None):
|
||||
docker = 'docker'
|
||||
|
||||
base_cmd = [docker, 'run', '-t']
|
||||
@@ -2258,7 +2307,7 @@ class Worker(ServiceCommandSection):
|
||||
if os.environ.get('TRAINS_DOCKER_SKIP_GPUS_FLAG', None):
|
||||
dockers_nvidia_visible_devices = gpu_devices
|
||||
else:
|
||||
base_cmd += ['--gpus', 'device='+gpu_devices, ]
|
||||
base_cmd += ['--gpus', '\"device={}\"'.format(gpu_devices), ]
|
||||
# We are using --gpu, so we should not pass NVIDIA_VISIBLE_DEVICES, I think.
|
||||
# base_cmd += ['-e', 'NVIDIA_VISIBLE_DEVICES=' + gpu_devices, ]
|
||||
elif gpu_devices.strip() == 'none':
|
||||
@@ -2275,7 +2324,8 @@ class Worker(ServiceCommandSection):
|
||||
base_cmd += [str(a) for a in extra_docker_arguments if a]
|
||||
|
||||
# check if running inside a kubernetes
|
||||
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and os.environ.get('KUBERNETES_PORT')):
|
||||
if ENV_DOCKER_HOST_MOUNT.get() or (os.environ.get('KUBERNETES_SERVICE_HOST') and
|
||||
os.environ.get('KUBERNETES_PORT')):
|
||||
# map network to sibling docker, unless we have other network argument
|
||||
if not any(a.strip().startswith('--network') for a in base_cmd):
|
||||
try:
|
||||
@@ -2297,7 +2347,7 @@ class Worker(ServiceCommandSection):
|
||||
print('Warning: K8S mount missing, ignoring cached folder {}'.format(m))
|
||||
host_mounts[i] = None
|
||||
else:
|
||||
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt)
|
||||
host_mounts[i] = m.replace(k8s_pod_mnt, k8s_node_mnt, 1)
|
||||
host_apt_cache, host_pip_cache, host_pip_dl, host_cache, host_vcs_cache = host_mounts
|
||||
|
||||
# copy the configuration file into the mounted folder
|
||||
@@ -2320,6 +2370,8 @@ class Worker(ServiceCommandSection):
|
||||
raise ValueError('Error: could not copy .ssh directory into: {}'.format(new_ssh_cache))
|
||||
|
||||
base_cmd += ['-e', 'TRAINS_WORKER_ID='+worker_id, ]
|
||||
# update the docker image, so the system knows where it runs
|
||||
base_cmd += ['-e', 'TRAINS_DOCKER_IMAGE={} {}'.format(docker_image, ' '.join(docker_arguments)).strip()]
|
||||
|
||||
# if we are running a RC version, install the same version in the docker
|
||||
# because the default latest, will be a release version (not RC)
|
||||
@@ -2343,21 +2395,33 @@ class Worker(ServiceCommandSection):
|
||||
|
||||
if not standalone_mode:
|
||||
if not bash_script:
|
||||
# Find the highest python version installed, or install from apt-get
|
||||
# python+pip is the requirement to match
|
||||
bash_script = [
|
||||
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
"chown -R root /root/.cache/pip",
|
||||
"export DEBIAN_FRONTEND=noninteractive",
|
||||
"apt-get update",
|
||||
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
"(which {python_single_digit} && {python_single_digit} -m pip --version) || " +
|
||||
"apt-get install -y {python_single_digit}-pip",
|
||||
"declare LOCAL_PYTHON",
|
||||
"for i in {{10..5}}; do which {python_single_digit}.$i && " +
|
||||
"{python_single_digit}.$i -m pip --version && " +
|
||||
"export LOCAL_PYTHON=$(which {python_single_digit}.$i) && break ; done",
|
||||
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y {python_single_digit}-pip",
|
||||
]
|
||||
|
||||
if preprocess_bash_script:
|
||||
bash_script = preprocess_bash_script + bash_script
|
||||
|
||||
docker_bash_script = " ; ".join(bash_script) if not isinstance(bash_script, str) else bash_script
|
||||
|
||||
# make sure that if we do not have $LOCAL_PYTHON defined
|
||||
# we set it to python3
|
||||
update_scheme += (
|
||||
docker_bash_script + " ; " +
|
||||
"{python} -m pip install -U \"pip{pip_version}\" ; " +
|
||||
"{python} -m pip install -U {trains_agent_wheel} ; ").format(
|
||||
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON={python} ; " +
|
||||
"$LOCAL_PYTHON -m pip install -U \"pip{pip_version}\" ; " +
|
||||
"$LOCAL_PYTHON -m pip install -U {trains_agent_wheel} ; ").format(
|
||||
python_single_digit=python_version.split('.')[0],
|
||||
python=python_version, pip_version=PackageManager.get_pip_version(),
|
||||
trains_agent_wheel=trains_agent_wheel)
|
||||
@@ -2378,7 +2442,7 @@ class Worker(ServiceCommandSection):
|
||||
update_scheme +
|
||||
extra_shell_script +
|
||||
"cp {} {} ; ".format(DOCKER_ROOT_CONF_FILE, DOCKER_DEFAULT_CONF_FILE) +
|
||||
"NVIDIA_VISIBLE_DEVICES={nv_visible} {python} -u -m trains_agent ".format(
|
||||
"NVIDIA_VISIBLE_DEVICES={nv_visible} $LOCAL_PYTHON -u -m trains_agent ".format(
|
||||
nv_visible=dockers_nvidia_visible_devices, python=python_version)
|
||||
])
|
||||
|
||||
@@ -2409,7 +2473,8 @@ class Worker(ServiceCommandSection):
|
||||
os.setuid(self.uid)
|
||||
|
||||
# create a home folder for our user
|
||||
trains_agent_home = self._run_as_user_home + '{}'.format('.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
|
||||
trains_agent_home = self._run_as_user_home + '{}'.format(
|
||||
'.'+str(Singleton.get_slot()) if Singleton.get_slot() else '')
|
||||
try:
|
||||
home_folder = self._run_as_user_home
|
||||
rm_tree(home_folder)
|
||||
@@ -2470,7 +2535,7 @@ class Worker(ServiceCommandSection):
|
||||
# Iterate over all running process
|
||||
for pid, uid, slot, file in sorted(Singleton.get_running_pids(), key=lambda x: x[1] or ''):
|
||||
# wither we have a match for the worker_id or we just pick the first one
|
||||
if pid >= 0 and (
|
||||
if pid >= 0 and uid is not None and (
|
||||
(worker_id and uid == worker_id) or
|
||||
(not worker_id and uid.startswith('{}:'.format(worker_name)))):
|
||||
# this is us kill it
|
||||
|
||||
@@ -1,45 +1,115 @@
|
||||
from __future__ import print_function, division, unicode_literals
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import Text, List
|
||||
|
||||
from pyhocon import HOCONConverter
|
||||
|
||||
from trains_agent.commands.events import Events
|
||||
from trains_agent.commands.worker import Worker
|
||||
from trains_agent.errors import APIError
|
||||
from trains_agent.helper.base import safe_remove_file
|
||||
from trains_agent.helper.dicts import merge_dicts
|
||||
from trains_agent.helper.process import get_bash_output
|
||||
from trains_agent.helper.resource_monitor import ResourceMonitor
|
||||
from trains_agent.interface.base import ObjectID
|
||||
|
||||
|
||||
class K8sIntegration(Worker):
|
||||
K8S_PENDING_QUEUE = "k8s_scheduler"
|
||||
|
||||
KUBECTL_RUN_CMD = "kubectl run trains_id_{task_id} " \
|
||||
KUBECTL_APPLY_CMD = "kubectl apply -f"
|
||||
|
||||
KUBECTL_RUN_CMD = "kubectl run trains-{queue_name}-id-{task_id} " \
|
||||
"--image {docker_image} " \
|
||||
"--restart=Never --replicas=1 " \
|
||||
"--generator=run-pod/v1"
|
||||
"--generator=run-pod/v1 " \
|
||||
"--namespace=trains"
|
||||
|
||||
KUBECTL_DELETE_CMD = "kubectl delete pods " \
|
||||
"--selector=TRAINS=agent " \
|
||||
"--field-selector=status.phase!=Pending,status.phase!=Running"
|
||||
"--field-selector=status.phase!=Pending,status.phase!=Running " \
|
||||
"--namespace=trains"
|
||||
|
||||
CONTAINER_BASH_SCRIPT = "apt-get install -y git python-pip && " \
|
||||
"pip install trains-agent && " \
|
||||
"python -u -m trains_agent execute --full-monitoring --require-queue --id {}"
|
||||
BASH_INSTALL_SSH_CMD = [
|
||||
"apt-get install -y openssh-server",
|
||||
"mkdir -p /var/run/sshd",
|
||||
"echo 'root:training' | chpasswd",
|
||||
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config",
|
||||
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config",
|
||||
r"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd",
|
||||
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY' >> /etc/ssh/sshd_config",
|
||||
'echo "export VISIBLE=now" >> /etc/profile',
|
||||
'echo "export PATH=$PATH" >> /etc/profile',
|
||||
'echo "ldconfig" >> /etc/profile',
|
||||
"/usr/sbin/sshd -p {port}"]
|
||||
|
||||
def __init__(self, k8s_pending_queue_name=None, kubectl_cmd=None, container_bash_script=None, debug=False):
|
||||
CONTAINER_BASH_SCRIPT = [
|
||||
"export DEBIAN_FRONTEND='noninteractive'",
|
||||
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
"chown -R root /root/.cache/pip",
|
||||
"apt-get update",
|
||||
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
"declare LOCAL_PYTHON",
|
||||
"for i in {{10..5}}; do which python3.$i && python3.$i -m pip --version && "
|
||||
"export LOCAL_PYTHON=$(which python3.$i) && break ; done",
|
||||
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip",
|
||||
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3",
|
||||
"$LOCAL_PYTHON -m pip install trains-agent",
|
||||
"{extra_bash_init_cmd}",
|
||||
"$LOCAL_PYTHON -m trains_agent execute --full-monitoring --require-queue --id {task_id}"
|
||||
]
|
||||
|
||||
AGENT_LABEL = "TRAINS=agent"
|
||||
LIMIT_POD_LABEL = "ai.allegro.agent.serial=pod-{pod_number}"
|
||||
|
||||
_edit_hyperparams_version = "2.9"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k8s_pending_queue_name=None,
|
||||
kubectl_cmd=None,
|
||||
container_bash_script=None,
|
||||
debug=False,
|
||||
ports_mode=False,
|
||||
num_of_services=20,
|
||||
user_props_cb=None,
|
||||
overrides_yaml=None,
|
||||
template_yaml=None,
|
||||
trains_conf_file=None,
|
||||
extra_bash_init_script=None,
|
||||
):
|
||||
"""
|
||||
Initialize the k8s integration glue layer daemon
|
||||
|
||||
:param str k8s_pending_queue_name: queue name to use when task is pending in the k8s scheduler
|
||||
:param str|callable kubectl_cmd: kubectl command line str, supports formating (default: KUBECTL_RUN_CMD)
|
||||
:param str|callable kubectl_cmd: kubectl command line str, supports formatting (default: KUBECTL_RUN_CMD)
|
||||
example: "task={task_id} image={docker_image} queue_id={queue_id}"
|
||||
or a callable function: kubectl_cmd(task_id, docker_image, queue_id, task_data)
|
||||
:param str container_bash_script: container bash script to be executed in k8s (default: CONTAINER_BASH_SCRIPT)
|
||||
Notice this string will use format() call, if you have curly brackets they should be doubled { -> {{
|
||||
Format arguments passed: {task_id} and {extra_bash_init_cmd}
|
||||
:param bool debug: Switch logging on
|
||||
:param bool ports_mode: Adds a label to each pod which can be used in services in order to expose ports.
|
||||
Requires the `num_of_services` parameter.
|
||||
:param int num_of_services: Number of k8s services configured in the cluster. Required if `port_mode` is True.
|
||||
(default: 20)
|
||||
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
|
||||
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
|
||||
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
|
||||
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
|
||||
:param str template_yaml: YAML file containing the template for the pod (optional).
|
||||
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
|
||||
:param str trains_conf_file: trains.conf file to be use by the pod itself (optional)
|
||||
:param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container
|
||||
"""
|
||||
super(K8sIntegration, self).__init__()
|
||||
self.k8s_pending_queue_name = k8s_pending_queue_name or self.K8S_PENDING_QUEUE
|
||||
@@ -51,12 +121,88 @@ class K8sIntegration(Worker):
|
||||
if debug:
|
||||
self.log.logger.disabled = False
|
||||
self.log.logger.setLevel(logging.INFO)
|
||||
self.ports_mode = ports_mode
|
||||
self.num_of_services = num_of_services
|
||||
self._edit_hyperparams_support = None
|
||||
self._user_props_cb = user_props_cb
|
||||
self.trains_conf_file = None
|
||||
self.overrides_json_string = None
|
||||
self.template_dict = None
|
||||
self.extra_bash_init_script = extra_bash_init_script or None
|
||||
if self.extra_bash_init_script and not isinstance(self.extra_bash_init_script, str):
|
||||
self.extra_bash_init_script = ' ; '.join(self.extra_bash_init_script) # noqa
|
||||
self.pod_limits = []
|
||||
self.pod_requests = []
|
||||
if overrides_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(overrides_yaml))), 'rt') as f:
|
||||
overrides = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
if overrides:
|
||||
containers = overrides.get('spec', {}).get('containers', [])
|
||||
for c in containers:
|
||||
resources = {str(k).lower(): v for k, v in c.get('resources', {}).items()}
|
||||
if not resources:
|
||||
continue
|
||||
if resources.get('limits'):
|
||||
self.pod_limits += ['{}={}'.format(k, v) for k, v in resources['limits'].items()]
|
||||
if resources.get('requests'):
|
||||
self.pod_requests += ['{}={}'.format(k, v) for k, v in resources['requests'].items()]
|
||||
# remove double entries
|
||||
self.pod_limits = list(set(self.pod_limits))
|
||||
self.pod_requests = list(set(self.pod_requests))
|
||||
if self.pod_limits or self.pod_requests:
|
||||
self.log.warning('Found pod container requests={} limits={}'.format(
|
||||
self.pod_limits, self.pod_requests))
|
||||
if containers:
|
||||
self.log.warning('Removing containers section: {}'.format(overrides['spec'].pop('containers')))
|
||||
self.overrides_json_string = json.dumps(overrides)
|
||||
if template_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(template_yaml))), 'rt') as f:
|
||||
self.template_dict = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
|
||||
def run_one_task(self, queue: Text, task_id: Text, worker_args=None):
|
||||
if trains_conf_file:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(trains_conf_file))), 'rt') as f:
|
||||
self.trains_conf_file = f.read()
|
||||
# make sure we use system packages!
|
||||
self.trains_conf_file += '\nagent.package_manager.system_site_packages=true\n'
|
||||
|
||||
def _set_task_user_properties(self, task_id: str, **properties: str):
|
||||
if self._edit_hyperparams_support is not True:
|
||||
# either not supported or never tested
|
||||
if self._edit_hyperparams_support == self._session.api_version:
|
||||
# tested against latest api_version, not supported
|
||||
return
|
||||
if not self._session.check_min_api_version(self._edit_hyperparams_version):
|
||||
# not supported due to insufficient api_version
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
return
|
||||
try:
|
||||
self._session.get(
|
||||
service="tasks",
|
||||
action="edit_hyper_params",
|
||||
task=task_id,
|
||||
hyperparams=[
|
||||
{
|
||||
"section": "properties",
|
||||
"name": k,
|
||||
"value": str(v),
|
||||
}
|
||||
for k, v in properties.items()
|
||||
],
|
||||
)
|
||||
# definitely supported
|
||||
self._runtime_props_support = True
|
||||
except APIError as error:
|
||||
if error.code == 404:
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
|
||||
def run_one_task(self, queue: Text, task_id: Text, worker_args=None, **_):
|
||||
print('Pulling task {} launching on kubernetes cluster'.format(task_id))
|
||||
task_data = self._session.api_client.tasks.get_all(id=[task_id])[0]
|
||||
|
||||
# push task into the k8s queue, so we have visibility on pending tasks in the k8s scheduler
|
||||
try:
|
||||
print('Pushing task {} into temporary pending queue'.format(task_id))
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(task_id, queue=self.k8s_pending_queue_name,
|
||||
status_reason='k8s pending scheduler')
|
||||
except Exception as e:
|
||||
@@ -65,36 +211,231 @@ class K8sIntegration(Worker):
|
||||
return
|
||||
|
||||
if task_data.execution.docker_cmd:
|
||||
docker_image = task_data.execution.docker_cmd
|
||||
docker_parts = task_data.execution.docker_cmd
|
||||
else:
|
||||
docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
|
||||
docker_parts = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
|
||||
self._session.config.get("agent.default_docker.image", "nvidia/cuda"))
|
||||
|
||||
# take the first part, this is the docker image name (not arguments)
|
||||
docker_image = docker_image.split()[0]
|
||||
docker_parts = docker_parts.split()
|
||||
docker_image = docker_parts[0]
|
||||
docker_args = docker_parts[1:] if len(docker_parts) > 1 else []
|
||||
|
||||
create_trains_conf = "echo '{}' >> ~/trains.conf && ".format(
|
||||
HOCONConverter.to_hocon(self._session.config._config))
|
||||
# get the trains.conf encoded file
|
||||
# noinspection PyProtectedMember
|
||||
hocon_config_encoded = (self.trains_conf_file or self._session._config_file).encode('ascii')
|
||||
create_trains_conf = "echo '{}' | base64 --decode >> ~/trains.conf".format(
|
||||
base64.b64encode(
|
||||
hocon_config_encoded
|
||||
).decode('ascii')
|
||||
)
|
||||
|
||||
if callable(self.kubectl_cmd):
|
||||
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data)
|
||||
if self.ports_mode:
|
||||
print("Kubernetes looking for available pod to use")
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
queue_name = self._session.api_client.queues.get_by_id(queue=queue).name
|
||||
except Exception:
|
||||
queue_name = 'k8s'
|
||||
|
||||
# conform queue name to k8s standards
|
||||
safe_queue_name = queue_name.lower().strip()
|
||||
safe_queue_name = re.sub(r'\W+', '', safe_queue_name).replace('_', '').replace('-', '')
|
||||
|
||||
# Search for a free pod number
|
||||
pod_number = 1
|
||||
while self.ports_mode:
|
||||
kubectl_cmd_new = "kubectl get pods -l {pod_label},{agent_label} -n trains".format(
|
||||
pod_label=self.LIMIT_POD_LABEL.format(pod_number=pod_number),
|
||||
agent_label=self.AGENT_LABEL
|
||||
)
|
||||
process = subprocess.Popen(kubectl_cmd_new.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
output = '' if not output else output if isinstance(output, str) else output.decode('utf-8')
|
||||
error = '' if not error else error if isinstance(error, str) else error.decode('utf-8')
|
||||
|
||||
if not output:
|
||||
# No such pod exist so we can use the pod_number we found
|
||||
break
|
||||
if pod_number >= self.num_of_services:
|
||||
# All pod numbers are taken, exit
|
||||
self.log.warning(
|
||||
"kubectl last result: {}\n{}\nAll k8s services are in use, task '{}' "
|
||||
"will be enqueued back to queue '{}'".format(
|
||||
error, output, task_id, queue
|
||||
)
|
||||
)
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(
|
||||
task_id, queue=queue, status_reason='k8s max pod limit (no free k8s service)')
|
||||
return
|
||||
pod_number += 1
|
||||
|
||||
labels = ([self.LIMIT_POD_LABEL.format(pod_number=pod_number)] if self.ports_mode else []) + [self.AGENT_LABEL]
|
||||
|
||||
if self.ports_mode:
|
||||
print("Kubernetes scheduling task id={} on pod={}".format(task_id, pod_number))
|
||||
else:
|
||||
kubectl_cmd = self.kubectl_cmd.format(task_id=task_id, docker_image=docker_image, queue_id=queue)
|
||||
print("Kubernetes scheduling task id={}".format(task_id))
|
||||
|
||||
# make sure we gave a list
|
||||
if self.template_dict:
|
||||
output, error = self._kubectl_apply(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image, docker_args=docker_args,
|
||||
task_id=task_id, queue=queue, queue_name=safe_queue_name)
|
||||
else:
|
||||
output, error = self._kubectl_run(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image,
|
||||
task_data=task_data,
|
||||
task_id=task_id, queue=queue, queue_name=safe_queue_name)
|
||||
|
||||
error = '' if not error else (error if isinstance(error, str) else error.decode('utf-8'))
|
||||
output = '' if not output else (output if isinstance(output, str) else output.decode('utf-8'))
|
||||
print('kubectl output:\n{}\n{}'.format(error, output))
|
||||
if error:
|
||||
self.log.error("Running kubectl encountered an error: {}".format(error))
|
||||
|
||||
user_props = {"k8s-queue": str(queue_name)}
|
||||
if self.ports_mode:
|
||||
user_props.update({"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]})
|
||||
|
||||
if self._user_props_cb:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb()
|
||||
user_props.update(custom_props)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if user_props:
|
||||
self._set_task_user_properties(
|
||||
task_id=task_id,
|
||||
**user_props
|
||||
)
|
||||
|
||||
def _parse_docker_args(self, docker_args):
|
||||
# type: (list) -> dict
|
||||
kube_args = {'env': []}
|
||||
while docker_args:
|
||||
cmd = docker_args.pop().strip()
|
||||
if cmd in ('-e', '--env',):
|
||||
env = docker_args.pop().strip()
|
||||
key, value = env.split('=', 1)
|
||||
kube_args[key] += {key: value}
|
||||
else:
|
||||
self.log.warning('skipping docker argument {} (only -e --env supported)'.format(cmd))
|
||||
return kube_args
|
||||
|
||||
def _kubectl_apply(self, create_trains_conf, docker_image, docker_args, labels, queue, task_id, queue_name):
|
||||
template = deepcopy(self.template_dict)
|
||||
template.setdefault('apiVersion', 'v1')
|
||||
template['kind'] = 'Pod'
|
||||
template.setdefault('metadata', {})
|
||||
name = 'trains-{queue}-id-{task_id}'.format(queue=queue_name, task_id=task_id)
|
||||
template['metadata']['name'] = name
|
||||
template.setdefault('spec', {})
|
||||
template['spec'].setdefault('containers', [])
|
||||
if labels:
|
||||
labels_dict = dict(pair.split('=', 1) for pair in labels)
|
||||
template['metadata'].setdefault('labels', {})
|
||||
template['metadata']['labels'].update(labels_dict)
|
||||
container = self._parse_docker_args(docker_args)
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
|
||||
script_encoded = '\n'.join(
|
||||
['#!/bin/bash', ] +
|
||||
[line.format(extra_bash_init_cmd=self.extra_bash_init_script or '', task_id=task_id)
|
||||
for line in container_bash_script])
|
||||
|
||||
create_init_script = \
|
||||
"echo '{}' | base64 --decode >> ~/__start_agent__.sh ; " \
|
||||
"/bin/bash ~/__start_agent__.sh".format(
|
||||
base64.b64encode(
|
||||
script_encoded.encode('ascii')
|
||||
).decode('ascii'))
|
||||
|
||||
container = merge_dicts(
|
||||
container,
|
||||
dict(name=name, image=docker_image,
|
||||
command=['/bin/bash'],
|
||||
args=['-c', '{} ; {}'.format(create_trains_conf, create_init_script)])
|
||||
)
|
||||
|
||||
if template['spec']['containers']:
|
||||
template['spec']['containers'][0] = merge_dicts(template['spec']['containers'][0], container)
|
||||
else:
|
||||
template['spec']['containers'].append(container)
|
||||
|
||||
fp, yaml_file = tempfile.mkstemp(prefix='trains_k8stmpl_', suffix='.yml')
|
||||
os.close(fp)
|
||||
with open(yaml_file, 'wt') as f:
|
||||
yaml.dump(template, f)
|
||||
|
||||
kubectl_cmd = self.KUBECTL_APPLY_CMD.format(
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue,
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
kubectl_cmd += ["--labels=TRAINS=agent", "--command", "--", "/bin/sh", "-c",
|
||||
create_trains_conf + self.container_bash_script.format(task_id)]
|
||||
# add the template file at the end
|
||||
kubectl_cmd += [yaml_file]
|
||||
try:
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
except Exception as ex:
|
||||
return None, str(ex)
|
||||
finally:
|
||||
safe_remove_file(yaml_file)
|
||||
|
||||
return output, error
|
||||
|
||||
def _kubectl_run(self, create_trains_conf, docker_image, labels, queue, task_data, task_id, queue_name):
|
||||
if callable(self.kubectl_cmd):
|
||||
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data, queue_name)
|
||||
else:
|
||||
kubectl_cmd = self.kubectl_cmd.format(
|
||||
queue_name=queue_name,
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
if self.overrides_json_string:
|
||||
kubectl_cmd += ['--overrides=' + self.overrides_json_string]
|
||||
|
||||
if self.pod_limits:
|
||||
kubectl_cmd += ['--limits', ",".join(self.pod_limits)]
|
||||
if self.pod_requests:
|
||||
kubectl_cmd += ['--requests', ",".join(self.pod_requests)]
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
container_bash_script = ' ; '.join(container_bash_script)
|
||||
|
||||
kubectl_cmd += [
|
||||
"--labels=" + ",".join(labels),
|
||||
"--command",
|
||||
"--",
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"{} ; {}".format(create_trains_conf, container_bash_script.format(
|
||||
extra_bash_init_cmd=self.extra_bash_init_script, task_id=task_id)),
|
||||
]
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
self.log.info("K8s scheduling experiment task id={}".format(task_id))
|
||||
if error:
|
||||
self.log.error("Running kubectl encountered an error: {}".format(
|
||||
error if isinstance(error, str) else error.decode()))
|
||||
return output, error
|
||||
|
||||
def run_tasks_loop(self, queues: List[Text], worker_params):
|
||||
def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
|
||||
"""
|
||||
:summary: Pull and run tasks from queues.
|
||||
:description: 1. Go through ``queues`` by order.
|
||||
@@ -108,6 +449,7 @@ class K8sIntegration(Worker):
|
||||
events_service = self.get_service(Events)
|
||||
|
||||
# make sure we have a k8s pending queue
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._session.api_client.queues.create(self.k8s_pending_queue_name)
|
||||
except Exception:
|
||||
@@ -119,7 +461,7 @@ class K8sIntegration(Worker):
|
||||
while True:
|
||||
# iterate over queues (priority style, queues[0] is highest)
|
||||
for queue in queues:
|
||||
# delete old completed /failed pods
|
||||
# delete old completed / failed pods
|
||||
get_bash_output(self.KUBECTL_DELETE_CMD)
|
||||
|
||||
# get next task in queue
|
||||
@@ -155,15 +497,20 @@ class K8sIntegration(Worker):
|
||||
if self._session.config["agent.reload_config"]:
|
||||
self.reload_config()
|
||||
|
||||
def k8s_daemon(self, queues):
|
||||
def k8s_daemon(self, queue):
|
||||
"""
|
||||
Start the k8s Glue service.
|
||||
This service will be pulling tasks from *queues* and scheduling them for execution using kubectl.
|
||||
This service will be pulling tasks from *queue* and scheduling them for execution using kubectl.
|
||||
Notice all scheduled tasks are pushed back into K8S_PENDING_QUEUE,
|
||||
and popped when execution actually starts. This creates full visibility into the k8s scheduler.
|
||||
Manually popping a task from the K8S_PENDING_QUEUE,
|
||||
will cause the k8s scheduler to skip the execution once the scheduled tasks needs to be executed
|
||||
|
||||
:param list(str) queues: List of queue names to pull from
|
||||
:param list(str) queue: queue name to pull from
|
||||
"""
|
||||
return self.daemon(queues=queues, log_level=logging.INFO, foreground=True, docker=False)
|
||||
return self.daemon(queues=[ObjectID(name=queue)] if queue else None,
|
||||
log_level=logging.INFO, foreground=True, docker=False)
|
||||
|
||||
@classmethod
|
||||
def get_ssh_server_bash(cls, ssh_port_number):
|
||||
return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD)
|
||||
|
||||
@@ -197,7 +197,8 @@ def safe_remove_tree(filename):
|
||||
pass
|
||||
|
||||
|
||||
def get_python_path(script_dir, entry_point, package_api):
|
||||
def get_python_path(script_dir, entry_point, package_api, is_conda_env=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
python_path_cmd = package_api.get_python_command(
|
||||
@@ -209,9 +210,9 @@ def get_python_path(script_dir, entry_point, package_api):
|
||||
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
|
||||
python_path_sep=python_path_sep)
|
||||
if is_windows_platform():
|
||||
return python_path.replace('/', '\\') + org_python_path
|
||||
python_path = python_path.replace('/', '\\')
|
||||
|
||||
return python_path + org_python_path
|
||||
return python_path if is_conda_env else (python_path + org_python_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -554,6 +555,7 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
branch = nullable_string
|
||||
version_num = nullable_string
|
||||
tag = nullable_string
|
||||
docker_cmd = nullable_string
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task_info):
|
||||
@@ -571,6 +573,12 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
execution.entry_point = entry_point
|
||||
execution.working_dir = working_dir or ""
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
execution.docker_cmd = task_info.execution.docker_cmd
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from attr import attrs, attrib
|
||||
|
||||
import six
|
||||
from six import binary_type, text_type
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort
|
||||
|
||||
|
||||
def print_text(text, newline=True):
|
||||
@@ -22,15 +22,21 @@ def print_text(text, newline=True):
|
||||
sys.stdout.write(data)
|
||||
|
||||
|
||||
def decode_binary_lines(binary_lines, encoding='utf-8'):
|
||||
def decode_binary_lines(binary_lines, encoding='utf-8', replace_cr=False, overwrite_cr=False):
|
||||
# decode per line, if we failed decoding skip the line
|
||||
lines = []
|
||||
for b in binary_lines:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
l = b.decode(encoding=encoding, errors='replace').replace('\r', '\n')
|
||||
except:
|
||||
l = ''
|
||||
lines.append(l + '\n' if l and l[-1] != '\n' else l)
|
||||
line = b.decode(encoding=encoding, errors='replace')
|
||||
if replace_cr:
|
||||
line = line.replace('\r', '\n')
|
||||
elif overwrite_cr:
|
||||
cr_lines = line.split('\r')
|
||||
line = cr_lines[-1] if cr_lines[-1] or len(cr_lines) < 2 else cr_lines[-2]
|
||||
except Exception:
|
||||
line = ''
|
||||
lines.append(line + '\n' if not line or line[-1] != '\n' else line)
|
||||
return lines
|
||||
|
||||
|
||||
|
||||
@@ -3,3 +3,15 @@ from typing import Callable, Dict, Any
|
||||
|
||||
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
|
||||
return {key: value for key, value in dct.items() if filter_(key)}
|
||||
|
||||
|
||||
def merge_dicts(dict1, dict2):
|
||||
""" Recursively merges dict2 into dict1 """
|
||||
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
|
||||
return dict2
|
||||
for k in dict2:
|
||||
if k in dict1:
|
||||
dict1[k] = merge_dicts(dict1[k], dict2[k])
|
||||
else:
|
||||
dict1[k] = dict2[k]
|
||||
return dict1
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import contextmanager
|
||||
from typing import Text, Iterable, Union
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
|
||||
from trains_agent.helper.process import Executable, Argv, PathLike
|
||||
|
||||
|
||||
@@ -66,7 +66,8 @@ class PackageManager(object):
|
||||
pass
|
||||
|
||||
def upgrade_pip(self):
|
||||
result = self._install("pip"+self.get_pip_version(), "--upgrade")
|
||||
result = self._install(
|
||||
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
|
||||
packages = self.run_with_env(('list',), output=True).splitlines()
|
||||
# p.split is ('pip', 'x.y.z')
|
||||
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
|
||||
|
||||
@@ -2,8 +2,9 @@ from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import os
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
@@ -18,13 +19,14 @@ from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.helper.package.requirements import SimpleVersion
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
from .requirements import RequirementsManager, MarkerRequirement
|
||||
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
|
||||
|
||||
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
|
||||
|
||||
@@ -40,8 +42,8 @@ def _package_diff(path, packages):
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe") \
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
@@ -61,8 +63,8 @@ class CondaAPI(PackageManager):
|
||||
|
||||
MINIMUM_VERSION = "4.3.30"
|
||||
|
||||
def __init__(self, session, path, python, requirements_manager):
|
||||
# type: (Session, PathLike, float, RequirementsManager) -> None
|
||||
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
|
||||
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
|
||||
"""
|
||||
:param python: base python version to use (e.g python3.6)
|
||||
:param path: path of env
|
||||
@@ -72,7 +74,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = None
|
||||
self.requirements_manager = requirements_manager
|
||||
self.path = path
|
||||
self.env_read_only = False
|
||||
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
|
||||
self.conda_env_as_base_docker = \
|
||||
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
|
||||
bool(ENV_CONDA_ENV_PACKAGE.get())
|
||||
if ENV_CONDA_ENV_PACKAGE.get():
|
||||
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
|
||||
else:
|
||||
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
@@ -80,10 +90,15 @@ class CondaAPI(PackageManager):
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
self.conda = (
|
||||
find_executable("conda")
|
||||
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
|
||||
)
|
||||
try:
|
||||
self.conda = (
|
||||
find_executable("conda") or
|
||||
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
|
||||
shell=select_for_platform(windows=True, linux=False)).strip()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("ERROR: package manager \"conda\" selected, "
|
||||
"but \'conda\' executable could not be located")
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
@@ -111,13 +126,58 @@ class CondaAPI(PackageManager):
|
||||
def bin(self):
|
||||
return self.pip.bin
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
def upgrade_pip(self):
|
||||
return self._install("pip" + self.pip.get_pip_version())
|
||||
# do not change pip version if pre built environement is used
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping pip upgrade.')
|
||||
return ''
|
||||
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create a new environment
|
||||
"""
|
||||
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
|
||||
if Path(self.conda_pre_build_env_path).is_dir():
|
||||
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
self.path = Path(self.conda_pre_build_env_path)
|
||||
self.source = ("conda", "activate", self.path.as_posix())
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
self.env_read_only = True
|
||||
return self
|
||||
elif Path(self.conda_pre_build_env_path).is_file():
|
||||
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
tar_path = find_executable("tar")
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
output = Argv(
|
||||
tar_path,
|
||||
"-xzf",
|
||||
self.conda_pre_build_env_path,
|
||||
"-C",
|
||||
self.path,
|
||||
).get_output()
|
||||
|
||||
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
# unpack cleanup
|
||||
print("Fixing prefix in Conda environment {}".format(self.path))
|
||||
CommandSequence(('source', conda_env.as_posix()),
|
||||
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
|
||||
return self
|
||||
else:
|
||||
raise ValueError("Could not restore Conda environment, cannot find {}".format(
|
||||
self.conda_pre_build_env_path))
|
||||
|
||||
output = Argv(
|
||||
self.conda,
|
||||
"create",
|
||||
@@ -133,13 +193,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = self.pip.source = (
|
||||
tuple(match.group(1).split()) + (match.group(2),)
|
||||
if match
|
||||
else ("activate", self.path)
|
||||
else ("conda", "activate", self.path.as_posix())
|
||||
)
|
||||
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
|
||||
conda_env = self._get_conda_sh()
|
||||
if conda_env.is_file() and not is_windows_platform():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
|
||||
if cuda_version > 0:
|
||||
@@ -181,6 +243,10 @@ class CondaAPI(PackageManager):
|
||||
|
||||
def _install(self, *args):
|
||||
# type: (*PathLike) -> ()
|
||||
# if we are in read only mode, do not install anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
|
||||
return
|
||||
channels_args = tuple(
|
||||
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
|
||||
)
|
||||
@@ -208,6 +274,10 @@ class CondaAPI(PackageManager):
|
||||
return self._install(*packages)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
|
||||
return ''
|
||||
return self._run_command(("uninstall", "-p", self.path))
|
||||
|
||||
def install_from_file(self, path):
|
||||
@@ -226,23 +296,158 @@ class CondaAPI(PackageManager):
|
||||
with self.temp_file("pip_reqs", pip_packages) as reqs:
|
||||
self.pip.install_from_file(reqs)
|
||||
|
||||
def freeze(self):
|
||||
def freeze(self, freeze_full_environment=False):
|
||||
requirements = self.pip.freeze()
|
||||
req_lines = []
|
||||
conda_lines = []
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
conda_packages_txt = []
|
||||
requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
|
||||
for pkg in conda_packages:
|
||||
# skip if this is a pypi package or it is not a python package at all
|
||||
if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
|
||||
pip_lines = requirements['pip']
|
||||
conda_packages_json = json.loads(
|
||||
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
for r in conda_packages_json:
|
||||
# check if this is a pypi package, if it is, leave it outside
|
||||
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||
name = (r['name'].replace('-', '_'), r['name'])
|
||||
pip_req_line = [l for l in pip_lines
|
||||
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
|
||||
if pip_req_line and \
|
||||
('@' not in pip_req_line[0] or
|
||||
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
|
||||
req_lines.append(pip_req_line[0])
|
||||
continue
|
||||
|
||||
req_lines.append(
|
||||
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
|
||||
continue
|
||||
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
|
||||
requirements['conda'] = conda_packages_txt
|
||||
except:
|
||||
|
||||
# check if we have it in our required packages
|
||||
name = r['name']
|
||||
# hack support pytorch/torch different naming convention
|
||||
if name == 'pytorch':
|
||||
name = 'torch'
|
||||
# skip over packages with _
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
|
||||
# make sure we see the conda packages, put them into the pip as well
|
||||
if conda_lines:
|
||||
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
|
||||
|
||||
requirements['pip'] = req_lines
|
||||
requirements['conda'] = conda_lines
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if freeze_full_environment:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(
|
||||
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
|
||||
conda_env_json.pop('name', None)
|
||||
conda_env_json.pop('prefix', None)
|
||||
conda_env_json.pop('channels', None)
|
||||
requirements['conda_env_json'] = json.dumps(conda_env_json)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return requirements
|
||||
|
||||
def _load_conda_full_env(self, conda_env_dict, requirements):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except Exception:
|
||||
cuda_version = 0
|
||||
|
||||
conda_env_dict['channels'] = self.extra_channels
|
||||
if 'dependencies' not in conda_env_dict:
|
||||
conda_env_dict['dependencies'] = []
|
||||
new_dependencies = OrderedDict()
|
||||
pip_requirements = None
|
||||
for line in conda_env_dict['dependencies']:
|
||||
if isinstance(line, dict):
|
||||
pip_requirements = line.pop('pip', None)
|
||||
continue
|
||||
name = line.strip().split('=', 1)[0].lower()
|
||||
if name == 'pip':
|
||||
continue
|
||||
elif name == 'python':
|
||||
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
|
||||
elif name == 'tensorflow-gpu' and cuda_version == 0:
|
||||
line = 'tensorflow={}'.format(line.split('=')[1])
|
||||
elif name == 'tensorflow' and cuda_version > 0:
|
||||
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
|
||||
elif name in ('cupti', 'cudnn'):
|
||||
# cudatoolkit should pull them based on the cudatoolkit version
|
||||
continue
|
||||
elif name.startswith('_'):
|
||||
continue
|
||||
new_dependencies[line.split('=', 1)[0].strip()] = line
|
||||
|
||||
# fix packages:
|
||||
conda_env_dict['dependencies'] = list(new_dependencies.values())
|
||||
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if bad_req:
|
||||
print('failed installing the following conda packages: {}'.format(bad_req))
|
||||
return False
|
||||
|
||||
if pip_requirements:
|
||||
# create a list of vcs packages that we need to replace in the pip section
|
||||
vcs_reqs = {}
|
||||
if 'pip' in requirements:
|
||||
pip_lines = requirements['pip'].splitlines() \
|
||||
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
|
||||
for line in pip_lines:
|
||||
try:
|
||||
marker = list(parse(line))
|
||||
except Exception:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
if m.vcs:
|
||||
vcs_reqs[m.name] = m
|
||||
try:
|
||||
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
|
||||
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping requirements installation.')
|
||||
return None
|
||||
|
||||
# if we have a full conda environment, use it and pass the pip to pip
|
||||
if requirements.get('conda_env_json'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(requirements.get('conda_env_json'))
|
||||
print('Conda restoring full yaml environment')
|
||||
return self._load_conda_full_env(conda_env_json, requirements)
|
||||
except Exception:
|
||||
print('Could not load fully stored conda environment, falling back to requirements')
|
||||
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
@@ -276,6 +481,15 @@ class CondaAPI(PackageManager):
|
||||
if m.vcs:
|
||||
pip_requirements.append(m)
|
||||
continue
|
||||
# Skip over pip
|
||||
if m.name in ('pip', 'virtualenv', ):
|
||||
continue
|
||||
# python version, only major.minor
|
||||
if m.name == 'python' and m.specs:
|
||||
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
|
||||
if '.' not in m.specs[0][1]:
|
||||
continue
|
||||
|
||||
conda_supported_req_names.append(m.name.lower())
|
||||
if m.req.name.lower() == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
@@ -303,15 +517,20 @@ class CondaAPI(PackageManager):
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# skip over local files (we cannot change the version to a local file)
|
||||
if m.local_file:
|
||||
continue
|
||||
m_name = m.name.lower()
|
||||
if m_name in conda_supported_req_names:
|
||||
# this package is in the conda list,
|
||||
# make sure that if we changed version and we match it in conda
|
||||
conda_supported_req_names.remove(m_name)
|
||||
## conda_supported_req_names.remove(m_name)
|
||||
for cr in reqs:
|
||||
if m_name == cr.name.lower():
|
||||
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
|
||||
# match versions
|
||||
cr.specs = m.specs
|
||||
# # conda always likes "-" not "_" but only on pypi packages
|
||||
# cr.name = cr.name.lower().replace('_', '-')
|
||||
break
|
||||
else:
|
||||
# not in conda, it is a pip package
|
||||
@@ -319,29 +538,39 @@ class CondaAPI(PackageManager):
|
||||
if m_name == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
|
||||
# remove any leftover conda packages (they were removed from the pip list)
|
||||
if conda_supported_req_names:
|
||||
reqs = [r for r in reqs if r.name.lower() not in conda_supported_req_names]
|
||||
|
||||
# Conda requirements Hacks:
|
||||
if has_matplotlib:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||
|
||||
# remove specific cudatoolkit, it should have being preinstalled.
|
||||
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
|
||||
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
|
||||
|
||||
if has_torch and cuda_version == 0:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||
|
||||
# make sure we have no double entries
|
||||
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
|
||||
|
||||
# conform conda packages (version/name)
|
||||
for r in reqs:
|
||||
# change _ to - in name but not the prefix _ (as this is conda prefix)
|
||||
if not r.name.startswith('_') and not requirements.get('conda', None):
|
||||
r.name = r.name.replace('_', '-')
|
||||
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
||||
if r.specs and r.specs[0]:
|
||||
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
||||
# conda always likes "-" not "_"
|
||||
r.req.name = r.req.name.replace('_', '-')
|
||||
|
||||
while reqs:
|
||||
# notice, we give conda more freedom in version selection, to help it choose best combination
|
||||
conda_env['dependencies'] = [r.tostr() for r in reqs]
|
||||
def clean_ver(ar):
|
||||
if not ar.specs:
|
||||
return ar.tostr()
|
||||
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
|
||||
return ar.tostr()
|
||||
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||
result = self._run_command(
|
||||
@@ -371,12 +600,15 @@ class CondaAPI(PackageManager):
|
||||
|
||||
if pip_requirements:
|
||||
try:
|
||||
pip_req_str = [r.tostr() for r in pip_requirements]
|
||||
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
self.pip.load_requirements('\n'.join(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
return True
|
||||
@@ -441,8 +673,22 @@ class CondaAPI(PackageManager):
|
||||
def get_python_command(self, extra=()):
|
||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||
|
||||
def _get_conda_sh(self):
|
||||
# type () -> Path
|
||||
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if base_conda_env.is_file():
|
||||
return base_conda_env
|
||||
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
|
||||
conda = find_executable("conda", path=path)
|
||||
if not conda:
|
||||
continue
|
||||
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
return conda_env
|
||||
return base_conda_env
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on unhashable exceptions
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
|
||||
exception = attrs(str=True, cmp=False)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session
|
||||
@@ -9,8 +11,8 @@ from ..requirements import RequirementsManager
|
||||
|
||||
|
||||
class VirtualenvPip(SystemPip, PackageManager):
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
|
||||
"""
|
||||
Program interface to virtualenv pip.
|
||||
Must be given either path to virtualenv or source command.
|
||||
|
||||
@@ -38,3 +38,38 @@ class PriorityPackageRequirement(SimpleSubstitution):
|
||||
return Text('')
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
|
||||
|
||||
class PackageCollectorRequirement(SimpleSubstitution):
|
||||
"""
|
||||
This RequirementSubstitution class will allow you to have multiple instances of the same
|
||||
package, it will output the last one (by order) to be actually used.
|
||||
"""
|
||||
name = tuple()
|
||||
|
||||
def __init__(self, session, collect_package):
|
||||
super(PackageCollectorRequirement, self).__init__(session)
|
||||
self._collect_packages = collect_package or tuple()
|
||||
self._last_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match package names
|
||||
return req.name and req.name.lower() in self._collect_packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
self._last_req = req.clone()
|
||||
return ''
|
||||
|
||||
def post_scan_add_req(self):
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
last_req = self._last_req
|
||||
self._last_req = None
|
||||
return last_req
|
||||
|
||||
@@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
|
||||
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
|
||||
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
|
||||
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
|
||||
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -117,20 +119,24 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version, nightly=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except:
|
||||
except Exception:
|
||||
cuda = 0
|
||||
|
||||
if nightly:
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch nightly CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# first check if key is valid
|
||||
@@ -138,13 +144,16 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
|
||||
# then try a new cuda version page
|
||||
torch_url = cls.page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
torch_url = cls.page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
|
||||
for k in keys:
|
||||
@@ -157,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
name = "torch"
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os_name = kwargs.pop("os_override", None)
|
||||
@@ -235,6 +244,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
py_ver = self.python_major_minor_str.replace('.', '')
|
||||
url = None
|
||||
last_v = None
|
||||
closest_v = None
|
||||
# search for our package
|
||||
for l in links_parser.links:
|
||||
parts = l.split('/')[-1].split('-')
|
||||
@@ -244,28 +254,40 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
continue
|
||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||
# version ignore .postX suffix (treat as regular version)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = str(parts[1].split('%')[0].split('+')[0])
|
||||
except Exception:
|
||||
continue
|
||||
if len(parts) < 3 or not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if len(parts) < 5 or platform_wheel not in parts[4]:
|
||||
continue
|
||||
# update the closest matched version (from above)
|
||||
if not closest_v:
|
||||
closest_v = v
|
||||
elif SimpleVersion.compare_versions(
|
||||
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
|
||||
SimpleVersion.compare_versions(
|
||||
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
|
||||
closest_v = v
|
||||
# check if this an actual match
|
||||
if not req.compare_version(v) or \
|
||||
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||
continue
|
||||
if not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if platform_wheel not in parts[4]:
|
||||
continue
|
||||
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
last_v = v
|
||||
# if we found an exact match, use it
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if req.specs[0][0] == '==' and \
|
||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url
|
||||
return url, last_v or closest_v
|
||||
|
||||
def get_url_for_platform(self, req):
|
||||
# check if package is already installed with system packages
|
||||
@@ -298,23 +320,28 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
# assert op == "=="
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||
while not url and torch_url_key > 0:
|
||||
previous_cuda_key = torch_url_key
|
||||
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
|
||||
req, previous_cuda_key, closest_matched_version))
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if url:
|
||||
break
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||
# never fallback to CPU
|
||||
if torch_url_key < 1:
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, self.cuda_version))
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, '
|
||||
'trying CUDA version {}'.format(req, previous_cuda_key, torch_url_key))
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
print(
|
||||
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError(
|
||||
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
|
||||
else:
|
||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
@@ -326,6 +353,8 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if url:
|
||||
# normalize url (sometimes we will get ../ which we should not...
|
||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||
# print found
|
||||
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
|
||||
|
||||
self.log.debug("checking url: %s", url)
|
||||
return url, requests.head(url, timeout=10).ok
|
||||
|
||||
@@ -4,7 +4,7 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
from itertools import chain, starmap
|
||||
from operator import itemgetter
|
||||
from os import path
|
||||
@@ -81,6 +81,9 @@ class MarkerRequirement(object):
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
def clone(self):
|
||||
return MarkerRequirement(copy(self.req))
|
||||
|
||||
__str__ = tostr
|
||||
|
||||
def __repr__(self):
|
||||
@@ -335,6 +338,14 @@ class RequirementSubstitution(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
return None
|
||||
|
||||
def post_install(self, session):
|
||||
pass
|
||||
|
||||
@@ -489,6 +500,14 @@ class RequirementsManager(object):
|
||||
)
|
||||
if not conda:
|
||||
result = map(self.translator.translate, result)
|
||||
|
||||
result = list(result)
|
||||
# add post scan add requirements call back
|
||||
for h in self.handlers:
|
||||
req = h.post_scan_add_req()
|
||||
if req:
|
||||
result.append(req.tostr())
|
||||
|
||||
return join_lines(result)
|
||||
|
||||
def post_install(self, session):
|
||||
|
||||
@@ -295,11 +295,9 @@ class VCS(object):
|
||||
if not self.session.config.agent.translate_ssh:
|
||||
return
|
||||
|
||||
ssh_agent_variable = "SSH_AUTH_SOCK"
|
||||
if not getenv(ssh_agent_variable) and (
|
||||
(ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and
|
||||
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None))
|
||||
):
|
||||
# if we have git_user / git_pass replace ssh credentials with https authentication
|
||||
if (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and \
|
||||
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)):
|
||||
# only apply to a specific domain (if requested)
|
||||
config_domain = \
|
||||
ENV_AGENT_GIT_HOST.get() or self.session.config.get("git_host", None)
|
||||
@@ -383,9 +381,10 @@ class VCS(object):
|
||||
"""
|
||||
Run command with `input_` as stdin
|
||||
"""
|
||||
input_ = input_.encode("latin1")
|
||||
if not input_.endswith(b"\n"):
|
||||
input_ += b"\n"
|
||||
input_ = input_.encode("utf-8")
|
||||
# always add extra empty line
|
||||
# (there is no downside, and it solves empty lines issue at end of patch cause corrupt message)
|
||||
input_ += b"\n"
|
||||
process = self._call_subprocess(
|
||||
subprocess.Popen,
|
||||
argv,
|
||||
@@ -535,7 +534,7 @@ class Git(VCS):
|
||||
root=Argv(executable_name, "rev-parse", "--show-toplevel"),
|
||||
)
|
||||
|
||||
patch_base = ("apply",)
|
||||
patch_base = ("apply", "--unidiff-zero", )
|
||||
|
||||
|
||||
class Hg(VCS):
|
||||
|
||||
@@ -4,6 +4,8 @@ from time import sleep
|
||||
from glob import glob
|
||||
from tempfile import gettempdir, NamedTemporaryFile
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from trains_agent.definitions import ENV_DOCKER_HOST_MOUNT
|
||||
from trains_agent.helper.base import warning
|
||||
|
||||
@@ -84,11 +86,12 @@ class Singleton(object):
|
||||
|
||||
@classmethod
|
||||
def get_running_pids(cls):
|
||||
# type: () -> List[Tuple[int, Optional[str], Optional[int], str]]
|
||||
temp_folder = cls._get_temp_folder()
|
||||
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
|
||||
pids = []
|
||||
for file in files:
|
||||
parts = file.split(cls.sep)
|
||||
parts = os.path.basename(file).split(cls.sep)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.1'
|
||||
__version__ = '0.16.3'
|
||||
|
||||
Reference in New Issue
Block a user