mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Compare commits
254 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd42423482 | ||
|
|
69eb25db1f | ||
|
|
a41ea52f87 | ||
|
|
259113c989 | ||
|
|
1afa3a3914 | ||
|
|
448e23825c | ||
|
|
b0c0f41f62 | ||
|
|
d2c5fb6512 | ||
|
|
b89cf4ec23 | ||
|
|
74b646af9e | ||
|
|
0cf485f7a9 | ||
|
|
ea63e4f66e | ||
|
|
58eb5fbd5f | ||
|
|
a8c543ef7b | ||
|
|
64e198a57a | ||
|
|
de332b9e6b | ||
|
|
60eeff292d | ||
|
|
52f30b306a | ||
|
|
6df0f81ca0 | ||
|
|
40b3c1502d | ||
|
|
a61265effe | ||
|
|
92efea6b76 | ||
|
|
216b3e2179 | ||
|
|
293a92f486 | ||
|
|
6bad2b5352 | ||
|
|
a09a638b9c | ||
|
|
24f57270ed | ||
|
|
1b7964ce98 | ||
|
|
5a510882b8 | ||
|
|
601ed03198 | ||
|
|
90fe4570b9 | ||
|
|
92fc8e838f | ||
|
|
89a3020c5e | ||
|
|
fc3e47b67e | ||
|
|
b2a80ca314 | ||
|
|
14655f19a0 | ||
|
|
47092c47db | ||
|
|
8e6fce8d63 | ||
|
|
3c514e3418 | ||
|
|
8a425b100b | ||
|
|
eb942cfedd | ||
|
|
0a7fc06108 | ||
|
|
0ae35afa76 | ||
|
|
a2156e73bf | ||
|
|
9fe77f3c28 | ||
|
|
6f078afafd | ||
|
|
15f4aa613e | ||
|
|
7cd9fa6c41 | ||
|
|
234d5fac2c | ||
|
|
6cbfb96ff8 | ||
|
|
6e54e55c31 | ||
|
|
3ff85b7b85 | ||
|
|
5640489f57 | ||
|
|
8135a6facf | ||
|
|
b6ae4f211d | ||
|
|
a56f032ec4 | ||
|
|
075736de20 | ||
|
|
d8543c892e | ||
|
|
ca0870b048 | ||
|
|
c7a739fafa | ||
|
|
7170296162 | ||
|
|
3bed0ef33c | ||
|
|
d419fa1e4f | ||
|
|
31a56c71bd | ||
|
|
28f47419b0 | ||
|
|
6a24da2849 | ||
|
|
782668fd21 | ||
|
|
aaf8d802e7 | ||
|
|
ca89a1e322 | ||
|
|
121dec2a62 | ||
|
|
4aacf9005e | ||
|
|
6b333202e9 | ||
|
|
ce6831368f | ||
|
|
e4111c830b | ||
|
|
52c1772b04 | ||
|
|
699d13bbb3 | ||
|
|
2c8d7d3d9a | ||
|
|
b13cc1e8e7 | ||
|
|
17d2bf2a3e | ||
|
|
94997f9c88 | ||
|
|
c6d998c4df | ||
|
|
f8ea445339 | ||
|
|
712efa208b | ||
|
|
09b6b6a9de | ||
|
|
98ff9a50e6 | ||
|
|
1f4d358316 | ||
|
|
f693fa165c | ||
|
|
c43084825c | ||
|
|
f1abee91dd | ||
|
|
c6b04edc34 | ||
|
|
50b847f4f7 | ||
|
|
1f53a06299 | ||
|
|
257dd95401 | ||
|
|
1736d205bb | ||
|
|
6fef58df6c | ||
|
|
473a8de8bb | ||
|
|
ff6272f48f | ||
|
|
1b5bcebd10 | ||
|
|
c4344d3afd | ||
|
|
45a44b087a | ||
|
|
c58ffdb9f8 | ||
|
|
54d9d77294 | ||
|
|
ce02385420 | ||
|
|
87ffd95eaa | ||
|
|
522dd85d7b | ||
|
|
3651c85fcd | ||
|
|
566427d550 | ||
|
|
cc99077c92 | ||
|
|
5f112447f7 | ||
|
|
22c5f043aa | ||
|
|
860ff8911c | ||
|
|
799b292146 | ||
|
|
fffe8e1c3f | ||
|
|
8245293f7f | ||
|
|
6563ce70c8 | ||
|
|
829b1d8f15 | ||
|
|
f6be64a4b5 | ||
|
|
21f6a73f66 | ||
|
|
77c4c79a2f | ||
|
|
2ad929fa00 | ||
|
|
53f511f536 | ||
|
|
7c87797a40 | ||
|
|
272fa07c29 | ||
|
|
6ce9cf7c2a | ||
|
|
abb30ac2b8 | ||
|
|
5bb257c46c | ||
|
|
c65b28ed92 | ||
|
|
fce8eb6782 | ||
|
|
9cb71b9526 | ||
|
|
38e02ca5cd | ||
|
|
06bfea80bc | ||
|
|
e660c7f2be | ||
|
|
fc28467080 | ||
|
|
8d47905982 | ||
|
|
a6a0b01f71 | ||
|
|
2b561f6066 | ||
|
|
61232d05dd | ||
|
|
b3418e4496 | ||
|
|
5ef627165c | ||
|
|
98a983d9a2 | ||
|
|
482007c4ce | ||
|
|
98198b8006 | ||
|
|
94bb11a81a | ||
|
|
4158d08f6f | ||
|
|
58ab67ea31 | ||
|
|
ea0ed4807e | ||
|
|
389600b91e | ||
|
|
5fb2550212 | ||
|
|
15e9e6b778 | ||
|
|
aa75b92e46 | ||
|
|
757210d5b3 | ||
|
|
00eb2f10ec | ||
|
|
3393372b9c | ||
|
|
f2d2d702de | ||
|
|
e3d0680d39 | ||
|
|
618c2ac5c4 | ||
|
|
0272c4c79c | ||
|
|
ff8cf63abf | ||
|
|
2c7c7f5b44 | ||
|
|
01f57c1e44 | ||
|
|
47bcd3839a | ||
|
|
0a3a8a1c52 | ||
|
|
231a907cff | ||
|
|
8f95eecf2e | ||
|
|
81008ee00e | ||
|
|
25bc44c0cf | ||
|
|
f838c8fc70 | ||
|
|
596093aac6 | ||
|
|
8f23f3b4c0 | ||
|
|
95d503afdd | ||
|
|
73ee33be99 | ||
|
|
ee3adf625f | ||
|
|
afec38a50e | ||
|
|
f9c60904f4 | ||
|
|
a09dc85c67 | ||
|
|
5d74f4b376 | ||
|
|
d558c66d3c | ||
|
|
714c6a05d0 | ||
|
|
43b2f7f41d | ||
|
|
28d752d568 | ||
|
|
6d091d8e08 | ||
|
|
5c6b3ccc94 | ||
|
|
df10e6ed46 | ||
|
|
8ef78fd058 | ||
|
|
640c83288a | ||
|
|
788c79a66f | ||
|
|
bef87c7744 | ||
|
|
f139891276 | ||
|
|
2afaff1713 | ||
|
|
a57a5b151c | ||
|
|
97f446d523 | ||
|
|
a88262c097 | ||
|
|
284271c654 | ||
|
|
ae2775f7b8 | ||
|
|
eb012f5c24 | ||
|
|
06897f7606 | ||
|
|
599219b02d | ||
|
|
b6e04ab982 | ||
|
|
98fe162878 | ||
|
|
f829d80a49 | ||
|
|
b7e568e299 | ||
|
|
6912846326 | ||
|
|
224868c9a4 | ||
|
|
b1ca90a303 | ||
|
|
dee2475698 | ||
|
|
aeede81474 | ||
|
|
2d91d4cde6 | ||
|
|
7a11c7c165 | ||
|
|
a9f479cfcd | ||
|
|
c1d91b0d6a | ||
|
|
cbfba6acb2 | ||
|
|
f2e2e1f94a | ||
|
|
23668a403a | ||
|
|
facbee0005 | ||
|
|
c486cfd09f | ||
|
|
119ecaa2e3 | ||
|
|
d6cc2be653 | ||
|
|
41d75df40c | ||
|
|
901c4be9ae | ||
|
|
966b14f914 | ||
|
|
847d35cbbb | ||
|
|
4022cb5c63 | ||
|
|
2b239829de | ||
|
|
402856656f | ||
|
|
7b94ff410c | ||
|
|
0a03dced50 | ||
|
|
ffe653afc6 | ||
|
|
8ce621cc44 | ||
|
|
7c0a2c4d50 | ||
|
|
5e063c9195 | ||
|
|
24329a21fe | ||
|
|
3a301b0b6c | ||
|
|
1f0bb4906b | ||
|
|
88f1031e5d | ||
|
|
fc2842c9a2 | ||
|
|
e9d3aab115 | ||
|
|
0ed7b2a0c8 | ||
|
|
bd73be928a | ||
|
|
79babdd149 | ||
|
|
02a21ba826 | ||
|
|
e7b6eb5c5a | ||
|
|
619ee3e8cf | ||
|
|
8b05bb1605 | ||
|
|
72c26499c0 | ||
|
|
741be2ae42 | ||
|
|
831b36c424 | ||
|
|
d715ec4b14 | ||
|
|
a7873705ec | ||
|
|
4a8f52b5a5 | ||
|
|
a5c7ff4ee1 | ||
|
|
c352c2711c | ||
|
|
2aea36c864 | ||
|
|
40b15dfe64 | ||
|
|
25e49df121 |
174
README.md
174
README.md
@@ -1,5 +1,5 @@
|
||||
# TRAINS Agent
|
||||
## Deep Learning DevOps For Everyone
|
||||
# Allegro Trains Agent
|
||||
## Deep Learning DevOps For Everyone - Now supporting all platforms (Linux, macOS, and Windows)
|
||||
|
||||
"All the Deep-Learning DevOps your research needs, and then some... Because ain't nobody got time for that"
|
||||
|
||||
@@ -8,27 +8,29 @@
|
||||
[](https://img.shields.io/pypi/v/trains-agent.svg)
|
||||
[](https://pypi.python.org/pypi/trains-agent/)
|
||||
|
||||
TRAINS Agent is an AI experiment cluster solution.
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
**Trains Agent is an AI experiment cluster solution.**
|
||||
|
||||
It is a zero configuration fire-and-forget execution agent, which combined with trains-server provides a full AI cluster solution.
|
||||
|
||||
**Full AutoML in 5 steps**
|
||||
1. Install the [TRAINS server](https://github.com/allegroai/trains-agent) (or use our [open server](https://demoapp.trains.allegro.ai))
|
||||
2. `pip install trains_agent` ([install](#installing-the-trains-agent) the TRAINS agent on any GPU machine: on-premises / cloud / ...)
|
||||
3. Add [TRAINS](https://github.com/allegroai/trains) to your code with just 2 lines & run it once (on your machine / laptop)
|
||||
1. Install the [Trains Server](https://github.com/allegroai/trains-agent) (or use our [open server](https://demoapp.trains.allegro.ai))
|
||||
2. `pip install trains-agent` ([install](#installing-the-trains-agent) the Trains Agent on any GPU machine: on-premises / cloud / ...)
|
||||
3. Add [Trains](https://github.com/allegroai/trains) to your code with just 2 lines & run it once (on your machine / laptop)
|
||||
4. Change the [parameters](#using-the-trains-agent) in the UI & schedule for [execution](#using-the-trains-agent) (or automate with an [AutoML pipeline](#automl-and-orchestration-pipelines-))
|
||||
5. :chart_with_downwards_trend: :chart_with_upwards_trend: :eyes: :beer:
|
||||
|
||||
|
||||
**Using the TRAINS agent, you can now set up a dynamic cluster with \*epsilon DevOps**
|
||||
**Using the Trains Agent, you can now set up a dynamic cluster with \*epsilon DevOps**
|
||||
|
||||
*epsilon - Because we are scientists :triangular_ruler: and nothing is really zero work
|
||||
|
||||
(Experience TRAINS live at [https://demoapp.trains.allegro.ai](https://demoapp.trains.allegro.ai))
|
||||
(Experience Trains live at [https://demoapp.trains.allegro.ai](https://demoapp.trains.allegro.ai))
|
||||
<a href="https://demoapp.trains.allegro.ai"><img src="https://raw.githubusercontent.com/allegroai/trains-agent/9f1e86c1ca45c984ee13edc9353c7b10c55d7257/docs/screenshots.gif" width="100%"></a>
|
||||
|
||||
## Simple, Flexible Experiment Orchestration
|
||||
**The TRAINS Agent was built to address the DL/ML R&D DevOps needs:**
|
||||
**The Trains Agent was built to address the DL/ML R&D DevOps needs:**
|
||||
|
||||
* Easily add & remove machines from the cluster
|
||||
* Reuse machines without the need for any dedicated containers or images
|
||||
@@ -49,30 +51,30 @@ If you are considering K8S for your research, also consider that you will soon b
|
||||
In our experience, handling and building the environments, having to package every experiment in a docker, managing those hundreds (or more) containers and building pipelines on top of it all, is very complicated (also, it’s usually out of scope for the research team, and overwhelming even for the DevOps team).
|
||||
|
||||
We feel there has to be a better way, that can be just as powerful for R&D and at the same time allow integration with K8S **when the need arises**.
|
||||
(If you already have a K8S cluster for AI, detailed instructions on how to integrate TRAINS into your K8S cluster are *coming soon*.)
|
||||
(If you already have a K8S cluster for AI, detailed instructions on how to integrate Trains into your K8S cluster are [here](https://github.com/allegroai/trains-server-k8s/tree/master/trains-server-chart) with included [helm chart](https://github.com/allegroai/trains-server-helm))
|
||||
|
||||
|
||||
## Using the TRAINS Agent
|
||||
## Using the Trains Agent
|
||||
**Full scale HPC with a click of a button**
|
||||
|
||||
TRAINS Agent is a job scheduler that listens on job queue(s), pulls jobs, sets the job environments, executes the job and monitors its progress.
|
||||
The Trains Agent is a job scheduler that listens on job queue(s), pulls jobs, sets the job environments, executes the job and monitors its progress.
|
||||
|
||||
Any 'Draft' experiment can be scheduled for execution by a TRAINS agent.
|
||||
Any 'Draft' experiment can be scheduled for execution by a Trains agent.
|
||||
|
||||
A previously run experiment can be put into 'Draft' state by either of two methods:
|
||||
* Using the **'Reset'** action from the experiment right-click context menu in the
|
||||
TRAINS UI - This will clear any results and artifacts the previous run had created.
|
||||
Trains UI - This will clear any results and artifacts the previous run had created.
|
||||
* Using the **'Clone'** action from the experiment right-click context menu in the
|
||||
TRAINS UI - This will create a new 'Draft' experiment with the same configuration as the original experiment.
|
||||
Trains UI - This will create a new 'Draft' experiment with the same configuration as the original experiment.
|
||||
|
||||
An experiment is scheduled for execution using the **'Enqueue'** action from the experiment
|
||||
right-click context menu in the TRAINS UI and selecting the execution queue.
|
||||
right-click context menu in the Trains UI and selecting the execution queue.
|
||||
|
||||
See [creating an experiment and enqueuing it for execution](#from-scratch).
|
||||
|
||||
Once an experiment is enqueued, it will be picked up and executed by a TRAINS agent monitoring this queue.
|
||||
Once an experiment is enqueued, it will be picked up and executed by a Trains agent monitoring this queue.
|
||||
|
||||
The TRAINS UI Workers & Queues page provides ongoing execution information:
|
||||
The Trains UI Workers & Queues page provides ongoing execution information:
|
||||
- Workers Tab: Monitor you cluster
|
||||
- Review available resources
|
||||
- Monitor machines statistics (CPU / GPU / Disk / Network)
|
||||
@@ -81,16 +83,16 @@ The TRAINS UI Workers & Queues page provides ongoing execution information:
|
||||
- Cancel or abort job execution
|
||||
- Move jobs between execution queues
|
||||
|
||||
### What The TRAINS Agent Actually Does
|
||||
The TRAINS agent executes experiments using the following process:
|
||||
### What The Trains Agent Actually Does
|
||||
The Trains Agent executes experiments using the following process:
|
||||
- Create a new virtual environment (or launch the selected docker image)
|
||||
- Clone the code into the virtual-environment (or inside the docker)
|
||||
- Install python packages based on the package requirements listed for the experiment
|
||||
- Special note for PyTorch: The TRAINS agent will automatically select the
|
||||
- Special note for PyTorch: The Trains Agent will automatically select the
|
||||
torch packages based on the CUDA_VERSION environment variable of the machine
|
||||
- Execute the code, while monitoring the process
|
||||
- Log all stdout/stderr in the TRAINS UI, including the cloning and installation process, for easy debugging
|
||||
- Monitor the execution and allow you to manually abort the job using the TRAINS UI (or, in the unfortunate case of a code crash, catch the error and signal the experiment has failed)
|
||||
- Log all stdout/stderr in the Trains UI, including the cloning and installation process, for easy debugging
|
||||
- Monitor the execution and allow you to manually abort the job using the Trains UI (or, in the unfortunate case of a code crash, catch the error and signal the experiment has failed)
|
||||
|
||||
### System Design & Flow
|
||||
```text
|
||||
@@ -98,24 +100,24 @@ The TRAINS agent executes experiments using the following process:
|
||||
| GPU Machine |
|
||||
Development Machine | |
|
||||
+------------------------+ | +-------------+ |
|
||||
| Data Scientist's | +--------------+ | |TRAINS Agent | |
|
||||
| Data Scientist's | +--------------+ | |Trains Agent | |
|
||||
| DL/ML Code | | WEB UI | | | | |
|
||||
| | | | | | +---------+ | |
|
||||
| | | | | | | DL/ML | | |
|
||||
| | +--------------+ | | | Code | | |
|
||||
| | User Clones Exp #1 / . . . . . . . / | | | | | |
|
||||
| +-------------------+ | into Exp #2 / . . . . . . . / | | +---------+ | |
|
||||
| | TRAINS | | +---------------/-_____________-/ | | | |
|
||||
| | Trains | | +---------------/-_____________-/ | | | |
|
||||
| +---------+---------+ | | | | ^ | |
|
||||
+-----------|------------+ | | +------|------+ |
|
||||
| | +--------|--------+
|
||||
Auto-Magically | |
|
||||
Creates Exp #1 | The TRAINS Agent
|
||||
Creates Exp #1 | The Trains Agent
|
||||
\ User Change Hyper-Parameters Pulls Exp #2, setup the
|
||||
| | environment & clone code.
|
||||
| | Start execution with the
|
||||
+------------|------------+ | +--------------------+ new set of Hyper-Parameters.
|
||||
| +---------v---------+ | | | TRAINS-SERVER | |
|
||||
| +---------v---------+ | | | Trains Server | |
|
||||
| | Experiment #1 | | | | | |
|
||||
| +-------------------+ | | | Execution Queue | |
|
||||
| || | | | | |
|
||||
@@ -126,17 +128,17 @@ Development Machine |
|
||||
| | ------------->---------------+ | |
|
||||
| | User Send Exp #2 | |Execute Exp #2 +--------------------+
|
||||
| | For Execution | +---------------+ |
|
||||
| TRAINS-SERVER | | |
|
||||
| Trains Server | | |
|
||||
+-------------------------+ +--------------------+
|
||||
```
|
||||
|
||||
### Installing the TRAINS Agent
|
||||
### Installing the Trains Agent
|
||||
|
||||
```bash
|
||||
pip install trains_agent
|
||||
pip install trains-agent
|
||||
```
|
||||
|
||||
### TRAINS Agent Usage Examples
|
||||
### Trains Agent Usage Examples
|
||||
|
||||
Full Interface and capabilities are available with
|
||||
```bash
|
||||
@@ -144,73 +146,76 @@ trains-agent --help
|
||||
trains-agent daemon --help
|
||||
```
|
||||
|
||||
### Configuring the TRAINS Agent
|
||||
### Configuring the Trains Agent
|
||||
|
||||
```bash
|
||||
trains-agent init
|
||||
```
|
||||
|
||||
Note: The TRAINS agent uses a cache folder to cache pip packages, apt packages and cloned repositories. The default TRAINS Agent cache folder is `~/.trains`
|
||||
Note: The Trains Agent uses a cache folder to cache pip packages, apt packages and cloned repositories. The default Trains Agent cache folder is `~/.trains`
|
||||
|
||||
See full details in your configuration file at `~/trains.conf`
|
||||
|
||||
Note: The **TRAINS agent** extends the **TRAINS** configuration file `~/trains.conf`
|
||||
Note: The **Trains agent** extends the **Trains** configuration file `~/trains.conf`
|
||||
They are designed to share the same configuration file, see example [here](docs/trains.conf)
|
||||
|
||||
### Running the TRAINS Agent
|
||||
### Running the Trains Agent
|
||||
|
||||
For debug and experimentation, start the TRAINS agent in `foreground` mode, where all the output is printed to screen
|
||||
For debug and experimentation, start the Trains agent in `foreground` mode, where all the output is printed to screen
|
||||
```bash
|
||||
trains-agent daemon --queue default --foreground
|
||||
```
|
||||
|
||||
For actual service mode, all the stdout will be stored automatically into a temporary file (no need to pipe)
|
||||
Notice: with `--detached` flag, the *trains-agent* will be running in the background
|
||||
```bash
|
||||
trains-agent daemon --queue default
|
||||
trains-agent daemon --detached --queue default
|
||||
```
|
||||
|
||||
GPU allocation is controlled via the standard OS environment NVIDIA_VISIBLE_DEVICES.
|
||||
GPU allocation is controlled via the standard OS environment `NVIDIA_VISIBLE_DEVICES` or `--gpus` flag (or disabled with `--cpu-only`).
|
||||
|
||||
If NVIDIA_VISIBLE_DEVICES variable doesn't exist, all GPU's will be allocated for the `trains-agent` <br>
|
||||
If NVIDIA_VISIBLE_DEVICES is an empty string ("") No gpu will be allocated for the `trains-agent`
|
||||
If no flag is set, and `NVIDIA_VISIBLE_DEVICES` variable doesn't exist, all GPU's will be allocated for the `trains-agent` <br>
|
||||
If `--cpu-only` flag is set, or `NVIDIA_VISIBLE_DEVICES` is an empty string (""), no gpu will be allocated for the `trains-agent`
|
||||
|
||||
Example: spin two agents, one per gpu on the same machine:
|
||||
Notice: with `--detached` flag, the *trains-agent* will be running in the background
|
||||
```bash
|
||||
NVIDIA_VISIBLE_DEVICES=0 trains-agent daemon --queue default &
|
||||
NVIDIA_VISIBLE_DEVICES=1 trains-agent daemon --queue default &
|
||||
trains-agent daemon --detached --gpus 0 --queue default
|
||||
trains-agent daemon --detached --gpus 1 --queue default
|
||||
```
|
||||
|
||||
Example: spin two agents, with two gpu's per agent:
|
||||
Example: spin two agents, pulling from dedicated `dual_gpu` queue, two gpu's per agent
|
||||
```bash
|
||||
NVIDIA_VISIBLE_DEVICES=0,1 trains-agent daemon --queue default &
|
||||
NVIDIA_VISIBLE_DEVICES=2,3 trains-agent daemon --queue default &
|
||||
trains-agent daemon --detached --gpus 0,1 --queue dual_gpu
|
||||
trains-agent daemon --detached --gpus 2,3 --queue dual_gpu
|
||||
```
|
||||
|
||||
#### Starting the TRAINS Agent in docker mode
|
||||
#### Starting the Trains Agent in docker mode
|
||||
|
||||
For debug and experimentation, start the TRAINS agent in `foreground` mode, where all the output is printed to screen
|
||||
For debug and experimentation, start the Trains agent in `foreground` mode, where all the output is printed to screen
|
||||
```bash
|
||||
trains-agent daemon --queue default --docker --foreground
|
||||
```
|
||||
|
||||
For actual service mode, all the stdout will be stored automatically into a file (no need to pipe)
|
||||
Notice: with `--detached` flag, the *trains-agent* will be running in the background
|
||||
```bash
|
||||
trains-agent daemon --queue default --docker
|
||||
trains-agent daemon --detached --queue default --docker
|
||||
```
|
||||
|
||||
Example: spin two agents, one per gpu on the same machine:
|
||||
Example: spin two agents, one per gpu on the same machine, with default nvidia/cuda docker:
|
||||
```bash
|
||||
NVIDIA_VISIBLE_DEVICES=0 trains-agent daemon --queue default --docker &
|
||||
NVIDIA_VISIBLE_DEVICES=1 trains-agent daemon --queue default --docker &
|
||||
trains-agent daemon --detached --gpus 0 --queue default --docker nvidia/cuda
|
||||
trains-agent daemon --detached --gpus 1 --queue default --docker nvidia/cuda
|
||||
```
|
||||
|
||||
Example: spin two agents, with two gpu's per agent:
|
||||
Example: spin two agents, pulling from dedicated `dual_gpu` queue, two gpu's per agent, with default nvidia/cuda docker:
|
||||
```bash
|
||||
NVIDIA_VISIBLE_DEVICES=0,1 trains-agent daemon --queue default --docker &
|
||||
NVIDIA_VISIBLE_DEVICES=2,3 trains-agent daemon --queue default --docker &
|
||||
trains-agent daemon --detached --gpus 0,1 --queue dual_gpu --docker nvidia/cuda
|
||||
trains-agent daemon --detached --gpus 2,3 --queue dual_gpu --docker nvidia/cuda
|
||||
```
|
||||
|
||||
#### Starting the TRAINS Agent - Priority Queues
|
||||
#### Starting the Trains Agent - Priority Queues
|
||||
|
||||
Priority Queues are also supported, example use case:
|
||||
|
||||
@@ -218,12 +223,22 @@ High priority queue: `important_jobs` Low priority queue: `default`
|
||||
```bash
|
||||
trains-agent daemon --queue important_jobs default
|
||||
```
|
||||
The **TRAINS agent** will first try to pull jobs from the `important_jobs` queue, only then it will fetch a job from the `default` queue.
|
||||
The **Trains Agent** will first try to pull jobs from the `important_jobs` queue, only then it will fetch a job from the `default` queue.
|
||||
|
||||
# How do I create an experiment on the TRAINS server? <a name="from-scratch"></a>
|
||||
* Integrate [TRAINS](https://github.com/allegroai/trains) with your code
|
||||
Adding queues, managing job order within a queue and moving jobs between queues, is available using the Web UI, see example on our [open server](https://demoapp.trains.allegro.ai/workers-and-queues/queues)
|
||||
|
||||
#### Stopping the Trains Agent
|
||||
|
||||
To stop a **Trains Agent** running in the background, run the same command line used to start the agent with `--stop` appended.
|
||||
For example, to stop the first of the above shown same machine, single gpu agents:
|
||||
```bash
|
||||
trains-agent daemon --detached --gpus 0 --queue default --docker nvidia/cuda --stop
|
||||
```
|
||||
|
||||
## How do I create an experiment on the Trains Server? <a name="from-scratch"></a>
|
||||
* Integrate [Trains](https://github.com/allegroai/trains) with your code
|
||||
* Execute the code on your machine (Manually / PyCharm / Jupyter Notebook)
|
||||
* As your code is running, **TRAINS** creates an experiment logging all the necessary execution information:
|
||||
* As your code is running, **Trains** creates an experiment logging all the necessary execution information:
|
||||
- Git repository link and commit ID (or an entire jupyter notebook)
|
||||
- Git diff (we’re not saying you never commit and push, but still...)
|
||||
- Python packages used by your code (including specific versions used)
|
||||
@@ -232,7 +247,7 @@ The **TRAINS agent** will first try to pull jobs from the `important_jobs` queue
|
||||
|
||||
You now have a 'template' of your experiment with everything required for automated execution
|
||||
|
||||
* In the TRAINS UI, Right click on the experiment and select 'clone'. A copy of your experiment will be created.
|
||||
* In the Trains UI, Right click on the experiment and select 'clone'. A copy of your experiment will be created.
|
||||
* You now have a new draft experiment cloned from your original experiment, feel free to edit it
|
||||
- Change the Hyper-Parameters
|
||||
- Switch to the latest code base of the repository
|
||||
@@ -241,19 +256,44 @@ The **TRAINS agent** will first try to pull jobs from the `important_jobs` queue
|
||||
- Or simply change nothing to run the same experiment again...
|
||||
* Schedule the newly created experiment for execution: Right-click the experiment and select 'enqueue'
|
||||
|
||||
# AutoML and Orchestration Pipelines <a name="automl-pipes"></a>
|
||||
The TRAINS Agent can also be used to implement AutoML orchestration and Experiment Pipelines in conjunction with the TRAINS package.
|
||||
## Trains-Agent Services Mode <a name="services"></a>
|
||||
|
||||
Sample AutoML & Orchestration examples can be found in the TRAINS [example/automl](https://github.com/allegroai/trains/tree/master/examples/automl) folder.
|
||||
Trains-Agent Services is a special mode of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services mode will spin **any** task enqueued into the specified queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
Currently trains-agent in services-mode supports cpu only configuration. Trains-agent services mode can be launched alongside GPU agents.
|
||||
|
||||
```bash
|
||||
trains-agent daemon --services-mode --detached --queue services --create-queue --docker ubuntu:18.04 --cpu-only
|
||||
```
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the specified queue.
|
||||
|
||||
|
||||
## AutoML and Orchestration Pipelines <a name="automl-pipes"></a>
|
||||
The Trains Agent can also be used to implement AutoML orchestration and Experiment Pipelines in conjunction with the Trains package.
|
||||
|
||||
Sample AutoML & Orchestration examples can be found in the Trains [example/automation](https://github.com/allegroai/trains/tree/master/examples/automation) folder.
|
||||
|
||||
AutoML examples
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/automl/automl_base_template_keras_simple.py)
|
||||
- [Toy Keras training experiment](https://github.com/allegroai/trains/blob/master/examples/optimization/hyper-parameter-optimization/base_template_keras_simple.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automl/automl_random_search_example.py)
|
||||
- [Random Search over the above Keras experiment-template](https://github.com/allegroai/trains/blob/master/examples/automation/manual_random_param_search_example.py)
|
||||
- This example will create multiple copies of the Keras experiment-template, with different hyper-parameter combinations
|
||||
|
||||
Experiment Pipeline examples
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/task_piping_example.py)
|
||||
- [First step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/task_piping_example.py)
|
||||
- This example will "process data", and once done, will launch a copy of the 'second step' experiment-template
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automl/toy_base_task.py)
|
||||
- [Second step experiment](https://github.com/allegroai/trains/blob/master/examples/automation/toy_base_task.py)
|
||||
- In order to create an experiment-template in the system, this code must be executed once manually
|
||||
|
||||
## License
|
||||
|
||||
Apache License, Version 2.0 (see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0.html) for more information)
|
||||
|
||||
18
docker/agent/Dockerfile
Normal file
18
docker/agent/Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
||||
# syntax = docker/dockerfile
|
||||
FROM nvidia/cuda
|
||||
|
||||
WORKDIR /usr/agent
|
||||
|
||||
COPY . /usr/agent
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get dist-upgrade -y
|
||||
RUN apt-get install -y curl python3-pip git
|
||||
RUN curl -sSL https://get.docker.com/ | sh
|
||||
RUN python3 -m pip install -U pip
|
||||
RUN python3 -m pip install trains-agent
|
||||
RUN python3 -m pip install -U "cryptography>=2.9"
|
||||
|
||||
ENV TRAINS_DOCKER_SKIP_GPUS_FLAG=1
|
||||
|
||||
ENTRYPOINT ["/usr/agent/entrypoint.sh"]
|
||||
19
docker/agent/entrypoint.sh
Executable file
19
docker/agent/entrypoint.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/sh
|
||||
|
||||
LOWER_PIP_UPDATE_VERSION="$(echo "$PIP_UPDATE_VERSION" | tr '[:upper:]' '[:lower:]')"
|
||||
LOWER_TRAINS_AGENT_UPDATE_VERSION="$(echo "$TRAINS_AGENT_UPDATE_VERSION" | tr '[:upper:]' '[:lower:]')"
|
||||
|
||||
if [ "$LOWER_PIP_UPDATE_VERSION" = "yes" ] || [ "$LOWER_PIP_UPDATE_VERSION" = "true" ] ; then
|
||||
python3 -m pip install -U pip
|
||||
elif [ ! -z "$LOWER_PIP_UPDATE_VERSION" ] ; then
|
||||
python3 -m pip install pip$LOWER_PIP_UPDATE_VERSION ;
|
||||
fi
|
||||
|
||||
echo "TRAINS_AGENT_UPDATE_VERSION = $LOWER_TRAINS_AGENT_UPDATE_VERSION"
|
||||
if [ "$LOWER_TRAINS_AGENT_UPDATE_VERSION" = "yes" ] || [ "$LOWER_TRAINS_AGENT_UPDATE_VERSION" = "true" ] ; then
|
||||
python3 -m pip install trains-agent -U
|
||||
elif [ ! -z "$LOWER_TRAINS_AGENT_UPDATE_VERSION" ] ; then
|
||||
python3 -m pip install trains-agent$LOWER_TRAINS_AGENT_UPDATE_VERSION ;
|
||||
fi
|
||||
|
||||
python3 -m trains_agent daemon --docker "$TRAINS_AGENT_DEFAULT_BASE_DOCKER" --force-current-version $TRAINS_AGENT_EXTRA_ARGS
|
||||
25
docker/services/Dockerfile
Normal file
25
docker/services/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
# syntax = docker/dockerfile
|
||||
FROM ubuntu:18.04
|
||||
|
||||
WORKDIR /usr/agent
|
||||
|
||||
COPY . /usr/agent
|
||||
|
||||
ENV LC_ALL=en_US.UTF-8
|
||||
ENV LANG=en_US.UTF-8
|
||||
ENV LANGUAGE=en_US.UTF-8
|
||||
ENV PYTHONIOENCODING=UTF-8
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get dist-upgrade -y
|
||||
RUN apt-get install -y locales
|
||||
|
||||
RUN locale-gen en_US.UTF-8
|
||||
|
||||
RUN apt-get install -y curl python3-pip git
|
||||
RUN curl -sSL https://get.docker.com/ | sh
|
||||
RUN python3 -m pip install -U pip
|
||||
RUN python3 -m pip install trains-agent
|
||||
RUN python3 -m pip install -U "cryptography>=2.9"
|
||||
|
||||
ENTRYPOINT ["/usr/agent/entrypoint.sh"]
|
||||
14
docker/services/entrypoint.sh
Executable file
14
docker/services/entrypoint.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/sh
|
||||
|
||||
if [ -z "$TRAINS_FILES_HOST" ]; then
|
||||
TRAINS_HOST_IP=${TRAINS_HOST_IP:-$(curl -s https://ifconfig.me/ip)}
|
||||
fi
|
||||
|
||||
TRAINS_FILES_HOST=${TRAINS_FILES_HOST:-"http://$TRAINS_HOST_IP:8081"}
|
||||
TRAINS_WEB_HOST=${TRAINS_WEB_HOST:-"http://$TRAINS_HOST_IP:8080"}
|
||||
TRAINS_API_HOST=${TRAINS_API_HOST:-"http://$TRAINS_HOST_IP:8008"}
|
||||
|
||||
echo $TRAINS_FILES_HOST $TRAINS_WEB_HOST $TRAINS_API_HOST 1>&2
|
||||
|
||||
python3 -m pip install -q -U "trains-agent${TRAINS_AGENT_UPDATE_VERSION}"
|
||||
trains-agent daemon --services-mode --queue services --create-queue --docker "$TRAINS_AGENT_DEFAULT_BASE_DOCKER" --cpu-only $TRAINS_AGENT_EXTRA_ARGS
|
||||
@@ -13,11 +13,18 @@ api {
|
||||
}
|
||||
|
||||
agent {
|
||||
# Set GIT user/pass credentials
|
||||
# leave blank for GIT SSH credentials
|
||||
# Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
|
||||
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
|
||||
git_user=""
|
||||
git_pass=""
|
||||
# Limit credentials to a single domain, for example: github.com,
|
||||
# all other domains will use public access (no user/pass). Default: always send user/pass for any VCS domain
|
||||
git_host=""
|
||||
|
||||
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
|
||||
force_git_ssh_protocol: false
|
||||
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
|
||||
# force_git_ssh_port: ""
|
||||
|
||||
# unique name of this worker, if None, created based on hostname:process_id
|
||||
# Overridden with os environment: TRAINS_WORKER_NAME
|
||||
@@ -29,12 +36,21 @@ agent {
|
||||
# worker_name: "trains-agent-machine1"
|
||||
worker_name: ""
|
||||
|
||||
# Set the python version to use when creating the virtual environment and launching the experiment
|
||||
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
|
||||
# The default is the python executing the trains_agent
|
||||
python_binary: ""
|
||||
|
||||
# select python package manager:
|
||||
# currently supported pip and conda
|
||||
# poetry is used if pip selected and repository contains poetry.lock file
|
||||
package_manager: {
|
||||
# supported options: pip, conda
|
||||
# supported options: pip, conda, poetry
|
||||
type: pip,
|
||||
|
||||
# specify pip version to use (examples "<20", "==19.3.1", "", empty string will install the latest version)
|
||||
# pip_version: "<20"
|
||||
|
||||
# virtual environment inheres packages from system
|
||||
system_site_packages: false,
|
||||
# install with --upgrade
|
||||
@@ -46,6 +62,28 @@ agent {
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["pytorch", "conda-forge", ]
|
||||
# conda_full_env_update: false
|
||||
# conda_env_as_base_docker: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
# set the optional priority packages to be installed before the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# priority_optional_packages: ["pygobject", ]
|
||||
|
||||
# set the post packages to be installed after all the rest of the required packages
|
||||
# post_packages: ["horovod", ]
|
||||
|
||||
# set the optional post packages to be installed after all the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# post_optional_packages: []
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
torch_nightly: false,
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
@@ -73,18 +111,34 @@ agent {
|
||||
# reload configuration file every daemon execution
|
||||
reload_config: false,
|
||||
|
||||
# pip cache folder used mapped into docker, for python package caching
|
||||
# pip cache folder mapped into docker, used for python package caching
|
||||
docker_pip_cache = ~/.trains/pip-cache
|
||||
# apt cache folder used mapped into docker, for ubuntu package caching
|
||||
# apt cache folder mapped into docker, used for ubuntu package caching
|
||||
docker_apt_cache = ~/.trains/apt-cache
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# these are local for this agent and will not be updated in the experiment's docker_cmd section
|
||||
# extra_docker_arguments: ["--ipc=host", ]
|
||||
|
||||
# optional shell script to run in docker when started before the experiment is started
|
||||
# extra_docker_shell_script: ["apt-get install -y bindfs", ]
|
||||
|
||||
# set to true in order to force "docker pull" before running an experiment using a docker image.
|
||||
# This makes sure the docker image is updated.
|
||||
docker_force_pull: false
|
||||
|
||||
default_docker: {
|
||||
# default docker image to use when running in docker mode
|
||||
image: "nvidia/cuda"
|
||||
image: "nvidia/cuda:10.1-runtime-ubuntu18.04"
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# arguments: ["--ipc=host"]
|
||||
}
|
||||
|
||||
# CUDA versions used for Conda setup & solving PyTorch wheel packages
|
||||
# it Should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
|
||||
# cuda_version: 10.1
|
||||
# cudnn_version: 7.6
|
||||
}
|
||||
|
||||
sdk {
|
||||
@@ -110,12 +164,23 @@ sdk {
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Max history size for matplotlib imshow files per plot title.
|
||||
# File names for the uploaded images will be recycled in such a way that no more than
|
||||
# X images are stored in the upload destination for each matplotlib plot title.
|
||||
matplotlib_untitled_history_size: 100
|
||||
|
||||
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||
# plot_max_num_digits: 5
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
|
||||
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to True, each series should have its own graph)
|
||||
tensorboard_single_series_per_graph: False
|
||||
}
|
||||
|
||||
network {
|
||||
@@ -219,6 +284,9 @@ sdk {
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
|
||||
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
|
||||
default_output_uri: ""
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
# Status report period in seconds
|
||||
|
||||
59
examples/archive_experiments.py
Normal file
59
examples/archive_experiments.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
An example script that cleans up failed experiments by moving them to the archive
|
||||
"""
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
from trains_agent import APIClient
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--project", "-P", help="Project ID. Only clean up experiments from this project")
|
||||
parser.add_argument("--user", "-U", help="User ID. Only clean up experiments assigned to this user")
|
||||
parser.add_argument(
|
||||
"--status", "-S", default="failed",
|
||||
help="Experiment status. Only clean up experiments with this status (default %(default)s)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations", "-I", type=int,
|
||||
help="Number of iterations. Only clean up experiments with less or equal number of iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sec-from-start", "-T", type=int,
|
||||
help="Seconds from start time. "
|
||||
"Only clean up experiments if less or equal number of seconds have elapsed since started"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
client = APIClient()
|
||||
|
||||
tasks = client.tasks.get_all(
|
||||
project=[args.project] if args.project else None,
|
||||
user=[args.user] if args.user else None,
|
||||
status=[args.status] if args.status else None,
|
||||
system_tags=["-archived"]
|
||||
)
|
||||
|
||||
count = 0
|
||||
|
||||
for task in tasks:
|
||||
if args.iterations and (task.last_iteration or 0) > args.iterations:
|
||||
continue
|
||||
if args.sec_from_start:
|
||||
if not task.started:
|
||||
continue
|
||||
if (datetime.utcnow() - task.started.replace(tzinfo=None)).total_seconds() > args.sec_from_start:
|
||||
continue
|
||||
|
||||
try:
|
||||
client.tasks.edit(
|
||||
task=task.id,
|
||||
system_tags=(task.system_tags or []) + ["archived"],
|
||||
force=True
|
||||
)
|
||||
count += 1
|
||||
except Exception as ex:
|
||||
print("Failed editing experiment: {}".format(ex))
|
||||
|
||||
print("Cleaned up {} experiments".format(count))
|
||||
587
examples/dynamic_cloud_cluster.ipynb
Normal file
587
examples/dynamic_cloud_cluster.ipynb
Normal file
@@ -0,0 +1,587 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Auto-Magically Spin AWS EC2 Instances On Demand \n",
|
||||
"# and Create a Dynamic Cluster Running *Trains-Agent*\n",
|
||||
"\n",
|
||||
"### Define your budget and execute the notebook, that's it\n",
|
||||
"### You now have a fully managed cluster on AWS 🎉 🎊 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**trains-agent**'s main goal is to quickly pull a job from an execution queue, setup the environment (as defined in the experiment, including git cloning, python packages etc.) then execute the experiment and monitor it.\n",
|
||||
"\n",
|
||||
"This notebook defines a cloud budget (currently only AWS is supported, but feel free to expand with PRs), and spins an instance the minute a job is waiting for execution. It will also spin down idle machines, saving you some $$$ :)\n",
|
||||
"\n",
|
||||
"Configuration steps\n",
|
||||
"- Define maximum budget to be used (instance type / number of instances).\n",
|
||||
"- Create new execution *queues* in the **trains-server**.\n",
|
||||
"- Define mapping between the created the *queues* and an instance budget.\n",
|
||||
"\n",
|
||||
"**TL;DR - This notebook:**\n",
|
||||
"- Will spin instances if there are jobs in the execution *queues*, until it will hit the budget limit. \n",
|
||||
"- If machines are idle, it will spin them down.\n",
|
||||
"\n",
|
||||
"The controller implementation itself is stateless, meaning you can always re-execute the notebook, if for some reason it stopped.\n",
|
||||
"\n",
|
||||
"It is as simple as it sounds, but extremely powerful\n",
|
||||
"\n",
|
||||
"Enjoy your newly created dynamic cluster :)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Install & import required packages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install trains-agent\n",
|
||||
"!pip install boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Define AWS instance types and configuration (Instance Type, EBS, AMI etc.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# AWS EC2 machines types - default AMI - NVIDIA Deep Learning AMI 19.11.3\n",
|
||||
"RESOURCE_CONFIGURATIONS = {\n",
|
||||
" \"amazon_ec2_normal\": {\n",
|
||||
" \"instance_type\": \"g4dn.4xlarge\",\n",
|
||||
" \"is_spot\": False,\n",
|
||||
" \"availability_zone\": \"us-east-1b\",\n",
|
||||
" \"ami_id\": \"ami-07c95cafbb788face\",\n",
|
||||
" \"ebs_device_name\": \"/dev/xvda\",\n",
|
||||
" \"ebs_volume_size\": 100,\n",
|
||||
" \"ebs_volume_type\": \"gp2\",\n",
|
||||
" },\n",
|
||||
" \"amazon_ec2_high\": {\n",
|
||||
" \"instance_type\": \"g4dn.8xlarge\",\n",
|
||||
" \"is_spot\": False,\n",
|
||||
" \"availability_zone\": \"us-east-1b\",\n",
|
||||
" \"ami_id\": \"ami-07c95cafbb788face\",\n",
|
||||
" \"ebs_device_name\": \"/dev/xvda\",\n",
|
||||
" \"ebs_volume_size\": 100,\n",
|
||||
" \"ebs_volume_type\": \"gp2\",\n",
|
||||
" },\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Define machine budget per execution queue\n",
|
||||
"\n",
|
||||
"Now that we defined our budget, we need to connect it with the **Trains** cluster.\n",
|
||||
"\n",
|
||||
"We map each queue to a resource type (instance type).\n",
|
||||
"\n",
|
||||
"Create two queues in the WebUI:\n",
|
||||
"- Browse to http://your_trains_server_ip:8080/workers-and-queues/queues\n",
|
||||
"- Then click on the \"New Queue\" button and name your queues \"aws_normal\" and \"aws_high\" respectively\n",
|
||||
"\n",
|
||||
"The QUEUES dictionary hold the mapping between the queue name and the type/number of instances to spin connected to the specific queue.\n",
|
||||
"```\n",
|
||||
"QUEUES = {\n",
|
||||
" 'queue_name': [(\"instance-type-as-defined-in-RESOURCE_CONFIGURATIONS\", max_number_of_instances), ]\n",
|
||||
"}\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Trains-Agent Queues - Machines budget per Queue\n",
|
||||
"# Per queue: list of (machine type as defined in RESOURCE_CONFIGURATIONS,\n",
|
||||
"# max instances for the specific queue). Order machines from most preferred to least.\n",
|
||||
"QUEUES = {\n",
|
||||
" \"aws_normal\": [(\"amazon_ec2_normal\", 2),],\n",
|
||||
" \"aws_high\": [(\"amazon_ec2_high\", 1)],\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Credentials for your AWS account, as well as for your **Trains-Server**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# AWS credentials (leave empty to use credentials set using the aws cli)\n",
|
||||
"CLOUD_CREDENTIALS_KEY = \"\"\n",
|
||||
"CLOUD_CREDENTIALS_SECRET = \"\"\n",
|
||||
"CLOUD_CREDENTIALS_REGION = \"us-east-1\"\n",
|
||||
"\n",
|
||||
"# TRAINS configuration\n",
|
||||
"TRAINS_SERVER_WEB_SERVER = \"http://localhost:8080\"\n",
|
||||
"TRAINS_SERVER_API_SERVER = \"http://localhost:8008\"\n",
|
||||
"TRAINS_SERVER_FILES_SERVER = \"http://localhost:8081\"\n",
|
||||
"# TRAINS credentials\n",
|
||||
"TRAINS_ACCESS_KEY = \"\"\n",
|
||||
"TRAINS_SECRET_KEY = \"\"\n",
|
||||
"# Git User/Pass to be used by trains-agent,\n",
|
||||
"# leave empty if image already contains git ssh-key\n",
|
||||
"TRAINS_GIT_USER = \"\"\n",
|
||||
"TRAINS_GIT_PASS = \"\"\n",
|
||||
"\n",
|
||||
"# Additional fields for trains.conf file created on the remote instance\n",
|
||||
"# for example: 'agent.default_docker.image: \"nvidia/cuda:10.0-cudnn7-runtime\"'\n",
|
||||
"EXTRA_TRAINS_CONF = \"\"\"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Bash script to run on instances before running trains-agent\n",
|
||||
"# Example: \"\"\"\n",
|
||||
"# echo \"This is the first line\"\n",
|
||||
"# echo \"This is the second line\"\n",
|
||||
"# \"\"\"\n",
|
||||
"EXTRA_BASH_SCRIPT = \"\"\"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Default docker for trains-agent when running in docker mode (requires docker v19.03 and above). \n",
|
||||
"# Leave empty to run trains-agent in non-docker mode.\n",
|
||||
"DEFAULT_DOCKER_IMAGE = \"nvidia/cuda\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Controller Internal Definitions\n",
|
||||
"\n",
|
||||
"# maximum idle time in minutes, after which the instance will be shutdown\n",
|
||||
"MAX_IDLE_TIME_MIN = 15\n",
|
||||
"# polling interval in minutes\n",
|
||||
"# make sure to increase in case bash commands were added in EXTRA_BASH_SCRIPT\n",
|
||||
"POLLING_INTERVAL_MIN = 5.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Import Packages and Budget Definition Sanity Check"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import base64\n",
|
||||
"import re\n",
|
||||
"import os\n",
|
||||
"from itertools import chain\n",
|
||||
"from operator import itemgetter\n",
|
||||
"from time import sleep, time\n",
|
||||
"\n",
|
||||
"import boto3\n",
|
||||
"from trains_agent.backend_api.session.client import APIClient"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sanity Check - Validate Queue Resources\n",
|
||||
"if len(set(map(itemgetter(0), chain(*QUEUES.values())))) != sum(\n",
|
||||
" map(len, QUEUES.values())\n",
|
||||
"):\n",
|
||||
" print(\n",
|
||||
" \"Error: at least one resource name is used in multiple queues. \"\n",
|
||||
" \"A resource name can only appear in a single queue definition.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Encode EXTRA_TRAINS_CONF for later bash script usage\n",
|
||||
"EXTRA_TRAINS_CONF_ENCODED = \"\\\\\\\"\".join(EXTRA_TRAINS_CONF.split(\"\\\"\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Cloud specific implementation of spin up/down - currently supports AWS only"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Cloud-specific implementation (currently, only AWS EC2 is supported)\n",
|
||||
"def spin_up_worker(resource, worker_id_prefix, queue_name):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates a new worker for trains.\n",
|
||||
" First, create an instance in the cloud and install some required packages.\n",
|
||||
" Then, define trains-agent environment variables and run \n",
|
||||
" trains-agent for the specified queue.\n",
|
||||
" NOTE: - Will wait until instance is running\n",
|
||||
" - This implementation assumes the instance image already has docker installed\n",
|
||||
"\n",
|
||||
" :param str resource: resource name, as defined in BUDGET and QUEUES.\n",
|
||||
" :param str worker_id_prefix: worker name prefix\n",
|
||||
" :param str queue_name: trains queue to listen to\n",
|
||||
" \"\"\"\n",
|
||||
" resource_conf = RESOURCE_CONFIGURATIONS[resource]\n",
|
||||
" # Add worker type and AWS instance type to the worker name.\n",
|
||||
" worker_id = \"{worker_id_prefix}:{worker_type}:{instance_type}\".format(\n",
|
||||
" worker_id_prefix=worker_id_prefix,\n",
|
||||
" worker_type=resource,\n",
|
||||
" instance_type=resource_conf[\"instance_type\"],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # user_data script will automatically run when the instance is started. \n",
|
||||
" # It will install the required packages for trains-agent configure it using \n",
|
||||
" # environment variables and run trains-agent on the required queue\n",
|
||||
" user_data = \"\"\"#!/bin/bash\n",
|
||||
" sudo apt-get update\n",
|
||||
" sudo apt-get install -y python3-dev\n",
|
||||
" sudo apt-get install -y python3-pip\n",
|
||||
" sudo apt-get install -y gcc\n",
|
||||
" sudo apt-get install -y git\n",
|
||||
" sudo apt-get install -y build-essential\n",
|
||||
" python3 -m pip install -U pip\n",
|
||||
" python3 -m pip install virtualenv\n",
|
||||
" python3 -m virtualenv trains_agent_venv\n",
|
||||
" source trains_agent_venv/bin/activate\n",
|
||||
" python -m pip install trains-agent\n",
|
||||
" echo 'agent.git_user=\\\"{git_user}\\\"' >> /root/trains.conf\n",
|
||||
" echo 'agent.git_pass=\\\"{git_pass}\\\"' >> /root/trains.conf\n",
|
||||
" echo \"{trains_conf}\" >> /root/trains.conf\n",
|
||||
" export TRAINS_API_HOST={api_server}\n",
|
||||
" export TRAINS_WEB_HOST={web_server}\n",
|
||||
" export TRAINS_FILES_HOST={files_server}\n",
|
||||
" export DYNAMIC_INSTANCE_ID=`curl http://169.254.169.254/latest/meta-data/instance-id`\n",
|
||||
" export TRAINS_WORKER_ID={worker_id}:$DYNAMIC_INSTANCE_ID\n",
|
||||
" export TRAINS_API_ACCESS_KEY='{access_key}'\n",
|
||||
" export TRAINS_API_SECRET_KEY='{secret_key}'\n",
|
||||
" {bash_script}\n",
|
||||
" source ~/.bashrc\n",
|
||||
" python -m trains_agent --config-file '/root/trains.conf' daemon --queue '{queue}' {docker}\n",
|
||||
" shutdown\n",
|
||||
" \"\"\".format(\n",
|
||||
" api_server=TRAINS_SERVER_API_SERVER,\n",
|
||||
" web_server=TRAINS_SERVER_WEB_SERVER,\n",
|
||||
" files_server=TRAINS_SERVER_FILES_SERVER,\n",
|
||||
" worker_id=worker_id,\n",
|
||||
" access_key=TRAINS_ACCESS_KEY,\n",
|
||||
" secret_key=TRAINS_SECRET_KEY,\n",
|
||||
" queue=queue_name,\n",
|
||||
" git_user=TRAINS_GIT_USER,\n",
|
||||
" git_pass=TRAINS_GIT_PASS,\n",
|
||||
" trains_conf=EXTRA_TRAINS_CONF_ENCODED,\n",
|
||||
" bash_script=EXTRA_BASH_SCRIPT,\n",
|
||||
" docker=\"--docker '{}'\".format(DEFAULT_DOCKER_IMAGE) if DEFAULT_DOCKER_IMAGE else \"\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ec2 = boto3.client(\n",
|
||||
" \"ec2\",\n",
|
||||
" aws_access_key_id=CLOUD_CREDENTIALS_KEY or None,\n",
|
||||
" aws_secret_access_key=CLOUD_CREDENTIALS_SECRET or None,\n",
|
||||
" region_name=CLOUD_CREDENTIALS_REGION\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if resource_conf[\"is_spot\"]:\n",
|
||||
" # Create a request for a spot instance in AWS\n",
|
||||
" encoded_user_data = base64.b64encode(user_data.encode(\"ascii\")).decode(\"ascii\")\n",
|
||||
" instances = ec2.request_spot_instances(\n",
|
||||
" LaunchSpecification={\n",
|
||||
" \"ImageId\": resource_conf[\"ami_id\"],\n",
|
||||
" \"InstanceType\": resource_conf[\"instance_type\"],\n",
|
||||
" \"Placement\": {\"AvailabilityZone\": resource_conf[\"availability_zone\"]},\n",
|
||||
" \"UserData\": encoded_user_data,\n",
|
||||
" \"BlockDeviceMappings\": [\n",
|
||||
" {\n",
|
||||
" \"DeviceName\": resource_conf[\"ebs_device_name\"],\n",
|
||||
" \"Ebs\": {\n",
|
||||
" \"VolumeSize\": resource_conf[\"ebs_volume_size\"],\n",
|
||||
" \"VolumeType\": resource_conf[\"ebs_volume_type\"],\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Wait until spot request is fulfilled\n",
|
||||
" request_id = instances[\"SpotInstanceRequests\"][0][\"SpotInstanceRequestId\"]\n",
|
||||
" waiter = ec2.get_waiter(\"spot_instance_request_fulfilled\")\n",
|
||||
" waiter.wait(SpotInstanceRequestIds=[request_id])\n",
|
||||
" # Get the instance object for later use\n",
|
||||
" response = ec2.describe_spot_instance_requests(\n",
|
||||
" SpotInstanceRequestIds=[request_id]\n",
|
||||
" )\n",
|
||||
" instance_id = response[\"SpotInstanceRequests\"][0][\"InstanceId\"]\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" # Create a new EC2 instance\n",
|
||||
" instances = ec2.run_instances(\n",
|
||||
" ImageId=resource_conf[\"ami_id\"],\n",
|
||||
" MinCount=1,\n",
|
||||
" MaxCount=1,\n",
|
||||
" InstanceType=resource_conf[\"instance_type\"],\n",
|
||||
" UserData=user_data,\n",
|
||||
" InstanceInitiatedShutdownBehavior='terminate',\n",
|
||||
" BlockDeviceMappings=[\n",
|
||||
" {\n",
|
||||
" \"DeviceName\": resource_conf[\"ebs_device_name\"],\n",
|
||||
" \"Ebs\": {\n",
|
||||
" \"VolumeSize\": resource_conf[\"ebs_volume_size\"],\n",
|
||||
" \"VolumeType\": resource_conf[\"ebs_volume_type\"],\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Get the instance object for later use\n",
|
||||
" instance_id = instances[\"Instances\"][0][\"InstanceId\"]\n",
|
||||
"\n",
|
||||
" instance = boto3.resource(\n",
|
||||
" \"ec2\",\n",
|
||||
" aws_access_key_id=CLOUD_CREDENTIALS_KEY or None,\n",
|
||||
" aws_secret_access_key=CLOUD_CREDENTIALS_SECRET or None,\n",
|
||||
" region_name=CLOUD_CREDENTIALS_REGION\n",
|
||||
" ).Instance(instance_id)\n",
|
||||
"\n",
|
||||
" # Wait until instance is in running state\n",
|
||||
" instance.wait_until_running()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Cloud-specific implementation (currently, only AWS EC2 is supported)\n",
|
||||
"def spin_down_worker(instance_id):\n",
|
||||
" \"\"\"\n",
|
||||
" Destroys the cloud instance.\n",
|
||||
"\n",
|
||||
" :param str instance_id: Cloud instance ID to be destroyed \n",
|
||||
" (currently, only AWS EC2 is supported)\n",
|
||||
" \"\"\"\n",
|
||||
" try:\n",
|
||||
" boto3.resource(\n",
|
||||
" \"ec2\",\n",
|
||||
" aws_access_key_id=CLOUD_CREDENTIALS_KEY or None,\n",
|
||||
" aws_secret_access_key=CLOUD_CREDENTIALS_SECRET or None,\n",
|
||||
" region_name=CLOUD_CREDENTIALS_REGION\n",
|
||||
" ).instances.filter(InstanceIds=[instance_id]).terminate()\n",
|
||||
" except Exception as ex:\n",
|
||||
" raise ex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"###### Controller Implementation and Logic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def supervisor():\n",
|
||||
" \"\"\"\n",
|
||||
" Spin up or down resources as necessary.\n",
|
||||
" - For every queue in QUEUES do the following:\n",
|
||||
" 1. Check if there are tasks waiting in the queue.\n",
|
||||
" 2. Check if there are enough idle workers available for those tasks.\n",
|
||||
" 3. In case more instances are required, and we haven't reached max instances allowed,\n",
|
||||
" create the required instances with regards to the maximum number defined in QUEUES\n",
|
||||
" Choose which instance to create according to their order QUEUES. Won't create \n",
|
||||
" more instances if maximum number defined has already reached.\n",
|
||||
" - spin down instances according to their idle time. instance which is idle for \n",
|
||||
" more than MAX_IDLE_TIME_MIN minutes would be removed.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Internal definitions\n",
|
||||
" workers_prefix = \"dynamic_aws\"\n",
|
||||
" # Worker's id in trains would be composed from:\n",
|
||||
" # prefix, name, instance_type and cloud_id separated by ';'\n",
|
||||
" workers_pattern = re.compile(\n",
|
||||
" r\"^(?P<prefix>[^:]+):(?P<name>[^:]+):(?P<instance_type>[^:]+):(?P<cloud_id>[^:]+)\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Set up the environment variables for trains\n",
|
||||
" os.environ[\"TRAINS_API_HOST\"] = TRAINS_SERVER_API_SERVER\n",
|
||||
" os.environ[\"TRAINS_WEB_HOST\"] = TRAINS_SERVER_WEB_SERVER\n",
|
||||
" os.environ[\"TRAINS_FILES_HOST\"] = TRAINS_SERVER_FILES_SERVER\n",
|
||||
" os.environ[\"TRAINS_API_ACCESS_KEY\"] = TRAINS_ACCESS_KEY\n",
|
||||
" os.environ[\"TRAINS_API_SECRET_KEY\"] = TRAINS_SECRET_KEY\n",
|
||||
" api_client = APIClient()\n",
|
||||
"\n",
|
||||
" # Verify the requested queues exist and create those that doesn't exist\n",
|
||||
" all_queues = [q.name for q in list(api_client.queues.get_all())]\n",
|
||||
" missing_queues = [q for q in QUEUES if q not in all_queues]\n",
|
||||
" for q in missing_queues:\n",
|
||||
" api_client.queues.create(q)\n",
|
||||
"\n",
|
||||
" idle_workers = {}\n",
|
||||
" while True:\n",
|
||||
" queue_name_to_id = {\n",
|
||||
" queue.name: queue.id for queue in api_client.queues.get_all()\n",
|
||||
" }\n",
|
||||
" resource_to_queue = {\n",
|
||||
" item[0]: queue\n",
|
||||
" for queue, resources in QUEUES.items()\n",
|
||||
" for item in resources\n",
|
||||
" }\n",
|
||||
" all_workers = [\n",
|
||||
" worker\n",
|
||||
" for worker in api_client.workers.get_all()\n",
|
||||
" if workers_pattern.match(worker.id)\n",
|
||||
" and workers_pattern.match(worker.id)[\"prefix\"] == workers_prefix\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # Workers without a task, are added to the idle list\n",
|
||||
" for worker in all_workers:\n",
|
||||
" if not hasattr(worker, \"task\") or not worker.task:\n",
|
||||
" if worker.id not in idle_workers:\n",
|
||||
" resource_name = workers_pattern.match(worker.id)[\"instance_type\"]\n",
|
||||
" idle_workers[worker.id] = (time(), resource_name, worker)\n",
|
||||
" elif hasattr(worker, \"task\") and worker.task and worker.id in idle_workers:\n",
|
||||
" idle_workers.pop(worker.id, None)\n",
|
||||
"\n",
|
||||
" required_idle_resources = [] # idle resources we'll need to keep running\n",
|
||||
" allocate_new_resources = [] # resources that will need to be started\n",
|
||||
" # Check if we have tasks waiting on one of the designated queues\n",
|
||||
" for queue in QUEUES:\n",
|
||||
" entries = api_client.queues.get_by_id(queue_name_to_id[queue]).entries\n",
|
||||
" if entries and len(entries) > 0:\n",
|
||||
" queue_resources = QUEUES[queue]\n",
|
||||
"\n",
|
||||
" # If we have an idle worker matching the required resource,\n",
|
||||
" # remove it from the required allocation resources\n",
|
||||
" free_queue_resources = [\n",
|
||||
" resource\n",
|
||||
" for _, resource, _ in idle_workers.values()\n",
|
||||
" if resource in queue_resources\n",
|
||||
" ]\n",
|
||||
" required_idle_resources.extend(free_queue_resources)\n",
|
||||
" spin_up_count = len(entries) - len(free_queue_resources)\n",
|
||||
" spin_up_resources = []\n",
|
||||
"\n",
|
||||
" # Add as many resources as possible to handle this queue's entries\n",
|
||||
" for resource, max_instances in queue_resources:\n",
|
||||
" if len(spin_up_resources) >= spin_up_count:\n",
|
||||
" break\n",
|
||||
" max_allowed = max_instances - len(\n",
|
||||
" [\n",
|
||||
" worker\n",
|
||||
" for worker in all_workers\n",
|
||||
" if workers_pattern.match(worker.id)[\"name\"] == resource\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" spin_up_resources.extend(\n",
|
||||
" [resource] * min(max_allowed, spin_up_count)\n",
|
||||
" )\n",
|
||||
" allocate_new_resources.extend(spin_up_resources)\n",
|
||||
"\n",
|
||||
" # Now we actually spin the new machines\n",
|
||||
" for resource in allocate_new_resources:\n",
|
||||
" spin_up_worker(resource, workers_prefix, resource_to_queue[resource])\n",
|
||||
"\n",
|
||||
" # Go over the idle workers list, and spin down idle workers\n",
|
||||
" for timestamp, resources, worker in idle_workers.values():\n",
|
||||
" # skip resource types that might be needed\n",
|
||||
" if resources in required_idle_resources:\n",
|
||||
" continue\n",
|
||||
" # Remove from both aws and trains all instances that are \n",
|
||||
" # idle for longer than MAX_IDLE_TIME_MIN\n",
|
||||
" if time() - timestamp > MAX_IDLE_TIME_MIN * 60.0:\n",
|
||||
" cloud_id = workers_pattern.match(worker.id)[\"cloud_id\"]\n",
|
||||
" spin_down_worker(cloud_id)\n",
|
||||
" worker.unregister()\n",
|
||||
"\n",
|
||||
" # Nothing else to do\n",
|
||||
" sleep(POLLING_INTERVAL_MIN * 60.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Execute Forever* (the controller is stateless, so you can always re-execute the notebook)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Loop forever, it is okay we are stateless\n",
|
||||
"while True:\n",
|
||||
" try:\n",
|
||||
" supervisor()\n",
|
||||
" except Exception as ex:\n",
|
||||
" print(\"Warning! exception occurred: {ex}\\nRetry in 15 seconds\".format(ex=ex))\n",
|
||||
" sleep(15)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
},
|
||||
"pycharm": {
|
||||
"stem_cell": {
|
||||
"cell_type": "raw",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
75
examples/k8s_glue_example.py
Normal file
75
examples/k8s_glue_example.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
This example assumes you have preconfigured services with selectors in the form of
|
||||
"ai.allegro.agent.serial=pod-<number>" and a targetPort of 10022.
|
||||
The K8sIntegration component will label each pod accordingly.
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from trains_agent.glue.k8s import K8sIntegration
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--queue", type=str, help="Queue to pull tasks from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ports-mode", action='store_true', default=False,
|
||||
help="Ports-Mode will add a label to the pod which can be used as service, in order to expose ports"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-of-services", type=int, default=20,
|
||||
help="Specify the number of k8s services to be used. Use only with ports-mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-port", type=int,
|
||||
help="Used in conjunction with ports-mode, specifies the base port exposed by the services. "
|
||||
"For pod #X, the port will be <base-port>+X"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gateway-address", type=str, default=None,
|
||||
help="Used in conjunction with ports-mode, specify the external address of the k8s ingress / ELB"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pod-trains-conf", type=str,
|
||||
help="Configuration file to be used by the pod itself (if not provided, current configuration is used)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overrides-yaml", type=str,
|
||||
help="YAML file containing pod overrides to be used when launching a new pod"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-yaml", type=str,
|
||||
help="YAML file containing pod template. If provided pod will be scheduled with kubectl apply "
|
||||
"and overrides are ignored, otherwise it will be scheduled with kubectl run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssh-server-port", type=int, default=0,
|
||||
help="If non-zero, every pod will also start an SSH server on the selected port (default: zero, not active)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
user_props_cb = None
|
||||
if args.ports_mode and args.base_port:
|
||||
def k8s_user_props_cb(pod_number):
|
||||
user_prop = {"k8s-pod-port": args.base_port + pod_number}
|
||||
if args.gateway_address:
|
||||
user_prop["k8s-gateway-address"] = args.gateway_address
|
||||
return user_prop
|
||||
user_props_cb = k8s_user_props_cb
|
||||
|
||||
k8s = K8sIntegration(
|
||||
ports_mode=args.ports_mode, num_of_services=args.num_of_services, user_props_cb=user_props_cb,
|
||||
overrides_yaml=args.overrides_yaml, trains_conf_file=args.pod_trains_conf, template_yaml=args.template_yaml,
|
||||
extra_bash_init_script=K8sIntegration.get_ssh_server_bash(
|
||||
ssh_port_number=args.ssh_server_port) if args.ssh_server_port else None
|
||||
)
|
||||
k8s.k8s_daemon(args.queue)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,7 +3,6 @@ enum34>=0.9 ; python_version < '3.6'
|
||||
furl>=2.0.0
|
||||
future>=0.16.0
|
||||
humanfriendly>=2.1
|
||||
jsonmodels>=2.2
|
||||
jsonschema>=2.6.0
|
||||
pathlib2>=2.3.0
|
||||
psutil>=3.4.2
|
||||
@@ -14,10 +13,8 @@ pyjwt>=1.6.4
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.20.0
|
||||
requirements_parser>=0.2.0
|
||||
semantic_version>=2.6.0
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
typing>=3.6.4
|
||||
urllib3>=1.21.1
|
||||
virtualenv>=16
|
||||
virtualenv>=16,<20
|
||||
|
||||
27
setup.py
27
setup.py
@@ -4,28 +4,31 @@ TRAINS-AGENT DevOps for machine/deep learning
|
||||
https://github.com/allegroai/trains-agent
|
||||
"""
|
||||
|
||||
import os.path
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import setup, find_packages
|
||||
from six import exec_
|
||||
from pathlib2 import Path
|
||||
|
||||
def read_text(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
return f.read()
|
||||
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
here = os.path.dirname(__file__)
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'README.md').read_text()
|
||||
long_description = read_text(os.path.join(here, 'README.md'))
|
||||
|
||||
|
||||
def read_version_string():
|
||||
result = {}
|
||||
exec_((here / 'trains_agent/version.py').read_text(), result)
|
||||
return result['__version__']
|
||||
def read_version_string(version_file):
|
||||
for line in read_text(version_file).splitlines():
|
||||
if line.startswith('__version__'):
|
||||
delim = '"' if '"' in line else "'"
|
||||
return line.split(delim)[1]
|
||||
else:
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
version = read_version_string()
|
||||
|
||||
requirements = (here / 'requirements.txt').read_text().splitlines()
|
||||
version = read_version_string("trains_agent/version.py")
|
||||
|
||||
requirements = read_text(os.path.join(here, 'requirements.txt')).splitlines()
|
||||
|
||||
setup(
|
||||
name='trains_agent',
|
||||
|
||||
@@ -35,7 +35,7 @@ def trains_agentyaml(tmpdir):
|
||||
def _method(template_file):
|
||||
file = tmpdir.join("trains_agent.yaml")
|
||||
with (PROJECT_ROOT / "tests/templates" / template_file).open() as f:
|
||||
code = yaml.load(f)
|
||||
code = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
yield Namespace(code=code, file=file.strpath)
|
||||
file.write(yaml.dump(code))
|
||||
return _method
|
||||
|
||||
@@ -30,6 +30,6 @@ from trains_agent.helper.repo import VCS
|
||||
),
|
||||
)
|
||||
def test(url, expected):
|
||||
result = VCS.resolve_ssh_url(url)
|
||||
result = VCS.replace_ssh_url(url)
|
||||
expected = expected or url
|
||||
assert result == expected
|
||||
|
||||
@@ -11,7 +11,7 @@ from contextlib import contextmanager
|
||||
from typing import Iterator, ContextManager, Sequence, IO, Text
|
||||
from uuid import uuid4
|
||||
|
||||
from trains_agent.backend_api.services.tasks import Script
|
||||
from trains_agent.backend_api.services import tasks
|
||||
from trains_agent.backend_api.session.client import APIClient
|
||||
from pathlib2 import Path
|
||||
from pytest import fixture
|
||||
@@ -154,7 +154,7 @@ def test_entry_point_warning(client):
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff="print('hello')", entry_point="foo.py", repository=""),
|
||||
script=tasks.Script(diff="print('hello')", entry_point="foo.py", repository=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
for line in output:
|
||||
@@ -172,7 +172,7 @@ def test_run_no_dirs(client):
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff=script, entry_point="", repository="", working_dir=""),
|
||||
script=tasks.Script(diff=script, entry_point="", repository="", working_dir=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
search_lines(
|
||||
@@ -196,7 +196,7 @@ def test_run_working_dir(client):
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
script=tasks.Script(
|
||||
diff=script,
|
||||
entry_point="",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
@@ -223,7 +223,7 @@ def test_regular_task(client):
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
script=tasks.Script(
|
||||
entry_point="noop.py",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
),
|
||||
@@ -241,7 +241,7 @@ def test_regular_task_nested(client):
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
script=tasks.Script(
|
||||
entry_point="noop_nested.py",
|
||||
working_dir="no_reqs",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
|
||||
@@ -1 +1 @@
|
||||
|
||||
from .backend_api.session.client import APIClient
|
||||
|
||||
@@ -20,6 +20,8 @@ from .interface import get_parser
|
||||
def run_command(parser, args, command_name):
|
||||
|
||||
debug = args.debug
|
||||
session.Session.set_debug_mode(debug)
|
||||
|
||||
if command_name and command_name.lower() in ('config', 'init'):
|
||||
command_class = commands.Config
|
||||
elif len(command_name.split('.')) < 2:
|
||||
|
||||
@@ -9,17 +9,32 @@
|
||||
# worker_name: "trains-agent-machine1"
|
||||
worker_name: ""
|
||||
|
||||
# Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials.
|
||||
# Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
|
||||
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
|
||||
# git_user: ""
|
||||
# git_pass: ""
|
||||
# git_host: ""
|
||||
|
||||
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
|
||||
force_git_ssh_protocol: false
|
||||
# Force a specific SSH port when converting http to ssh links (the domain is kept the same)
|
||||
# force_git_ssh_port: 0
|
||||
|
||||
# Set the python version to use when creating the virtual environment and launching the experiment
|
||||
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
|
||||
# The default is the python executing the trains_agent
|
||||
python_binary: ""
|
||||
|
||||
# select python package manager:
|
||||
# currently supported pip and conda
|
||||
# poetry is used if pip selected and repository contains poetry.lock file
|
||||
package_manager: {
|
||||
# supported options: pip, conda
|
||||
# supported options: pip, conda, poetry
|
||||
type: pip,
|
||||
|
||||
# specify pip version to use (examples "<20", "==19.3.1", "", empty string will install the latest version)
|
||||
pip_version: "<20.2",
|
||||
|
||||
# virtual environment inheres packages from system
|
||||
system_site_packages: false,
|
||||
|
||||
@@ -28,10 +43,33 @@
|
||||
|
||||
# additional artifact repositories to use when installing python packages
|
||||
# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]
|
||||
extra_index_url: []
|
||||
|
||||
# additional conda channels to use when installing with conda package manager
|
||||
conda_channels: ["defaults", "conda-forge", "pytorch", ]
|
||||
|
||||
# If set to true, Task's "installed packages" are ignored,
|
||||
# and the repository's "requirements.txt" is used instead
|
||||
# force_repo_requirements_txt: false
|
||||
|
||||
# set the priority packages to be installed before the rest of the required packages
|
||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||
|
||||
# set the optional priority packages to be installed before the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# priority_optional_packages: ["pygobject", ]
|
||||
|
||||
# set the post packages to be installed after all the rest of the required packages
|
||||
# post_packages: ["horovod", ]
|
||||
|
||||
# set the optional post packages to be installed after all the rest of the required packages,
|
||||
# In case a package installation fails, the package will be ignored,
|
||||
# and the virtual environment process will continue
|
||||
# post_optional_packages: []
|
||||
|
||||
# set to True to support torch nightly build installation,
|
||||
# notice: torch nightly builds are ephemeral and are deleted from time to time
|
||||
torch_nightly: false,
|
||||
},
|
||||
|
||||
# target folder for virtual environments builds, created when executing experiment
|
||||
@@ -59,19 +97,67 @@
|
||||
# reload configuration file every daemon execution
|
||||
reload_config: false,
|
||||
|
||||
# pip cache folder used mapped into docker, for python package caching
|
||||
# pip cache folder mapped into docker, used for python package caching
|
||||
docker_pip_cache = ~/.trains/pip-cache
|
||||
# apt cache folder used mapped into docker, for ubuntu package caching
|
||||
# apt cache folder mapped into docker, used for ubuntu package caching
|
||||
docker_apt_cache = ~/.trains/apt-cache
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# these are local for this agent and will not be updated in the experiment's docker_cmd section
|
||||
# extra_docker_arguments: ["--ipc=host", ]
|
||||
|
||||
# optional shell script to run in docker when started before the experiment is started
|
||||
# extra_docker_shell_script: ["apt-get install -y bindfs", ]
|
||||
|
||||
# optional uptime configuration, make sure to use only one of 'uptime/downtime' and not both.
|
||||
# If uptime is specified, agent will actively poll (and execute) tasks in the time-spans defined here.
|
||||
# Outside of the specified time-spans, the agent will be idle.
|
||||
# Defined using a list of items of the format: "<hours> <days>".
|
||||
# hours - use values 0-23, single values would count as start hour and end at midnight.
|
||||
# days - use days in abbreviated format (SUN-SAT)
|
||||
# use '-' for ranges and ',' to separate singular values.
|
||||
# for example, to enable the workers every Sunday and Tuesday between 17:00-20:00 set uptime to:
|
||||
# uptime: ["17-20 SUN,TUE"]
|
||||
|
||||
# optional downtime configuration, can be used only when uptime is not used.
|
||||
# If downtime is specified, agent will be idle in the time-spans defined here.
|
||||
# Outside of the specified time-spans, the agent will actively poll (and execute) tasks.
|
||||
# Use the same format as described above for uptime
|
||||
# downtime: []
|
||||
|
||||
# set to true in order to force "docker pull" before running an experiment using a docker image.
|
||||
# This makes sure the docker image is updated.
|
||||
docker_force_pull: false
|
||||
|
||||
default_docker: {
|
||||
# default docker image to use when running in docker mode
|
||||
image: "nvidia/cuda"
|
||||
image: "nvidia/cuda:10.1-runtime-ubuntu18.04"
|
||||
|
||||
# optional arguments to pass to docker image
|
||||
# arguments: ["--ipc=host", ]
|
||||
}
|
||||
|
||||
# set the initial bash script to execute at the startup of any docker.
|
||||
# all lines will be executed regardless of their exit code.
|
||||
# {python_single_digit} is translated to 'python3' or 'python2' according to requested python version
|
||||
# docker_init_bash_script = [
|
||||
# "echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
# "chown -R root /root/.cache/pip",
|
||||
# "apt-get update",
|
||||
# "apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
# "(which {python_single_digit} && {python_single_digit} -m pip --version) || apt-get install -y {python_single_digit}-pip",
|
||||
# ]
|
||||
|
||||
# set the preprocessing bash script to execute at the startup of any docker.
|
||||
# all lines will be executed regardless of their exit code.
|
||||
# docker_preprocess_bash_script = [
|
||||
# "echo \"starting docker\"",
|
||||
#]
|
||||
|
||||
# If False replace \r with \n and display full console output
|
||||
# default is True, report a single \r line in a sequence of consecutive lines, per 5 seconds.
|
||||
# suppress_carriage_return: true
|
||||
|
||||
# cuda versions used for solving pytorch wheel packages
|
||||
# should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
|
||||
# cuda_version: 10.1
|
||||
|
||||
@@ -26,12 +26,23 @@
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Max history size for matplotlib imshow files per plot title.
|
||||
# File names for the uploaded images will be recycled in such a way that no more than
|
||||
# X images are stored in the upload destination for each matplotlib plot title.
|
||||
matplotlib_untitled_history_size: 100
|
||||
|
||||
# Limit the number of digits after the dot in plot reporting (reducing plot report size)
|
||||
# plot_max_num_digits: 5
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
|
||||
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph)
|
||||
tensorboard_single_series_per_graph: false
|
||||
}
|
||||
|
||||
network {
|
||||
@@ -112,11 +123,11 @@
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
null_log_propagate: false
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
disable_urllib3_info: true
|
||||
}
|
||||
|
||||
development {
|
||||
@@ -126,14 +137,30 @@
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: True
|
||||
vcs_repo_detect_async: true
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
store_uncommitted_code_diff: true
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
support_stopping: true
|
||||
|
||||
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
|
||||
default_output_uri: ""
|
||||
|
||||
# Default auto generated requirements optimize for smaller requirements
|
||||
# If True, analyze the entire repository regardless of the entry point.
|
||||
# If False, first analyze the entry point script, if it does not contain other to local files,
|
||||
# do not analyze the entire repository.
|
||||
force_analyze_entire_repo: false
|
||||
|
||||
# If set to true, *trains* update message will not be printed to the console
|
||||
# this value can be overwritten with os environment variable TRAINS_SUPPRESS_UPDATE_MESSAGE=1
|
||||
suppress_update_message: false
|
||||
|
||||
# If this flag is true (default is false), instead of analyzing the code with Pigar, analyze with `pip freeze`
|
||||
detect_with_pip_freeze: false
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
@@ -144,7 +171,11 @@
|
||||
ping_period_sec: 30
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
log_stdout: true
|
||||
|
||||
# compatibility feature, report memory usage for the entire machine
|
||||
# default (false), report only on the running process and its sub-processes
|
||||
report_global_mem_used: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
from .v2_4 import auth
|
||||
from .v2_4 import debug
|
||||
from .v2_4 import queues
|
||||
from .v2_4 import tasks
|
||||
from .v2_4 import workers
|
||||
from .v2_5 import auth
|
||||
from .v2_5 import debug
|
||||
from .v2_5 import queues
|
||||
from .v2_5 import tasks
|
||||
from .v2_5 import workers
|
||||
from .v2_5 import events
|
||||
from .v2_5 import models
|
||||
|
||||
__all__ = [
|
||||
'auth',
|
||||
@@ -10,4 +12,6 @@ __all__ = [
|
||||
'queues',
|
||||
'tasks',
|
||||
'workers',
|
||||
'events',
|
||||
'models',
|
||||
]
|
||||
|
||||
@@ -151,7 +151,7 @@ class CreateCredentialsRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
@@ -169,7 +169,7 @@ class CreateCredentialsResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
@@ -230,7 +230,7 @@ class EditUserRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -287,7 +287,7 @@ class EditUserResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -347,7 +347,7 @@ class GetCredentialsRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
@@ -365,7 +365,7 @@ class GetCredentialsResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
@@ -433,7 +433,7 @@ class LoginRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -474,7 +474,7 @@ class LoginResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -510,7 +510,7 @@ class LogoutRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.2"
|
||||
_version = "2.4"
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
@@ -521,7 +521,7 @@ class LogoutResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.2"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
@@ -537,7 +537,7 @@ class RevokeCredentialsRequest(Request):
|
||||
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -577,7 +577,7 @@ class RevokeCredentialsResponse(Response):
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
|
||||
@@ -19,7 +19,7 @@ class ApiexRequest(Request):
|
||||
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class ApiexResponse(Response):
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
@@ -43,7 +43,7 @@ class EchoRequest(Request):
|
||||
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class EchoResponse(Response):
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
@@ -65,7 +65,7 @@ class ExRequest(Request):
|
||||
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class ExResponse(Response):
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
@@ -89,7 +89,7 @@ class PingRequest(Request):
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ class PingResponse(Response):
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -141,7 +141,7 @@ class PingAuthRequest(Request):
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ class PingAuthResponse(Response):
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
|
||||
2977
trains_agent/backend_api/services/v2_4/events.py
Normal file
2977
trains_agent/backend_api/services/v2_4/events.py
Normal file
File diff suppressed because it is too large
Load Diff
2850
trains_agent/backend_api/services/v2_4/models.py
Normal file
2850
trains_agent/backend_api/services/v2_4/models.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1518,7 +1518,7 @@ class CloseRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "close"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -1612,7 +1612,7 @@ class CloseResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "close"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -1682,7 +1682,7 @@ class CompletedRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "completed"
|
||||
_version = "2.2"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -1776,7 +1776,7 @@ class CompletedResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "completed"
|
||||
_version = "2.2"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -1862,7 +1862,7 @@ class CreateRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "create"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'artifact': {
|
||||
@@ -2229,7 +2229,7 @@ class CreateResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "create"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -2280,7 +2280,7 @@ class DeleteRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "delete"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -2403,7 +2403,7 @@ class DeleteResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "delete"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -2547,7 +2547,7 @@ class DequeueRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "dequeue"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -2624,7 +2624,7 @@ class DequeueResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "dequeue"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -2733,7 +2733,7 @@ class EditRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "edit"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'artifact': {
|
||||
@@ -3123,7 +3123,7 @@ class EditResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "edit"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -3201,7 +3201,7 @@ class EnqueueRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "enqueue"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -3296,7 +3296,7 @@ class EnqueueResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "enqueue"
|
||||
_version = "1.5"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -3386,7 +3386,7 @@ class FailedRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "failed"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -3480,7 +3480,7 @@ class FailedResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "failed"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -3587,7 +3587,7 @@ class GetAllRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "get_all"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'multi_field_pattern_data': {
|
||||
@@ -3986,7 +3986,7 @@ class GetAllResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "get_all"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
@@ -4373,7 +4373,7 @@ class GetByIdRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "get_by_id"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {'task': {'description': 'Task ID', 'type': 'string'}},
|
||||
@@ -4408,7 +4408,7 @@ class GetByIdResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "get_by_id"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
@@ -4792,7 +4792,7 @@ class PingRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "ping"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {'task': {'description': 'Task ID', 'type': 'string'}},
|
||||
@@ -4825,7 +4825,7 @@ class PingResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "ping"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
@@ -4853,7 +4853,7 @@ class PublishRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "publish"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -4967,7 +4967,7 @@ class PublishResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "publish"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5057,7 +5057,7 @@ class ResetRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "reset"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -5160,7 +5160,7 @@ class ResetResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "reset"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5305,7 +5305,7 @@ class SetRequirementsRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "set_requirements"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -5362,7 +5362,7 @@ class SetRequirementsResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "set_requirements"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5431,7 +5431,7 @@ class StartedRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "started"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -5527,7 +5527,7 @@ class StartedResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "started"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5617,7 +5617,7 @@ class StopRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "stop"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -5711,7 +5711,7 @@ class StopResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "stop"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5780,7 +5780,7 @@ class StoppedRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "stopped"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -5874,7 +5874,7 @@ class StoppedResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "stopped"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -5952,7 +5952,7 @@ class UpdateRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "update"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
@@ -6120,7 +6120,7 @@ class UpdateResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "update"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -6183,7 +6183,7 @@ class UpdateBatchRequest(BatchRequest):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "update_batch"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_batched_request_cls = UpdateRequest
|
||||
|
||||
|
||||
@@ -6196,7 +6196,7 @@ class UpdateBatchResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "update_batch"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
@@ -6261,7 +6261,7 @@ class ValidateRequest(Request):
|
||||
|
||||
_service = "tasks"
|
||||
_action = "validate"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'artifact': {
|
||||
@@ -6614,7 +6614,7 @@ class ValidateResponse(Response):
|
||||
"""
|
||||
_service = "tasks"
|
||||
_action = "validate"
|
||||
_version = "2.1"
|
||||
_version = "2.4"
|
||||
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
0
trains_agent/backend_api/services/v2_5/__init__.py
Normal file
0
trains_agent/backend_api/services/v2_5/__init__.py
Normal file
623
trains_agent/backend_api/services/v2_5/auth.py
Normal file
623
trains_agent/backend_api/services/v2_5/auth.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
auth service
|
||||
|
||||
This service provides authentication management and authorization
|
||||
validation for the entire system.
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class Credentials(NonStrictDataModel):
|
||||
"""
|
||||
:param access_key: Credentials access key
|
||||
:type access_key: str
|
||||
:param secret_key: Credentials secret key
|
||||
:type secret_key: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, secret_key=None, **kwargs):
|
||||
super(Credentials, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
@schema_property('secret_key')
|
||||
def secret_key(self):
|
||||
return self._property_secret_key
|
||||
|
||||
@secret_key.setter
|
||||
def secret_key(self, value):
|
||||
if value is None:
|
||||
self._property_secret_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "secret_key", six.string_types)
|
||||
self._property_secret_key = value
|
||||
|
||||
|
||||
class CredentialKey(NonStrictDataModel):
|
||||
"""
|
||||
:param access_key:
|
||||
:type access_key: str
|
||||
:param last_used:
|
||||
:type last_used: datetime.datetime
|
||||
:param last_used_from:
|
||||
:type last_used_from: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'access_key': {'description': '', 'type': ['string', 'null']},
|
||||
'last_used': {
|
||||
'description': '',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'last_used_from': {'description': '', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, last_used=None, last_used_from=None, **kwargs):
|
||||
super(CredentialKey, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
self.last_used = last_used
|
||||
self.last_used_from = last_used_from
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
@schema_property('last_used')
|
||||
def last_used(self):
|
||||
return self._property_last_used
|
||||
|
||||
@last_used.setter
|
||||
def last_used(self, value):
|
||||
if value is None:
|
||||
self._property_last_used = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "last_used", six.string_types + (datetime,))
|
||||
if not isinstance(value, datetime):
|
||||
value = parse_datetime(value)
|
||||
self._property_last_used = value
|
||||
|
||||
@schema_property('last_used_from')
|
||||
def last_used_from(self):
|
||||
return self._property_last_used_from
|
||||
|
||||
@last_used_from.setter
|
||||
def last_used_from(self, value):
|
||||
if value is None:
|
||||
self._property_last_used_from = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "last_used_from", six.string_types)
|
||||
self._property_last_used_from = value
|
||||
|
||||
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Request):
|
||||
"""
|
||||
Creates a new set of credentials for the authenticated user.
|
||||
New key/secret is returned.
|
||||
Note: Secret will never be returned in any other API call.
|
||||
If a secret is lost or compromised, the key should be revoked
|
||||
and a new set of credentials can be created.
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.5"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
'properties': {},
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.create_credentials endpoint.
|
||||
|
||||
:param credentials: Created credentials
|
||||
:type credentials: Credentials
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "create_credentials"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credentials': {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'credentials': {
|
||||
'description': 'Created credentials',
|
||||
'oneOf': [{'$ref': '#/definitions/credentials'}, {'type': 'null'}],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, credentials=None, **kwargs):
|
||||
super(CreateCredentialsResponse, self).__init__(**kwargs)
|
||||
self.credentials = credentials
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
if isinstance(value, dict):
|
||||
value = Credentials.from_dict(value)
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", Credentials)
|
||||
self._property_credentials = value
|
||||
|
||||
|
||||
|
||||
|
||||
class EditUserRequest(Request):
|
||||
"""
|
||||
Edit a users' auth data properties
|
||||
|
||||
:param user: User ID
|
||||
:type user: str
|
||||
:param role: The new user's role within the company
|
||||
:type role: str
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'role': {
|
||||
'description': "The new user's role within the company",
|
||||
'enum': ['admin', 'superuser', 'user', 'annotator'],
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'user': {'description': 'User ID', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, user=None, role=None, **kwargs):
|
||||
super(EditUserRequest, self).__init__(**kwargs)
|
||||
self.user = user
|
||||
self.role = role
|
||||
|
||||
@schema_property('user')
|
||||
def user(self):
|
||||
return self._property_user
|
||||
|
||||
@user.setter
|
||||
def user(self, value):
|
||||
if value is None:
|
||||
self._property_user = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "user", six.string_types)
|
||||
self._property_user = value
|
||||
|
||||
@schema_property('role')
|
||||
def role(self):
|
||||
return self._property_role
|
||||
|
||||
@role.setter
|
||||
def role(self, value):
|
||||
if value is None:
|
||||
self._property_role = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "role", six.string_types)
|
||||
self._property_role = value
|
||||
|
||||
|
||||
class EditUserResponse(Response):
|
||||
"""
|
||||
Response of auth.edit_user endpoint.
|
||||
|
||||
:param updated: Number of users updated (0 or 1)
|
||||
:type updated: float
|
||||
:param fields: Updated fields names and values
|
||||
:type fields: dict
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "edit_user"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'fields': {
|
||||
'additionalProperties': True,
|
||||
'description': 'Updated fields names and values',
|
||||
'type': ['object', 'null'],
|
||||
},
|
||||
'updated': {
|
||||
'description': 'Number of users updated (0 or 1)',
|
||||
'enum': [0, 1],
|
||||
'type': ['number', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, updated=None, fields=None, **kwargs):
|
||||
super(EditUserResponse, self).__init__(**kwargs)
|
||||
self.updated = updated
|
||||
self.fields = fields
|
||||
|
||||
@schema_property('updated')
|
||||
def updated(self):
|
||||
return self._property_updated
|
||||
|
||||
@updated.setter
|
||||
def updated(self, value):
|
||||
if value is None:
|
||||
self._property_updated = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "updated", six.integer_types + (float,))
|
||||
self._property_updated = value
|
||||
|
||||
@schema_property('fields')
|
||||
def fields(self):
|
||||
return self._property_fields
|
||||
|
||||
@fields.setter
|
||||
def fields(self, value):
|
||||
if value is None:
|
||||
self._property_fields = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "fields", (dict,))
|
||||
self._property_fields = value
|
||||
|
||||
|
||||
class GetCredentialsRequest(Request):
|
||||
"""
|
||||
Returns all existing credential keys for the authenticated user.
|
||||
Note: Only credential keys are returned.
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.5"
|
||||
_schema = {
|
||||
'additionalProperties': False,
|
||||
'definitions': {},
|
||||
'properties': {},
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
|
||||
class GetCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.get_credentials endpoint.
|
||||
|
||||
:param credentials: List of credentials, each with an empty secret field.
|
||||
:type credentials: Sequence[CredentialKey]
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "get_credentials"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credential_key': {
|
||||
'properties': {
|
||||
'access_key': {'description': '', 'type': ['string', 'null']},
|
||||
'last_used': {
|
||||
'description': '',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'last_used_from': {
|
||||
'description': '',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'credentials': {
|
||||
'description': 'List of credentials, each with an empty secret field.',
|
||||
'items': {'$ref': '#/definitions/credential_key'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, credentials=None, **kwargs):
|
||||
super(GetCredentialsResponse, self).__init__(**kwargs)
|
||||
self.credentials = credentials
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "credentials", (list, tuple))
|
||||
if any(isinstance(v, dict) for v in value):
|
||||
value = [CredentialKey.from_dict(v) if isinstance(v, dict) else v for v in value]
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", CredentialKey, is_array=True)
|
||||
self._property_credentials = value
|
||||
|
||||
|
||||
|
||||
|
||||
class LoginRequest(Request):
|
||||
"""
|
||||
Get a token based on supplied credentials (key/secret).
|
||||
Intended for use by users with key/secret credentials that wish to obtain a token
|
||||
for use with other services. Token will be limited by the same permissions that
|
||||
exist for the credentials used in this call.
|
||||
|
||||
:param expiration_sec: Requested token expiration time in seconds. Not
|
||||
guaranteed, might be overridden by the service
|
||||
:type expiration_sec: int
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'expiration_sec': {
|
||||
'description': 'Requested token expiration time in seconds. \n Not guaranteed, might be overridden by the service',
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, expiration_sec=None, **kwargs):
|
||||
super(LoginRequest, self).__init__(**kwargs)
|
||||
self.expiration_sec = expiration_sec
|
||||
|
||||
@schema_property('expiration_sec')
|
||||
def expiration_sec(self):
|
||||
return self._property_expiration_sec
|
||||
|
||||
@expiration_sec.setter
|
||||
def expiration_sec(self, value):
|
||||
if value is None:
|
||||
self._property_expiration_sec = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "expiration_sec", six.integer_types)
|
||||
self._property_expiration_sec = value
|
||||
|
||||
|
||||
class LoginResponse(Response):
|
||||
"""
|
||||
Response of auth.login endpoint.
|
||||
|
||||
:param token: Token string
|
||||
:type token: str
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "login"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'token': {'description': 'Token string', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, token=None, **kwargs):
|
||||
super(LoginResponse, self).__init__(**kwargs)
|
||||
self.token = token
|
||||
|
||||
@schema_property('token')
|
||||
def token(self):
|
||||
return self._property_token
|
||||
|
||||
@token.setter
|
||||
def token(self, value):
|
||||
if value is None:
|
||||
self._property_token = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "token", six.string_types)
|
||||
self._property_token = value
|
||||
|
||||
|
||||
class LogoutRequest(Request):
|
||||
"""
|
||||
Removes the authentication cookie from the current session
|
||||
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.5"
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class LogoutResponse(Response):
|
||||
"""
|
||||
Response of auth.logout endpoint.
|
||||
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "logout"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {'additionalProperties': False, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Request):
|
||||
"""
|
||||
Revokes (and deletes) a set (key, secret) of credentials for
|
||||
the authenticated user.
|
||||
|
||||
:param access_key: Credentials key
|
||||
:type access_key: str
|
||||
"""
|
||||
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'required': ['key_id'],
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, **kwargs):
|
||||
super(RevokeCredentialsRequest, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
|
||||
class RevokeCredentialsResponse(Response):
|
||||
"""
|
||||
Response of auth.revoke_credentials endpoint.
|
||||
|
||||
:param revoked: Number of credentials revoked
|
||||
:type revoked: int
|
||||
"""
|
||||
_service = "auth"
|
||||
_action = "revoke_credentials"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'revoked': {
|
||||
'description': 'Number of credentials revoked',
|
||||
'enum': [0, 1],
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, revoked=None, **kwargs):
|
||||
super(RevokeCredentialsResponse, self).__init__(**kwargs)
|
||||
self.revoked = revoked
|
||||
|
||||
@schema_property('revoked')
|
||||
def revoked(self):
|
||||
return self._property_revoked
|
||||
|
||||
@revoked.setter
|
||||
def revoked(self, value):
|
||||
if value is None:
|
||||
self._property_revoked = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "revoked", six.integer_types)
|
||||
self._property_revoked = value
|
||||
|
||||
|
||||
|
||||
|
||||
response_mapping = {
|
||||
LoginRequest: LoginResponse,
|
||||
LogoutRequest: LogoutResponse,
|
||||
CreateCredentialsRequest: CreateCredentialsResponse,
|
||||
GetCredentialsRequest: GetCredentialsResponse,
|
||||
RevokeCredentialsRequest: RevokeCredentialsResponse,
|
||||
EditUserRequest: EditUserResponse,
|
||||
}
|
||||
194
trains_agent/backend_api/services/v2_5/debug.py
Normal file
194
trains_agent/backend_api/services/v2_5/debug.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
debug service
|
||||
|
||||
Debugging utilities
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class ApiexRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "2.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ApiexResponse(Response):
|
||||
"""
|
||||
Response of debug.apiex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoRequest(Request):
|
||||
"""
|
||||
Return request data
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "2.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoResponse(Response):
|
||||
"""
|
||||
Response of debug.echo endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class ExRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "2.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ExResponse(Response):
|
||||
"""
|
||||
Response of debug.ex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingRequest(Request):
|
||||
"""
|
||||
Return a message. Does not require authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "2.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingResponse(Response):
|
||||
"""
|
||||
Response of debug.ping endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
class PingAuthRequest(Request):
|
||||
"""
|
||||
Return a message. Requires authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "2.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingAuthResponse(Response):
|
||||
"""
|
||||
Response of debug.ping_auth endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "2.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingAuthResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
response_mapping = {
|
||||
EchoRequest: EchoResponse,
|
||||
PingRequest: PingResponse,
|
||||
PingAuthRequest: PingAuthResponse,
|
||||
ApiexRequest: ApiexResponse,
|
||||
ExRequest: ExResponse,
|
||||
}
|
||||
3000
trains_agent/backend_api/services/v2_5/events.py
Normal file
3000
trains_agent/backend_api/services/v2_5/events.py
Normal file
File diff suppressed because it is too large
Load Diff
2850
trains_agent/backend_api/services/v2_5/models.py
Normal file
2850
trains_agent/backend_api/services/v2_5/models.py
Normal file
File diff suppressed because it is too large
Load Diff
2198
trains_agent/backend_api/services/v2_5/queues.py
Normal file
2198
trains_agent/backend_api/services/v2_5/queues.py
Normal file
File diff suppressed because it is too large
Load Diff
7053
trains_agent/backend_api/services/v2_5/tasks.py
Normal file
7053
trains_agent/backend_api/services/v2_5/tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
2368
trains_agent/backend_api/services/v2_5/workers.py
Normal file
2368
trains_agent/backend_api/services/v2_5/workers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -139,13 +139,25 @@ class Response(object):
|
||||
:param dest: if all of a response's data is contained in one field, use that field
|
||||
:type dest: Text
|
||||
"""
|
||||
self.response = None
|
||||
self._result = result
|
||||
response = getattr(result, "response", result)
|
||||
if getattr(response, "_service") == "events" and \
|
||||
getattr(response, "_action") in ("scalar_metrics_iter_histogram",
|
||||
"multi_task_scalar_metrics_iter_histogram",
|
||||
"vector_metrics_iter_histogram",
|
||||
):
|
||||
# put all the response data under metrics:
|
||||
response.metrics = result.response_data
|
||||
if 'metrics' not in response.__class__._get_data_props():
|
||||
response.__class__._data_props_list['metrics'] = 'metrics'
|
||||
if dest:
|
||||
response = getattr(response, dest)
|
||||
self.response = response
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if self.response is None:
|
||||
return None
|
||||
return getattr(self.response, attr)
|
||||
|
||||
@property
|
||||
@@ -212,8 +224,8 @@ class TableResponse(Response):
|
||||
fields = fields or self.fields
|
||||
from trains_agent.helper.base import create_table
|
||||
return create_table(
|
||||
(tuple(getter(item, attr) for attr in fields) for item in self),
|
||||
titles=fields, headers=True,
|
||||
(dict((attr, getter(item, attr)) for attr in fields) for item in self),
|
||||
titles=fields, columns=fields, headers=True,
|
||||
)
|
||||
|
||||
def display(self, fields=None):
|
||||
@@ -493,6 +505,7 @@ class APIClient(object):
|
||||
queues = None # type: Any
|
||||
tasks = None # type: Any
|
||||
workers = None # type: Any
|
||||
events = None # type: Any
|
||||
|
||||
def __init__(self, session=None, api_version=None):
|
||||
self.session = session or StrictSession()
|
||||
|
||||
@@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "TRAINS_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
ENV_CONDA_ENV_PACKAGE = EnvEntry("TRAINS_CONDA_ENV_PACKAGE", "TRAINS_CONDA_ENV_PACKAGE")
|
||||
|
||||
9
trains_agent/backend_api/session/jsonmodels/__init__.py
Normal file
9
trains_agent/backend_api/session/jsonmodels/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# coding: utf-8
|
||||
|
||||
__author__ = 'Szczepan Cieślik'
|
||||
__email__ = 'szczepan.cieslik@gmail.com'
|
||||
__version__ = '2.4'
|
||||
|
||||
from . import models
|
||||
from . import fields
|
||||
from . import errors
|
||||
230
trains_agent/backend_api/session/jsonmodels/builders.py
Normal file
230
trains_agent/backend_api/session/jsonmodels/builders.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Builders to generate in memory representation of model and fields tree."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import six
|
||||
|
||||
from . import errors
|
||||
from .fields import NotSet
|
||||
|
||||
|
||||
class Builder(object):
|
||||
|
||||
def __init__(self, parent=None, nullable=False, default=NotSet):
|
||||
self.parent = parent
|
||||
self.types_builders = {}
|
||||
self.types_count = defaultdict(int)
|
||||
self.definitions = set()
|
||||
self.nullable = nullable
|
||||
self.default = default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self.default is not NotSet
|
||||
|
||||
def register_type(self, type, builder):
|
||||
if self.parent:
|
||||
return self.parent.register_type(type, builder)
|
||||
|
||||
self.types_count[type] += 1
|
||||
if type not in self.types_builders:
|
||||
self.types_builders[type] = builder
|
||||
|
||||
def get_builder(self, type):
|
||||
if self.parent:
|
||||
return self.parent.get_builder(type)
|
||||
|
||||
return self.types_builders[type]
|
||||
|
||||
def count_type(self, type):
|
||||
if self.parent:
|
||||
return self.parent.count_type(type)
|
||||
|
||||
return self.types_count[type]
|
||||
|
||||
@staticmethod
|
||||
def maybe_build(value):
|
||||
return value.build() if isinstance(value, Builder) else value
|
||||
|
||||
def add_definition(self, builder):
|
||||
if self.parent:
|
||||
return self.parent.add_definition(builder)
|
||||
|
||||
self.definitions.add(builder)
|
||||
|
||||
|
||||
class ObjectBuilder(Builder):
|
||||
|
||||
def __init__(self, model_type, *args, **kwargs):
|
||||
super(ObjectBuilder, self).__init__(*args, **kwargs)
|
||||
self.properties = {}
|
||||
self.required = []
|
||||
self.type = model_type
|
||||
|
||||
self.register_type(self.type, self)
|
||||
|
||||
def add_field(self, name, field, schema):
|
||||
_apply_validators_modifications(schema, field)
|
||||
self.properties[name] = schema
|
||||
if field.required:
|
||||
self.required.append(name)
|
||||
|
||||
def build(self):
|
||||
builder = self.get_builder(self.type)
|
||||
if self.is_definition and not self.is_root:
|
||||
self.add_definition(builder)
|
||||
[self.maybe_build(value) for _, value in self.properties.items()]
|
||||
return '#/definitions/{name}'.format(name=self.type_name)
|
||||
else:
|
||||
return builder.build_definition(nullable=self.nullable)
|
||||
|
||||
@property
|
||||
def type_name(self):
|
||||
module_name = '{module}.{name}'.format(
|
||||
module=self.type.__module__,
|
||||
name=self.type.__name__,
|
||||
)
|
||||
return module_name.replace('.', '_').lower()
|
||||
|
||||
def build_definition(self, add_defintitions=True, nullable=False):
|
||||
properties = dict(
|
||||
(name, self.maybe_build(value))
|
||||
for name, value
|
||||
in self.properties.items()
|
||||
)
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'additionalProperties': False,
|
||||
'properties': properties,
|
||||
}
|
||||
if self.required:
|
||||
schema['required'] = self.required
|
||||
if self.definitions and add_defintitions:
|
||||
schema['definitions'] = dict(
|
||||
(builder.type_name, builder.build_definition(False, False))
|
||||
for builder in self.definitions
|
||||
)
|
||||
return schema
|
||||
|
||||
@property
|
||||
def is_definition(self):
|
||||
if self.count_type(self.type) > 1:
|
||||
return True
|
||||
elif self.parent:
|
||||
return self.parent.is_definition
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_root(self):
|
||||
return not bool(self.parent)
|
||||
|
||||
|
||||
def _apply_validators_modifications(field_schema, field):
|
||||
for validator in field.validators:
|
||||
try:
|
||||
validator.modify_schema(field_schema)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class PrimitiveBuilder(Builder):
|
||||
|
||||
def __init__(self, type, *args, **kwargs):
|
||||
super(PrimitiveBuilder, self).__init__(*args, **kwargs)
|
||||
self.type = type
|
||||
|
||||
def build(self):
|
||||
schema = {}
|
||||
if issubclass(self.type, six.string_types):
|
||||
obj_type = 'string'
|
||||
elif issubclass(self.type, bool):
|
||||
obj_type = 'boolean'
|
||||
elif issubclass(self.type, int):
|
||||
obj_type = 'number'
|
||||
elif issubclass(self.type, float):
|
||||
obj_type = 'number'
|
||||
else:
|
||||
raise errors.FieldNotSupported(
|
||||
"Can't specify value schema!", self.type
|
||||
)
|
||||
|
||||
if self.nullable:
|
||||
obj_type = [obj_type, 'null']
|
||||
schema['type'] = obj_type
|
||||
|
||||
if self.has_default:
|
||||
schema["default"] = self.default
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
class ListBuilder(Builder):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ListBuilder, self).__init__(*args, **kwargs)
|
||||
self.schemas = []
|
||||
|
||||
def add_type_schema(self, schema):
|
||||
self.schemas.append(schema)
|
||||
|
||||
def build(self):
|
||||
schema = {'type': 'array'}
|
||||
if self.nullable:
|
||||
self.add_type_schema({'type': 'null'})
|
||||
|
||||
if self.has_default:
|
||||
schema["default"] = [self.to_struct(i) for i in self.default]
|
||||
|
||||
schemas = [self.maybe_build(s) for s in self.schemas]
|
||||
if len(schemas) == 1:
|
||||
items = schemas[0]
|
||||
else:
|
||||
items = {'oneOf': schemas}
|
||||
|
||||
schema['items'] = items
|
||||
return schema
|
||||
|
||||
@property
|
||||
def is_definition(self):
|
||||
return self.parent.is_definition
|
||||
|
||||
@staticmethod
|
||||
def to_struct(item):
|
||||
from .models import Base
|
||||
if isinstance(item, Base):
|
||||
return item.to_struct()
|
||||
return item
|
||||
|
||||
|
||||
class EmbeddedBuilder(Builder):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EmbeddedBuilder, self).__init__(*args, **kwargs)
|
||||
self.schemas = []
|
||||
|
||||
def add_type_schema(self, schema):
|
||||
self.schemas.append(schema)
|
||||
|
||||
def build(self):
|
||||
if self.nullable:
|
||||
self.add_type_schema({'type': 'null'})
|
||||
|
||||
schemas = [self.maybe_build(schema) for schema in self.schemas]
|
||||
if len(schemas) == 1:
|
||||
schema = schemas[0]
|
||||
else:
|
||||
schema = {'oneOf': schemas}
|
||||
|
||||
if self.has_default:
|
||||
# The default value of EmbeddedField is expected to be an instance
|
||||
# of a subclass of models.Base, thus have `to_struct`
|
||||
schema["default"] = self.default.to_struct()
|
||||
|
||||
return schema
|
||||
|
||||
@property
|
||||
def is_definition(self):
|
||||
return self.parent.is_definition
|
||||
21
trains_agent/backend_api/session/jsonmodels/collections.py
Normal file
21
trains_agent/backend_api/session/jsonmodels/collections.py
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
|
||||
class ModelCollection(list):
|
||||
|
||||
"""`ModelCollection` is list which validates stored values.
|
||||
|
||||
Validation is made with use of field passed to `__init__` at each point,
|
||||
when new value is assigned.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
def append(self, value):
|
||||
self.field.validate_single_value(value)
|
||||
super(ModelCollection, self).append(value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.field.validate_single_value(value)
|
||||
super(ModelCollection, self).__setitem__(key, value)
|
||||
15
trains_agent/backend_api/session/jsonmodels/errors.py
Normal file
15
trains_agent/backend_api/session/jsonmodels/errors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
|
||||
class ValidationError(RuntimeError):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FieldNotFound(RuntimeError):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FieldNotSupported(ValueError):
|
||||
|
||||
pass
|
||||
488
trains_agent/backend_api/session/jsonmodels/fields.py
Normal file
488
trains_agent/backend_api/session/jsonmodels/fields.py
Normal file
@@ -0,0 +1,488 @@
|
||||
import datetime
|
||||
import re
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import six
|
||||
from dateutil.parser import parse
|
||||
|
||||
from .errors import ValidationError
|
||||
from .collections import ModelCollection
|
||||
|
||||
|
||||
# unique marker for "no default value specified". None is not good enough since
|
||||
# it is a completely valid default value.
|
||||
NotSet = object()
|
||||
|
||||
|
||||
class BaseField(object):
|
||||
|
||||
"""Base class for all fields."""
|
||||
|
||||
types = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required=False,
|
||||
nullable=False,
|
||||
help_text=None,
|
||||
validators=None,
|
||||
default=NotSet,
|
||||
name=None):
|
||||
self.memory = WeakKeyDictionary()
|
||||
self.required = required
|
||||
self.help_text = help_text
|
||||
self.nullable = nullable
|
||||
self._assign_validators(validators)
|
||||
self.name = name
|
||||
self._validate_name()
|
||||
if default is not NotSet:
|
||||
self.validate(default)
|
||||
self._default = default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self._default is not NotSet
|
||||
|
||||
def _assign_validators(self, validators):
|
||||
if validators and not isinstance(validators, list):
|
||||
validators = [validators]
|
||||
self.validators = validators or []
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self._finish_initialization(type(instance))
|
||||
value = self.parse_value(value)
|
||||
self.validate(value)
|
||||
self.memory[instance._cache_key] = value
|
||||
|
||||
def __get__(self, instance, owner=None):
|
||||
if instance is None:
|
||||
self._finish_initialization(owner)
|
||||
return self
|
||||
|
||||
self._finish_initialization(type(instance))
|
||||
|
||||
self._check_value(instance)
|
||||
return self.memory[instance._cache_key]
|
||||
|
||||
def _finish_initialization(self, owner):
|
||||
pass
|
||||
|
||||
def _check_value(self, obj):
|
||||
if obj._cache_key not in self.memory:
|
||||
self.__set__(obj, self.get_default_value())
|
||||
|
||||
def validate_for_object(self, obj):
|
||||
value = self.__get__(obj)
|
||||
self.validate(value)
|
||||
|
||||
def validate(self, value):
|
||||
self._check_types()
|
||||
self._validate_against_types(value)
|
||||
self._check_against_required(value)
|
||||
self._validate_with_custom_validators(value)
|
||||
|
||||
def _check_against_required(self, value):
|
||||
if value is None and self.required:
|
||||
raise ValidationError('Field is required!')
|
||||
|
||||
def _validate_against_types(self, value):
|
||||
if value is not None and not isinstance(value, self.types):
|
||||
raise ValidationError(
|
||||
'Value is wrong, expected type "{types}"'.format(
|
||||
types=', '.join([t.__name__ for t in self.types])
|
||||
),
|
||||
value,
|
||||
)
|
||||
|
||||
def _check_types(self):
|
||||
if self.types is None:
|
||||
raise ValidationError(
|
||||
'Field "{type}" is not usable, try '
|
||||
'different field type.'.format(type=type(self).__name__))
|
||||
|
||||
def to_struct(self, value):
|
||||
"""Cast value to Python structure."""
|
||||
return value
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Parse value from primitive to desired format.
|
||||
|
||||
Each field can parse value to form it wants it to be (like string or
|
||||
int).
|
||||
|
||||
"""
|
||||
return value
|
||||
|
||||
def _validate_with_custom_validators(self, value):
|
||||
if value is None and self.nullable:
|
||||
return
|
||||
|
||||
for validator in self.validators:
|
||||
try:
|
||||
validator.validate(value)
|
||||
except AttributeError:
|
||||
validator(value)
|
||||
|
||||
def get_default_value(self):
|
||||
"""Get default value for field.
|
||||
|
||||
Each field can specify its default.
|
||||
|
||||
"""
|
||||
return self._default if self.has_default else None
|
||||
|
||||
def _validate_name(self):
|
||||
if self.name is None:
|
||||
return
|
||||
if not re.match('^[A-Za-z_](([\w\-]*)?\w+)?$', self.name):
|
||||
raise ValueError('Wrong name', self.name)
|
||||
|
||||
def structue_name(self, default):
|
||||
return self.name if self.name is not None else default
|
||||
|
||||
|
||||
class StringField(BaseField):
|
||||
|
||||
"""String field."""
|
||||
|
||||
types = six.string_types
|
||||
|
||||
|
||||
class IntField(BaseField):
|
||||
|
||||
"""Integer field."""
|
||||
|
||||
types = (int,)
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Cast value to `int`, e.g. from string or long"""
|
||||
parsed = super(IntField, self).parse_value(value)
|
||||
if parsed is None:
|
||||
return parsed
|
||||
return int(parsed)
|
||||
|
||||
|
||||
class FloatField(BaseField):
|
||||
|
||||
"""Float field."""
|
||||
|
||||
types = (float, int)
|
||||
|
||||
|
||||
class BoolField(BaseField):
|
||||
|
||||
"""Bool field."""
|
||||
|
||||
types = (bool,)
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Cast value to `bool`."""
|
||||
parsed = super(BoolField, self).parse_value(value)
|
||||
return bool(parsed) if parsed is not None else None
|
||||
|
||||
|
||||
class ListField(BaseField):
|
||||
|
||||
"""List field."""
|
||||
|
||||
types = (list,)
|
||||
|
||||
def __init__(self, items_types=None, *args, **kwargs):
|
||||
"""Init.
|
||||
|
||||
`ListField` is **always not required**. If you want to control number
|
||||
of items use validators.
|
||||
|
||||
"""
|
||||
self._assign_types(items_types)
|
||||
super(ListField, self).__init__(*args, **kwargs)
|
||||
self.required = False
|
||||
|
||||
def get_default_value(self):
|
||||
default = super(ListField, self).get_default_value()
|
||||
if default is None:
|
||||
return ModelCollection(self)
|
||||
return default
|
||||
|
||||
def _assign_types(self, items_types):
|
||||
if items_types:
|
||||
try:
|
||||
self.items_types = tuple(items_types)
|
||||
except TypeError:
|
||||
self.items_types = items_types,
|
||||
else:
|
||||
self.items_types = tuple()
|
||||
|
||||
types = []
|
||||
for type_ in self.items_types:
|
||||
if isinstance(type_, six.string_types):
|
||||
types.append(_LazyType(type_))
|
||||
else:
|
||||
types.append(type_)
|
||||
self.items_types = tuple(types)
|
||||
|
||||
def validate(self, value):
|
||||
super(ListField, self).validate(value)
|
||||
|
||||
if len(self.items_types) == 0:
|
||||
return
|
||||
|
||||
for item in value:
|
||||
self.validate_single_value(item)
|
||||
|
||||
def validate_single_value(self, item):
|
||||
if len(self.items_types) == 0:
|
||||
return
|
||||
|
||||
if not isinstance(item, self.items_types):
|
||||
raise ValidationError(
|
||||
'All items must be instances '
|
||||
'of "{types}", and not "{type}".'.format(
|
||||
types=', '.join([t.__name__ for t in self.items_types]),
|
||||
type=type(item).__name__,
|
||||
))
|
||||
|
||||
def parse_value(self, values):
|
||||
"""Cast value to proper collection."""
|
||||
result = self.get_default_value()
|
||||
|
||||
if not values:
|
||||
return result
|
||||
|
||||
if not isinstance(values, list):
|
||||
return values
|
||||
|
||||
return [self._cast_value(value) for value in values]
|
||||
|
||||
def _cast_value(self, value):
|
||||
if isinstance(value, self.items_types):
|
||||
return value
|
||||
else:
|
||||
if len(self.items_types) != 1:
|
||||
tpl = 'Cannot decide which type to choose from "{types}".'
|
||||
raise ValidationError(
|
||||
tpl.format(
|
||||
types=', '.join([t.__name__ for t in self.items_types])
|
||||
)
|
||||
)
|
||||
return self.items_types[0](**value)
|
||||
|
||||
def _finish_initialization(self, owner):
|
||||
super(ListField, self)._finish_initialization(owner)
|
||||
|
||||
types = []
|
||||
for type in self.items_types:
|
||||
if isinstance(type, _LazyType):
|
||||
types.append(type.evaluate(owner))
|
||||
else:
|
||||
types.append(type)
|
||||
self.items_types = tuple(types)
|
||||
|
||||
def _elem_to_struct(self, value):
|
||||
try:
|
||||
return value.to_struct()
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
def to_struct(self, values):
|
||||
return [self._elem_to_struct(v) for v in values]
|
||||
|
||||
|
||||
class EmbeddedField(BaseField):
|
||||
|
||||
"""Field for embedded models."""
|
||||
|
||||
def __init__(self, model_types, *args, **kwargs):
|
||||
self._assign_model_types(model_types)
|
||||
super(EmbeddedField, self).__init__(*args, **kwargs)
|
||||
|
||||
def _assign_model_types(self, model_types):
|
||||
if not isinstance(model_types, (list, tuple)):
|
||||
model_types = (model_types,)
|
||||
|
||||
types = []
|
||||
for type_ in model_types:
|
||||
if isinstance(type_, six.string_types):
|
||||
types.append(_LazyType(type_))
|
||||
else:
|
||||
types.append(type_)
|
||||
self.types = tuple(types)
|
||||
|
||||
def _finish_initialization(self, owner):
|
||||
super(EmbeddedField, self)._finish_initialization(owner)
|
||||
|
||||
types = []
|
||||
for type in self.types:
|
||||
if isinstance(type, _LazyType):
|
||||
types.append(type.evaluate(owner))
|
||||
else:
|
||||
types.append(type)
|
||||
self.types = tuple(types)
|
||||
|
||||
def validate(self, value):
|
||||
super(EmbeddedField, self).validate(value)
|
||||
try:
|
||||
value.validate()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Parse value to proper model type."""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
embed_type = self._get_embed_type()
|
||||
return embed_type(**value)
|
||||
|
||||
def _get_embed_type(self):
|
||||
if len(self.types) != 1:
|
||||
raise ValidationError(
|
||||
'Cannot decide which type to choose from "{types}".'.format(
|
||||
types=', '.join([t.__name__ for t in self.types])
|
||||
)
|
||||
)
|
||||
return self.types[0]
|
||||
|
||||
def to_struct(self, value):
|
||||
return value.to_struct()
|
||||
|
||||
|
||||
class _LazyType(object):
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def evaluate(self, base_cls):
|
||||
module, type_name = _evaluate_path(self.path, base_cls)
|
||||
return _import(module, type_name)
|
||||
|
||||
|
||||
def _evaluate_path(relative_path, base_cls):
|
||||
base_module = base_cls.__module__
|
||||
|
||||
modules = _get_modules(relative_path, base_module)
|
||||
|
||||
type_name = modules.pop()
|
||||
module = '.'.join(modules)
|
||||
if not module:
|
||||
module = base_module
|
||||
return module, type_name
|
||||
|
||||
|
||||
def _get_modules(relative_path, base_module):
|
||||
canonical_path = relative_path.lstrip('.')
|
||||
canonical_modules = canonical_path.split('.')
|
||||
|
||||
if not relative_path.startswith('.'):
|
||||
return canonical_modules
|
||||
|
||||
parents_amount = len(relative_path) - len(canonical_path)
|
||||
parent_modules = base_module.split('.')
|
||||
parents_amount = max(0, parents_amount - 1)
|
||||
if parents_amount > len(parent_modules):
|
||||
raise ValueError("Can't evaluate path '{}'".format(relative_path))
|
||||
return parent_modules[:parents_amount * -1] + canonical_modules
|
||||
|
||||
|
||||
def _import(module_name, type_name):
|
||||
module = __import__(module_name, fromlist=[type_name])
|
||||
try:
|
||||
return getattr(module, type_name)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Can't find type '{}.{}'.".format(module_name, type_name))
|
||||
|
||||
|
||||
class TimeField(StringField):
|
||||
|
||||
"""Time field."""
|
||||
|
||||
types = (datetime.time,)
|
||||
|
||||
def __init__(self, str_format=None, *args, **kwargs):
|
||||
"""Init.
|
||||
|
||||
:param str str_format: Format to cast time to (if `None` - casting to
|
||||
ISO 8601 format).
|
||||
|
||||
"""
|
||||
self.str_format = str_format
|
||||
super(TimeField, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_struct(self, value):
|
||||
"""Cast `time` object to string."""
|
||||
if self.str_format:
|
||||
return value.strftime(self.str_format)
|
||||
return value.isoformat()
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Parse string into instance of `time`."""
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, datetime.time):
|
||||
return value
|
||||
return parse(value).timetz()
|
||||
|
||||
|
||||
class DateField(StringField):
|
||||
|
||||
"""Date field."""
|
||||
|
||||
types = (datetime.date,)
|
||||
default_format = '%Y-%m-%d'
|
||||
|
||||
def __init__(self, str_format=None, *args, **kwargs):
|
||||
"""Init.
|
||||
|
||||
:param str str_format: Format to cast date to (if `None` - casting to
|
||||
%Y-%m-%d format).
|
||||
|
||||
"""
|
||||
self.str_format = str_format
|
||||
super(DateField, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_struct(self, value):
|
||||
"""Cast `date` object to string."""
|
||||
if self.str_format:
|
||||
return value.strftime(self.str_format)
|
||||
return value.strftime(self.default_format)
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Parse string into instance of `date`."""
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, datetime.date):
|
||||
return value
|
||||
return parse(value).date()
|
||||
|
||||
|
||||
class DateTimeField(StringField):
|
||||
|
||||
"""Datetime field."""
|
||||
|
||||
types = (datetime.datetime,)
|
||||
|
||||
def __init__(self, str_format=None, *args, **kwargs):
|
||||
"""Init.
|
||||
|
||||
:param str str_format: Format to cast datetime to (if `None` - casting
|
||||
to ISO 8601 format).
|
||||
|
||||
"""
|
||||
self.str_format = str_format
|
||||
super(DateTimeField, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_struct(self, value):
|
||||
"""Cast `datetime` object to string."""
|
||||
if self.str_format:
|
||||
return value.strftime(self.str_format)
|
||||
return value.isoformat()
|
||||
|
||||
def parse_value(self, value):
|
||||
"""Parse string into instance of `datetime`."""
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
if value:
|
||||
return parse(value)
|
||||
else:
|
||||
return None
|
||||
154
trains_agent/backend_api/session/jsonmodels/models.py
Normal file
154
trains_agent/backend_api/session/jsonmodels/models.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import six
|
||||
|
||||
from . import parsers, errors
|
||||
from .fields import BaseField
|
||||
from .errors import ValidationError
|
||||
|
||||
|
||||
class JsonmodelMeta(type):
|
||||
|
||||
def __new__(cls, name, bases, attributes):
|
||||
cls.validate_fields(attributes)
|
||||
return super(cls, cls).__new__(cls, name, bases, attributes)
|
||||
|
||||
@staticmethod
|
||||
def validate_fields(attributes):
|
||||
fields = {
|
||||
key: value for key, value in attributes.items()
|
||||
if isinstance(value, BaseField)
|
||||
}
|
||||
taken_names = set()
|
||||
for name, field in fields.items():
|
||||
structue_name = field.structue_name(name)
|
||||
if structue_name in taken_names:
|
||||
raise ValueError('Name taken', structue_name, name)
|
||||
taken_names.add(structue_name)
|
||||
|
||||
|
||||
class Base(six.with_metaclass(JsonmodelMeta, object)):
|
||||
|
||||
"""Base class for all models."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._cache_key = _CacheKey()
|
||||
self.populate(**kwargs)
|
||||
|
||||
def populate(self, **values):
|
||||
"""Populate values to fields. Skip non-existing."""
|
||||
values = values.copy()
|
||||
fields = list(self.iterate_with_name())
|
||||
for _, structure_name, field in fields:
|
||||
if structure_name in values:
|
||||
field.__set__(self, values.pop(structure_name))
|
||||
for name, _, field in fields:
|
||||
if name in values:
|
||||
field.__set__(self, values.pop(name))
|
||||
|
||||
def get_field(self, field_name):
|
||||
"""Get field associated with given attribute."""
|
||||
for attr_name, field in self:
|
||||
if field_name == attr_name:
|
||||
return field
|
||||
|
||||
raise errors.FieldNotFound('Field not found', field_name)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through fields and values."""
|
||||
for name, field in self.iterate_over_fields():
|
||||
yield name, field
|
||||
|
||||
def validate(self):
|
||||
"""Explicitly validate all the fields."""
|
||||
for name, field in self:
|
||||
try:
|
||||
field.validate_for_object(self)
|
||||
except ValidationError as error:
|
||||
raise ValidationError(
|
||||
"Error for field '{name}'.".format(name=name),
|
||||
error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def iterate_over_fields(cls):
|
||||
"""Iterate through fields as `(attribute_name, field_instance)`."""
|
||||
for attr in dir(cls):
|
||||
clsattr = getattr(cls, attr)
|
||||
if isinstance(clsattr, BaseField):
|
||||
yield attr, clsattr
|
||||
|
||||
@classmethod
|
||||
def iterate_with_name(cls):
|
||||
"""Iterate over fields, but also give `structure_name`.
|
||||
|
||||
Format is `(attribute_name, structue_name, field_instance)`.
|
||||
Structure name is name under which value is seen in structure and
|
||||
schema (in primitives) and only there.
|
||||
"""
|
||||
for attr_name, field in cls.iterate_over_fields():
|
||||
structure_name = field.structue_name(attr_name)
|
||||
yield attr_name, structure_name, field
|
||||
|
||||
def to_struct(self):
|
||||
"""Cast model to Python structure."""
|
||||
return parsers.to_struct(self)
|
||||
|
||||
@classmethod
|
||||
def to_json_schema(cls):
|
||||
"""Generate JSON schema for model."""
|
||||
return parsers.to_json_schema(cls)
|
||||
|
||||
def __repr__(self):
|
||||
attrs = {}
|
||||
for name, _ in self:
|
||||
try:
|
||||
attr = getattr(self, name)
|
||||
if attr is not None:
|
||||
attrs[name] = repr(attr)
|
||||
except ValidationError:
|
||||
pass
|
||||
|
||||
return '{class_name}({fields})'.format(
|
||||
class_name=self.__class__.__name__,
|
||||
fields=', '.join(
|
||||
'{0[0]}={0[1]}'.format(x) for x in sorted(attrs.items())
|
||||
),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return '{name} object'.format(name=self.__class__.__name__)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
return super(Base, self).__setattr__(name, value)
|
||||
except ValidationError as error:
|
||||
raise ValidationError(
|
||||
"Error for field '{name}'.".format(name=name),
|
||||
error
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) is not type(self):
|
||||
return False
|
||||
|
||||
for name, _ in self.iterate_over_fields():
|
||||
try:
|
||||
our = getattr(self, name)
|
||||
except errors.ValidationError:
|
||||
our = None
|
||||
|
||||
try:
|
||||
their = getattr(other, name)
|
||||
except errors.ValidationError:
|
||||
their = None
|
||||
|
||||
if our != their:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
|
||||
class _CacheKey(object):
|
||||
"""Object to identify model in memory."""
|
||||
106
trains_agent/backend_api/session/jsonmodels/parsers.py
Normal file
106
trains_agent/backend_api/session/jsonmodels/parsers.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Parsers to change model structure into different ones."""
|
||||
import inspect
|
||||
|
||||
from . import fields, builders, errors
|
||||
|
||||
|
||||
def to_struct(model):
|
||||
"""Cast instance of model to python structure.
|
||||
|
||||
:param model: Model to be casted.
|
||||
:rtype: ``dict``
|
||||
|
||||
"""
|
||||
model.validate()
|
||||
|
||||
resp = {}
|
||||
for _, name, field in model.iterate_with_name():
|
||||
value = field.__get__(model)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
value = field.to_struct(value)
|
||||
resp[name] = value
|
||||
return resp
|
||||
|
||||
|
||||
def to_json_schema(cls):
|
||||
"""Generate JSON schema for given class.
|
||||
|
||||
:param cls: Class to be casted.
|
||||
:rtype: ``dict``
|
||||
|
||||
"""
|
||||
builder = build_json_schema(cls)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def build_json_schema(value, parent_builder=None):
|
||||
from .models import Base
|
||||
|
||||
cls = value if inspect.isclass(value) else value.__class__
|
||||
if issubclass(cls, Base):
|
||||
return build_json_schema_object(cls, parent_builder)
|
||||
else:
|
||||
return build_json_schema_primitive(cls, parent_builder)
|
||||
|
||||
|
||||
def build_json_schema_object(cls, parent_builder=None):
|
||||
builder = builders.ObjectBuilder(cls, parent_builder)
|
||||
if builder.count_type(builder.type) > 1:
|
||||
return builder
|
||||
for _, name, field in cls.iterate_with_name():
|
||||
if isinstance(field, fields.EmbeddedField):
|
||||
builder.add_field(name, field, _parse_embedded(field, builder))
|
||||
elif isinstance(field, fields.ListField):
|
||||
builder.add_field(name, field, _parse_list(field, builder))
|
||||
else:
|
||||
builder.add_field(
|
||||
name, field, _create_primitive_field_schema(field))
|
||||
return builder
|
||||
|
||||
|
||||
def _parse_list(field, parent_builder):
|
||||
builder = builders.ListBuilder(
|
||||
parent_builder, field.nullable, default=field._default)
|
||||
for type in field.items_types:
|
||||
builder.add_type_schema(build_json_schema(type, builder))
|
||||
return builder
|
||||
|
||||
|
||||
def _parse_embedded(field, parent_builder):
|
||||
builder = builders.EmbeddedBuilder(
|
||||
parent_builder, field.nullable, default=field._default)
|
||||
for type in field.types:
|
||||
builder.add_type_schema(build_json_schema(type, builder))
|
||||
return builder
|
||||
|
||||
|
||||
def build_json_schema_primitive(cls, parent_builder):
|
||||
builder = builders.PrimitiveBuilder(cls, parent_builder)
|
||||
return builder
|
||||
|
||||
|
||||
def _create_primitive_field_schema(field):
|
||||
if isinstance(field, fields.StringField):
|
||||
obj_type = 'string'
|
||||
elif isinstance(field, fields.IntField):
|
||||
obj_type = 'number'
|
||||
elif isinstance(field, fields.FloatField):
|
||||
obj_type = 'float'
|
||||
elif isinstance(field, fields.BoolField):
|
||||
obj_type = 'boolean'
|
||||
else:
|
||||
raise errors.FieldNotSupported(
|
||||
'Field {field} is not supported!'.format(
|
||||
field=type(field).__class__.__name__))
|
||||
|
||||
if field.nullable:
|
||||
obj_type = [obj_type, 'null']
|
||||
|
||||
schema = {'type': obj_type}
|
||||
|
||||
if field.has_default:
|
||||
schema["default"] = field._default
|
||||
|
||||
return schema
|
||||
156
trains_agent/backend_api/session/jsonmodels/utilities.py
Normal file
156
trains_agent/backend_api/session/jsonmodels/utilities.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import six
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
SCALAR_TYPES = tuple(list(six.string_types) + [int, float, bool])
|
||||
|
||||
ECMA_TO_PYTHON_FLAGS = {
|
||||
'i': re.I,
|
||||
'm': re.M,
|
||||
}
|
||||
|
||||
PYTHON_TO_ECMA_FLAGS = dict(
|
||||
(value, key) for key, value in ECMA_TO_PYTHON_FLAGS.items()
|
||||
)
|
||||
|
||||
PythonRegex = namedtuple('PythonRegex', ['regex', 'flags'])
|
||||
|
||||
|
||||
def _normalize_string_type(value):
|
||||
if isinstance(value, six.string_types):
|
||||
return six.text_type(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def _compare_dicts(one, two):
|
||||
if len(one) != len(two):
|
||||
return False
|
||||
|
||||
for key, value in one.items():
|
||||
if key not in one or key not in two:
|
||||
return False
|
||||
|
||||
if not compare_schemas(one[key], two[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _compare_lists(one, two):
|
||||
if len(one) != len(two):
|
||||
return False
|
||||
|
||||
they_match = False
|
||||
for first_item in one:
|
||||
for second_item in two:
|
||||
if they_match:
|
||||
continue
|
||||
they_match = compare_schemas(first_item, second_item)
|
||||
return they_match
|
||||
|
||||
|
||||
def _assert_same_types(one, two):
|
||||
if not isinstance(one, type(two)) or not isinstance(two, type(one)):
|
||||
raise RuntimeError('Types mismatch! "{type1}" and "{type2}".'.format(
|
||||
type1=type(one).__name__, type2=type(two).__name__))
|
||||
|
||||
|
||||
def compare_schemas(one, two):
|
||||
"""Compare two structures that represents JSON schemas.
|
||||
|
||||
For comparison you can't use normal comparison, because in JSON schema
|
||||
lists DO NOT keep order (and Python lists do), so this must be taken into
|
||||
account during comparison.
|
||||
|
||||
Note this wont check all configurations, only first one that seems to
|
||||
match, which can lead to wrong results.
|
||||
|
||||
:param one: First schema to compare.
|
||||
:param two: Second schema to compare.
|
||||
:rtype: `bool`
|
||||
|
||||
"""
|
||||
one = _normalize_string_type(one)
|
||||
two = _normalize_string_type(two)
|
||||
|
||||
_assert_same_types(one, two)
|
||||
|
||||
if isinstance(one, list):
|
||||
return _compare_lists(one, two)
|
||||
elif isinstance(one, dict):
|
||||
return _compare_dicts(one, two)
|
||||
elif isinstance(one, SCALAR_TYPES):
|
||||
return one == two
|
||||
elif one is None:
|
||||
return one is two
|
||||
else:
|
||||
raise RuntimeError('Not allowed type "{type}"'.format(
|
||||
type=type(one).__name__))
|
||||
|
||||
|
||||
def is_ecma_regex(regex):
|
||||
"""Check if given regex is of type ECMA 262 or not.
|
||||
|
||||
:rtype: bool
|
||||
|
||||
"""
|
||||
parts = regex.split('/')
|
||||
|
||||
if len(parts) == 1:
|
||||
return False
|
||||
|
||||
if len(parts) < 3:
|
||||
raise ValueError('Given regex isn\'t ECMA regex nor Python regex.')
|
||||
parts.pop()
|
||||
parts.append('')
|
||||
|
||||
raw_regex = '/'.join(parts)
|
||||
if raw_regex.startswith('/') and raw_regex.endswith('/'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_ecma_regex_to_python(value):
|
||||
"""Convert ECMA 262 regex to Python tuple with regex and flags.
|
||||
|
||||
If given value is already Python regex it will be returned unchanged.
|
||||
|
||||
:param string value: ECMA regex.
|
||||
:return: 2-tuple with `regex` and `flags`
|
||||
:rtype: namedtuple
|
||||
|
||||
"""
|
||||
if not is_ecma_regex(value):
|
||||
return PythonRegex(value, [])
|
||||
|
||||
parts = value.split('/')
|
||||
flags = parts.pop()
|
||||
|
||||
try:
|
||||
result_flags = [ECMA_TO_PYTHON_FLAGS[f] for f in flags]
|
||||
except KeyError:
|
||||
raise ValueError('Wrong flags "{}".'.format(flags))
|
||||
|
||||
return PythonRegex('/'.join(parts[1:]), result_flags)
|
||||
|
||||
|
||||
def convert_python_regex_to_ecma(value, flags=[]):
|
||||
"""Convert Python regex to ECMA 262 regex.
|
||||
|
||||
If given value is already ECMA regex it will be returned unchanged.
|
||||
|
||||
:param string value: Python regex.
|
||||
:param list flags: List of flags (allowed flags: `re.I`, `re.M`)
|
||||
:return: ECMA 262 regex
|
||||
:rtype: str
|
||||
|
||||
"""
|
||||
if is_ecma_regex(value):
|
||||
return value
|
||||
|
||||
result_flags = [PYTHON_TO_ECMA_FLAGS[f] for f in flags]
|
||||
result_flags = ''.join(result_flags)
|
||||
|
||||
return '/{value}/{flags}'.format(value=value, flags=result_flags)
|
||||
202
trains_agent/backend_api/session/jsonmodels/validators.py
Normal file
202
trains_agent/backend_api/session/jsonmodels/validators.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Predefined validators."""
|
||||
import re
|
||||
|
||||
from six.moves import reduce
|
||||
|
||||
from .errors import ValidationError
|
||||
from . import utilities
|
||||
|
||||
|
||||
class Min(object):
|
||||
|
||||
"""Validator for minimum value."""
|
||||
|
||||
def __init__(self, minimum_value, exclusive=False):
|
||||
"""Init.
|
||||
|
||||
:param minimum_value: Minimum value for validator.
|
||||
:param bool exclusive: If `True`, then validated value must be strongly
|
||||
lower than given threshold.
|
||||
|
||||
"""
|
||||
self.minimum_value = minimum_value
|
||||
self.exclusive = exclusive
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate value."""
|
||||
if self.exclusive:
|
||||
if value <= self.minimum_value:
|
||||
tpl = "'{value}' is lower or equal than minimum ('{min}')."
|
||||
raise ValidationError(
|
||||
tpl.format(value=value, min=self.minimum_value))
|
||||
else:
|
||||
if value < self.minimum_value:
|
||||
raise ValidationError(
|
||||
"'{value}' is lower than minimum ('{min}').".format(
|
||||
value=value, min=self.minimum_value))
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
"""Modify field schema."""
|
||||
field_schema['minimum'] = self.minimum_value
|
||||
if self.exclusive:
|
||||
field_schema['exclusiveMinimum'] = True
|
||||
|
||||
|
||||
class Max(object):
|
||||
|
||||
"""Validator for maximum value."""
|
||||
|
||||
def __init__(self, maximum_value, exclusive=False):
|
||||
"""Init.
|
||||
|
||||
:param maximum_value: Maximum value for validator.
|
||||
:param bool exclusive: If `True`, then validated value must be strongly
|
||||
bigger than given threshold.
|
||||
|
||||
"""
|
||||
self.maximum_value = maximum_value
|
||||
self.exclusive = exclusive
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate value."""
|
||||
if self.exclusive:
|
||||
if value >= self.maximum_value:
|
||||
tpl = "'{val}' is bigger or equal than maximum ('{max}')."
|
||||
raise ValidationError(
|
||||
tpl.format(val=value, max=self.maximum_value))
|
||||
else:
|
||||
if value > self.maximum_value:
|
||||
raise ValidationError(
|
||||
"'{value}' is bigger than maximum ('{max}').".format(
|
||||
value=value, max=self.maximum_value))
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
"""Modify field schema."""
|
||||
field_schema['maximum'] = self.maximum_value
|
||||
if self.exclusive:
|
||||
field_schema['exclusiveMaximum'] = True
|
||||
|
||||
|
||||
class Regex(object):
|
||||
|
||||
"""Validator for regular expressions."""
|
||||
|
||||
FLAGS = {
|
||||
'ignorecase': re.I,
|
||||
'multiline': re.M,
|
||||
}
|
||||
|
||||
def __init__(self, pattern, **flags):
|
||||
"""Init.
|
||||
|
||||
Note, that if given pattern is ECMA regex, given flags will be
|
||||
**completely ignored** and taken from given regex.
|
||||
|
||||
|
||||
:param string pattern: Pattern of regex.
|
||||
:param bool flags: Flags used for the regex matching.
|
||||
Allowed flag names are in the `FLAGS` attribute. The flag value
|
||||
does not matter as long as it evaluates to True.
|
||||
Flags with False values will be ignored.
|
||||
Invalid flags will be ignored.
|
||||
|
||||
"""
|
||||
if utilities.is_ecma_regex(pattern):
|
||||
result = utilities.convert_ecma_regex_to_python(pattern)
|
||||
self.pattern, self.flags = result
|
||||
else:
|
||||
self.pattern = pattern
|
||||
self.flags = [self.FLAGS[key] for key, value in flags.items()
|
||||
if key in self.FLAGS and value]
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate value."""
|
||||
flags = self._calculate_flags()
|
||||
|
||||
try:
|
||||
result = re.search(self.pattern, value, flags)
|
||||
except TypeError as te:
|
||||
raise ValidationError(*te.args)
|
||||
|
||||
if not result:
|
||||
raise ValidationError(
|
||||
'Value "{value}" did not match pattern "{pattern}".'.format(
|
||||
value=value, pattern=self.pattern
|
||||
))
|
||||
|
||||
def _calculate_flags(self):
|
||||
return reduce(lambda x, y: x | y, self.flags, 0)
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
"""Modify field schema."""
|
||||
field_schema['pattern'] = utilities.convert_python_regex_to_ecma(
|
||||
self.pattern, self.flags)
|
||||
|
||||
|
||||
class Length(object):
|
||||
|
||||
"""Validator for length."""
|
||||
|
||||
def __init__(self, minimum_value=None, maximum_value=None):
|
||||
"""Init.
|
||||
|
||||
Note that if no `minimum_value` neither `maximum_value` will be
|
||||
specified, `ValueError` will be raised.
|
||||
|
||||
:param int minimum_value: Minimum value (optional).
|
||||
:param int maximum_value: Maximum value (optional).
|
||||
|
||||
"""
|
||||
if minimum_value is None and maximum_value is None:
|
||||
raise ValueError(
|
||||
"Either 'minimum_value' or 'maximum_value' must be specified.")
|
||||
|
||||
self.minimum_value = minimum_value
|
||||
self.maximum_value = maximum_value
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate value."""
|
||||
len_ = len(value)
|
||||
|
||||
if self.minimum_value is not None and len_ < self.minimum_value:
|
||||
tpl = "Value '{val}' length is lower than allowed minimum '{min}'."
|
||||
raise ValidationError(tpl.format(
|
||||
val=value, min=self.minimum_value
|
||||
))
|
||||
|
||||
if self.maximum_value is not None and len_ > self.maximum_value:
|
||||
raise ValidationError(
|
||||
"Value '{val}' length is bigger than "
|
||||
"allowed maximum '{max}'.".format(
|
||||
val=value,
|
||||
max=self.maximum_value,
|
||||
))
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
"""Modify field schema."""
|
||||
if self.minimum_value:
|
||||
field_schema['minLength'] = self.minimum_value
|
||||
|
||||
if self.maximum_value:
|
||||
field_schema['maxLength'] = self.maximum_value
|
||||
|
||||
|
||||
class Enum(object):
|
||||
|
||||
"""Validator for enums."""
|
||||
|
||||
def __init__(self, *choices):
|
||||
"""Init.
|
||||
|
||||
:param [] choices: Valid choices for the field.
|
||||
"""
|
||||
|
||||
self.choices = list(choices)
|
||||
|
||||
def validate(self, value):
|
||||
if value not in self.choices:
|
||||
tpl = "Value '{val}' is not a valid choice."
|
||||
raise ValidationError(tpl.format(val=value))
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
field_schema['enum'] = self.choices
|
||||
@@ -1,10 +1,8 @@
|
||||
import requests
|
||||
|
||||
import six
|
||||
import jsonmodels.models
|
||||
import jsonmodels.fields
|
||||
import jsonmodels.errors
|
||||
|
||||
from . import jsonmodels
|
||||
from .apimodel import ApiModel
|
||||
from .datamodel import NonStrictDataModelMixin
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ...backend_config.environment import backward_compatibility_support
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
@@ -38,9 +39,10 @@ class Session(TokenManager):
|
||||
_async_status_code = 202
|
||||
_session_requests = 0
|
||||
_session_initial_timeout = (3.0, 10.)
|
||||
_session_timeout = (10.0, 300.)
|
||||
_session_timeout = (10.0, 30.)
|
||||
_session_initial_connect_retry = 4
|
||||
_write_session_data_size = 15000
|
||||
_write_session_timeout = (300.0, 300.)
|
||||
_write_session_timeout = (30.0, 30.)
|
||||
|
||||
api_version = '2.1'
|
||||
default_host = "https://demoapi.trains.allegro.ai"
|
||||
@@ -84,15 +86,18 @@ class Session(TokenManager):
|
||||
initialize_logging=True,
|
||||
client=None,
|
||||
config=None,
|
||||
http_retries_config=None,
|
||||
**kwargs
|
||||
):
|
||||
# add backward compatibility support for old environment variables
|
||||
backward_compatibility_support()
|
||||
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = load()
|
||||
if initialize_logging:
|
||||
self.config.initialize_logging()
|
||||
self.config.initialize_logging(debug=kwargs.get('debug', False))
|
||||
|
||||
token_expiration_threshold_sec = self.config.get(
|
||||
"auth.token_expiration_threshold_sec", 60
|
||||
@@ -126,11 +131,10 @@ class Session(TokenManager):
|
||||
raise ValueError("host is required in init or config")
|
||||
|
||||
self.__host = host.strip("/")
|
||||
http_retries_config = self.config.get(
|
||||
http_retries_config = http_retries_config or self.config.get(
|
||||
"api.http.retries", ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
http_retries_config["status_forcelist"] = self._retry_codes
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
self.__worker = worker or gethostname()
|
||||
|
||||
@@ -140,7 +144,14 @@ class Session(TokenManager):
|
||||
|
||||
self.client = client or "api-{}".format(__version__)
|
||||
|
||||
# limit the reconnect retries, so we get an error if we are starting the session
|
||||
http_no_retries_config = dict(**http_retries_config)
|
||||
http_no_retries_config['connect'] = self._session_initial_connect_retry
|
||||
self.__http_session = get_http_session_with_retry(**http_no_retries_config)
|
||||
# try to connect with the server
|
||||
self.refresh_token()
|
||||
# create the default session with many retries
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
# update api version from server response
|
||||
try:
|
||||
@@ -423,16 +434,15 @@ class Session(TokenManager):
|
||||
@classmethod
|
||||
def get_api_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
||||
config.get("api.host", None) or cls.default_host))
|
||||
|
||||
@classmethod
|
||||
def get_app_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
# get from config/environment
|
||||
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
|
||||
@@ -459,8 +469,8 @@ class Session(TokenManager):
|
||||
@classmethod
|
||||
def get_files_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return None
|
||||
|
||||
# get from config/environment
|
||||
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
|
||||
if files_host:
|
||||
@@ -542,6 +552,9 @@ class Session(TokenManager):
|
||||
else:
|
||||
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||
"exception: {}".format(res, ex))
|
||||
except requests.ConnectionError as ex:
|
||||
raise ValueError('Connection Error: it seems *api_server* is misconfigured. '
|
||||
'Is this the TRAINS API server {} ?'.format('/'.join(ex.request.url.split('/')[:3])))
|
||||
except Exception as ex:
|
||||
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ class Config(object):
|
||||
def reload(self):
|
||||
self.replace(self._reload())
|
||||
|
||||
def initialize_logging(self):
|
||||
def initialize_logging(self, debug=False):
|
||||
logging_config = self._config.get("logging", None)
|
||||
if not logging_config:
|
||||
return False
|
||||
@@ -217,6 +217,8 @@ class Config(object):
|
||||
)
|
||||
for logger in loggers:
|
||||
handlers = logger.get("handlers", None)
|
||||
if debug:
|
||||
logger['level'] = 'DEBUG'
|
||||
if not handlers:
|
||||
continue
|
||||
logger["handlers"] = [h for h in handlers if h not in deleted]
|
||||
|
||||
@@ -46,6 +46,15 @@ class Environment(object):
|
||||
local = 'local'
|
||||
|
||||
|
||||
class UptimeConf(object):
|
||||
min_api_version = "2.10"
|
||||
queue_tag_on = "force_workers:on"
|
||||
queue_tag_off = "force_workers:off"
|
||||
worker_key = "force"
|
||||
worker_value_off = ["off"]
|
||||
worker_value_on = ["on"]
|
||||
|
||||
|
||||
CONFIG_FILE_EXTENSION = '.conf'
|
||||
|
||||
|
||||
|
||||
@@ -23,3 +23,31 @@ class EnvEntry(Entry):
|
||||
|
||||
def error(self, message):
|
||||
print("Environment configuration: {}".format(message))
|
||||
|
||||
|
||||
def backward_compatibility_support():
|
||||
from ..definitions import ENVIRONMENT_CONFIG, ENVIRONMENT_SDK_PARAMS, ENVIRONMENT_BACKWARD_COMPATIBLE
|
||||
if not ENVIRONMENT_BACKWARD_COMPATIBLE.get():
|
||||
return
|
||||
|
||||
# Add ALG_ prefix on every TRAINS_ os environment we support
|
||||
for k, v in ENVIRONMENT_CONFIG.items():
|
||||
try:
|
||||
trains_vars = [var for var in v.vars if var.startswith('TRAINS_')]
|
||||
if not trains_vars:
|
||||
continue
|
||||
alg_var = trains_vars[0].replace('TRAINS_', 'ALG_', 1)
|
||||
if alg_var not in v.vars:
|
||||
v.vars = tuple(list(v.vars) + [alg_var])
|
||||
except:
|
||||
continue
|
||||
for k, v in ENVIRONMENT_SDK_PARAMS.items():
|
||||
try:
|
||||
trains_vars = [var for var in v if var.startswith('TRAINS_')]
|
||||
if not trains_vars:
|
||||
continue
|
||||
alg_var = trains_vars[0].replace('TRAINS_', 'ALG_', 1)
|
||||
if alg_var not in v:
|
||||
ENVIRONMENT_SDK_PARAMS[k] = tuple(list(v) + [alg_var])
|
||||
except:
|
||||
continue
|
||||
|
||||
@@ -94,9 +94,20 @@ class ServiceCommandSection(BaseCommandSection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ServiceCommandSection, self).__init__()
|
||||
kwargs = self._verify_command_states(kwargs)
|
||||
self._session = self._get_session(*args, **kwargs)
|
||||
self._list_formatter = ListFormatter(self.service)
|
||||
|
||||
@classmethod
|
||||
def _verify_command_states(cls, kwargs):
|
||||
"""
|
||||
Conform and enforce command argument
|
||||
This is where you can automatically turn on/off switches based on different states.
|
||||
:param kwargs:
|
||||
:return: kwargs
|
||||
"""
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_session(*args, **kwargs):
|
||||
return Session(*args, **kwargs)
|
||||
@@ -358,7 +369,7 @@ class ServiceCommandSection(BaseCommandSection):
|
||||
**locals())
|
||||
self.exit(message)
|
||||
|
||||
message = 'Could not find {} with name "{}"'.format(service.rstrip('s'), name)
|
||||
message = 'Could not find {} with name/id "{}"'.format(service.rstrip('s'), name)
|
||||
|
||||
if not response:
|
||||
raise NameResolutionError(message)
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import input
|
||||
from pyhocon import ConfigFactory
|
||||
from pyhocon import ConfigFactory, ConfigMissingException
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from trains_agent.backend_api.session import Session
|
||||
from trains_agent.backend_api.session.defs import ENV_HOST
|
||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILES
|
||||
|
||||
|
||||
description = """
|
||||
Please create new credentials using the web app: {}/profile
|
||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
||||
Please create new trains credentials through the profile page in your trains web app (e.g. https://demoapp.trains.allegro.ai/profile)
|
||||
In the profile page, press "Create new credentials", then press "Copy to clipboard".
|
||||
|
||||
Paste credentials here: """
|
||||
Paste copied configuration here:
|
||||
"""
|
||||
|
||||
def_host = 'http://localhost:8080'
|
||||
try:
|
||||
@@ -38,96 +40,66 @@ def main():
|
||||
print('Leaving setup, feel free to edit the configuration file.')
|
||||
return
|
||||
|
||||
print(host_description)
|
||||
web_host = input_url('Web Application Host', '')
|
||||
parsed_host = verify_url(web_host)
|
||||
print(description, end='')
|
||||
sentinel = ''
|
||||
parse_input = '\n'.join(iter(input, sentinel))
|
||||
credentials = None
|
||||
api_server = None
|
||||
web_server = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
# Take the credentials in raw form or from api section
|
||||
credentials = get_parsed_field(parsed, ["credentials"])
|
||||
api_server = get_parsed_field(parsed, ["api_server", "host"])
|
||||
web_server = get_parsed_field(parsed, ["web_server"])
|
||||
except Exception:
|
||||
credentials = credentials or None
|
||||
api_server = api_server or None
|
||||
web_server = web_server or None
|
||||
|
||||
if parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||
while not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse credentials, please try entering them manually.')
|
||||
credentials = read_manual_credentials()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'][0:4] + "***"))
|
||||
web_input = True
|
||||
if web_server:
|
||||
host = input_url('WEB Host', web_server)
|
||||
elif api_server:
|
||||
web_input = False
|
||||
host = input_url('API Host', api_server)
|
||||
else:
|
||||
api_host = ''
|
||||
web_host = ''
|
||||
files_host = ''
|
||||
if not parsed_host.port:
|
||||
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
|
||||
replace_port = input().lower()
|
||||
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
|
||||
if not api_host:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
print(host_description)
|
||||
host = input_url('WEB Host', '')
|
||||
|
||||
parsed_host = verify_url(host)
|
||||
api_host, files_host, web_host = parse_host(parsed_host, allow_input=True)
|
||||
|
||||
# on of these two we configured
|
||||
if not web_input:
|
||||
web_host = input_url('Web Application Host', web_host)
|
||||
else:
|
||||
api_host = input_url('API Host', api_host)
|
||||
|
||||
api_host = input_url('API Host', api_host)
|
||||
files_host = input_url('File Store Host', files_host)
|
||||
|
||||
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
|
||||
api_host, web_host, files_host))
|
||||
print('\nTRAINS Hosts configuration:\nWeb App: {}\nAPI: {}\nFile Store: {}\n'.format(
|
||||
web_host, api_host, files_host))
|
||||
|
||||
while True:
|
||||
print(description.format(web_host), end='')
|
||||
parse_input = input()
|
||||
# check if these are valid credentials
|
||||
credentials = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
credentials = parsed.get("credentials", None)
|
||||
except Exception:
|
||||
credentials = None
|
||||
|
||||
if not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse user credentials, try again one after the other.')
|
||||
credentials = {}
|
||||
# parse individual
|
||||
print('Enter user access key: ', end='')
|
||||
credentials['access_key'] = input()
|
||||
print('Enter user secret: ', end='')
|
||||
credentials['secret_key'] = input()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'], ))
|
||||
|
||||
from trains_agent.backend_api.session import Session
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
print('Verifying credentials ...')
|
||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host)
|
||||
print('Credentials verified!')
|
||||
retry = 1
|
||||
max_retries = 2
|
||||
while retry <= max_retries: # Up to 2 tries by the user
|
||||
if verify_credentials(api_host, credentials):
|
||||
break
|
||||
except Exception:
|
||||
print('Error: could not verify credentials: host={} access={} secret={}'.format(
|
||||
api_host, credentials['access_key'], credentials['secret_key']))
|
||||
retry += 1
|
||||
if retry < max_retries + 1:
|
||||
credentials = read_manual_credentials()
|
||||
else:
|
||||
print('Exiting setup without creating configuration file')
|
||||
return
|
||||
|
||||
# get GIT User/Pass for cloning
|
||||
print('Enter git username for repository cloning (leave blank for SSH key authentication): [] ', end='')
|
||||
@@ -140,6 +112,18 @@ def main():
|
||||
git_user = None
|
||||
git_pass = None
|
||||
|
||||
# get extra-index-url for pip installations
|
||||
extra_index_urls = []
|
||||
print('\nEnter additional artifact repository (extra-index-url) to use when installing python packages '
|
||||
'(leave blank if not required):', end='')
|
||||
index_url = input().strip()
|
||||
while index_url:
|
||||
extra_index_urls.append(index_url)
|
||||
print('Another artifact repository? (enter another url or leave blank if done):', end='')
|
||||
index_url = input().strip()
|
||||
if len(extra_index_urls):
|
||||
print("The following artifact repositories will be added:\n\t- {}".format("\n\t- ".join(extra_index_urls)))
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conf_folder = Path(__file__).parent.absolute() / '..' / 'backend_api' / 'config' / 'default'
|
||||
@@ -158,6 +142,7 @@ def main():
|
||||
with open(str(conf_file), 'wt') as f:
|
||||
header = '# TRAINS-AGENT configuration file\n' \
|
||||
'api {\n' \
|
||||
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
|
||||
' api_server: %s\n' \
|
||||
' web_server: %s\n' \
|
||||
' files_server: %s\n' \
|
||||
@@ -173,6 +158,10 @@ def main():
|
||||
'agent.git_pass=\"{}\"\n' \
|
||||
'\n'.format(git_user or '', git_pass or '')
|
||||
f.write(git_credentials)
|
||||
extra_index_str = '# extra_index_url: ["https://allegroai.jfrog.io/trainsai/api/pypi/public/simple"]\n' \
|
||||
'agent.package_manager.extra_index_url= ' \
|
||||
'[\n{}\n]\n\n'.format("\n".join(map("\"{}\"".format, extra_index_urls)))
|
||||
f.write(extra_index_str)
|
||||
f.write(default_conf)
|
||||
except Exception:
|
||||
print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
|
||||
@@ -182,19 +171,132 @@ def main():
|
||||
print('TRAINS-AGENT setup completed successfully.')
|
||||
|
||||
|
||||
def parse_host(parsed_host, allow_input=True):
|
||||
if parsed_host.netloc.startswith('demoapp.'):
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.',
|
||||
1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.',
|
||||
1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif allow_input:
|
||||
api_host = ''
|
||||
web_host = ''
|
||||
files_host = ''
|
||||
if not parsed_host.port:
|
||||
print('Host port not detected, do you wish to use the default 8080 port n/[y]? ', end='')
|
||||
replace_port = input().lower()
|
||||
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
|
||||
elif not replace_port or replace_port.lower() == 'n' or replace_port.lower() == 'no':
|
||||
web_host = input_host_port("Web", parsed_host)
|
||||
api_host = input_host_port("API", parsed_host)
|
||||
files_host = input_host_port("Files", parsed_host)
|
||||
if not api_host:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
else:
|
||||
raise ValueError("Could not parse host name")
|
||||
|
||||
return api_host, files_host, web_host
|
||||
|
||||
|
||||
def verify_credentials(api_host, credentials):
|
||||
"""check if the credentials are valid"""
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
print('Verifying credentials ...')
|
||||
if api_host:
|
||||
Session(api_key=credentials['access_key'], secret_key=credentials['secret_key'], host=api_host,
|
||||
http_retries_config={"total": 2})
|
||||
print('Credentials verified!')
|
||||
return True
|
||||
else:
|
||||
print("Can't verify credentials")
|
||||
return False
|
||||
except Exception:
|
||||
print('Error: could not verify credentials: key={} secret={}'.format(
|
||||
credentials.get('access_key'), credentials.get('secret_key')))
|
||||
return False
|
||||
|
||||
|
||||
def get_parsed_field(parsed_config, fields):
|
||||
"""
|
||||
Parsed the value from web profile page, 'copy to clipboard' option
|
||||
:param parsed_config: The parsed value from the web ui
|
||||
:type parsed_config: Config object
|
||||
:param fields: list of values to parse, will parse by the list order
|
||||
:type fields: List[str]
|
||||
:return: parsed value if found, None else
|
||||
"""
|
||||
try:
|
||||
return parsed_config.get("api").get(fields[0])
|
||||
except ConfigMissingException: # fallback - try to parse the field like it was in web older version
|
||||
if len(fields) == 1:
|
||||
return parsed_config.get(fields[0])
|
||||
elif len(fields) == 2:
|
||||
return parsed_config.get(fields[1])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def read_manual_credentials():
|
||||
print('Enter user access key: ', end='')
|
||||
access_key = input()
|
||||
print('Enter user secret: ', end='')
|
||||
secret_key = input()
|
||||
return {"access_key": access_key, "secret_key": secret_key}
|
||||
|
||||
|
||||
def input_url(host_type, host=None):
|
||||
while True:
|
||||
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||
print('{} configured to: {}'.format(host_type, '[{}] '.format(host) if host else ''), end='')
|
||||
parse_input = input()
|
||||
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
|
||||
break
|
||||
if parse_input and verify_url(parse_input):
|
||||
host = parse_input
|
||||
parsed_host = verify_url(parse_input) if parse_input else None
|
||||
if parse_input and parsed_host:
|
||||
host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
break
|
||||
return host
|
||||
|
||||
|
||||
def input_host_port(host_type, parsed_host):
|
||||
print('Enter port for {} host '.format(host_type), end='')
|
||||
replace_port = input().lower()
|
||||
return parsed_host.scheme + "://" + parsed_host.netloc + (
|
||||
':{}'.format(replace_port) if replace_port else '') + parsed_host.path
|
||||
|
||||
|
||||
def verify_url(parse_input):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
||||
# if we have a specific port, use http prefix, otherwise assume https
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -55,35 +55,41 @@ class EnvironmentConfig(object):
|
||||
|
||||
|
||||
ENVIRONMENT_CONFIG = {
|
||||
"api.api_server": EnvironmentConfig("TRAINS_API_HOST", "ALG_API_HOST"),
|
||||
"api.api_server": EnvironmentConfig("TRAINS_API_HOST", ),
|
||||
"api.credentials.access_key": EnvironmentConfig(
|
||||
"TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY"
|
||||
"TRAINS_API_ACCESS_KEY",
|
||||
),
|
||||
"api.credentials.secret_key": EnvironmentConfig(
|
||||
"TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY"
|
||||
"TRAINS_API_SECRET_KEY",
|
||||
),
|
||||
"agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", "ALG_WORKER_NAME"),
|
||||
"agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", "ALG_WORKER_ID"),
|
||||
"agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", ),
|
||||
"agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", ),
|
||||
"agent.cuda_version": EnvironmentConfig(
|
||||
"TRAINS_CUDA_VERSION", "ALG_CUDA_VERSION", "CUDA_VERSION"
|
||||
"TRAINS_CUDA_VERSION", "CUDA_VERSION"
|
||||
),
|
||||
"agent.cudnn_version": EnvironmentConfig(
|
||||
"TRAINS_CUDNN_VERSION", "ALG_CUDNN_VERSION", "CUDNN_VERSION"
|
||||
"TRAINS_CUDNN_VERSION", "CUDNN_VERSION"
|
||||
),
|
||||
"agent.cpu_only": EnvironmentConfig(
|
||||
"TRAINS_CPU_ONLY", "ALG_CPU_ONLY", "CPU_ONLY", type=bool
|
||||
"TRAINS_CPU_ONLY", "CPU_ONLY", type=bool
|
||||
),
|
||||
"sdk.aws.s3.key": EnvironmentConfig("AWS_ACCESS_KEY_ID"),
|
||||
"sdk.aws.s3.secret": EnvironmentConfig("AWS_SECRET_ACCESS_KEY"),
|
||||
"sdk.aws.s3.region": EnvironmentConfig("AWS_DEFAULT_REGION"),
|
||||
"sdk.azure.storage.containers.0": {'account_name': EnvironmentConfig("AZURE_STORAGE_ACCOUNT"),
|
||||
'account_key': EnvironmentConfig("AZURE_STORAGE_KEY")},
|
||||
"sdk.google.storage.credentials_json": EnvironmentConfig("GOOGLE_APPLICATION_CREDENTIALS"),
|
||||
}
|
||||
|
||||
CONFIG_FILE_ENV = EnvironmentConfig("ALG_CONFIG_FILE")
|
||||
|
||||
ENVIRONMENT_SDK_PARAMS = {
|
||||
"task_id": ("TRAINS_TASK_ID", "ALG_TASK_ID"),
|
||||
"config_file": ("TRAINS_CONFIG_FILE", "ALG_CONFIG_FILE", "TRAINS_CONFIG_FILE"),
|
||||
"log_level": ("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL"),
|
||||
"log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND"),
|
||||
"task_id": ("TRAINS_TASK_ID", ),
|
||||
"config_file": ("TRAINS_CONFIG_FILE", ),
|
||||
"log_level": ("TRAINS_LOG_LEVEL", ),
|
||||
"log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", ),
|
||||
}
|
||||
|
||||
ENVIRONMENT_BACKWARD_COMPATIBLE = EnvironmentConfig("TRAINS_AGENT_ALG_ENV", type=bool)
|
||||
|
||||
VIRTUAL_ENVIRONMENT_PATH = {
|
||||
"python2": normalize_path(CONFIG_DIR, "py2venv"),
|
||||
"python3": normalize_path(CONFIG_DIR, "py3venv"),
|
||||
@@ -107,13 +113,19 @@ HTTP_HEADERS = {
|
||||
METADATA_EXTENSION = ".json"
|
||||
|
||||
DEFAULT_VENV_UPDATE_URL = (
|
||||
"https://raw.githubusercontent.com/Yelp/venv-update/v3.2.2/venv_update.py"
|
||||
"https://raw.githubusercontent.com/Yelp/venv-update/v3.2.4/venv_update.py"
|
||||
)
|
||||
WORKING_REPOSITORY_DIR = "task_repository"
|
||||
DEFAULT_VCS_CACHE = normalize_path(CONFIG_DIR, "vcs-cache")
|
||||
PIP_EXTRA_INDICES = [
|
||||
]
|
||||
DEFAULT_PIP_DOWNLOAD_CACHE = normalize_path(CONFIG_DIR, "pip-download-cache")
|
||||
ENV_AGENT_GIT_USER = EnvironmentConfig('TRAINS_AGENT_GIT_USER')
|
||||
ENV_AGENT_GIT_PASS = EnvironmentConfig('TRAINS_AGENT_GIT_PASS')
|
||||
ENV_AGENT_GIT_HOST = EnvironmentConfig('TRAINS_AGENT_GIT_HOST')
|
||||
ENV_TASK_EXECUTE_AS_USER = 'TRAINS_AGENT_EXEC_USER'
|
||||
ENV_TASK_EXTRA_PYTHON_PATH = 'TRAINS_AGENT_EXTRA_PYTHON_PATH'
|
||||
ENV_DOCKER_HOST_MOUNT = EnvironmentConfig('TRAINS_AGENT_K8S_HOST_MOUNT', 'TRAINS_AGENT_DOCKER_HOST_MOUNT')
|
||||
|
||||
|
||||
class FileBuffering(IntEnum):
|
||||
|
||||
0
trains_agent/external/__init__.py
vendored
Normal file
0
trains_agent/external/__init__.py
vendored
Normal file
22
trains_agent/external/requirements_parser/__init__.py
vendored
Normal file
22
trains_agent/external/requirements_parser/__init__.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
from .parser import parse # noqa
|
||||
|
||||
_MAJOR = 0
|
||||
_MINOR = 2
|
||||
_PATCH = 0
|
||||
|
||||
|
||||
def version_tuple():
|
||||
'''
|
||||
Returns a 3-tuple of ints that represent the version
|
||||
'''
|
||||
return (_MAJOR, _MINOR, _PATCH)
|
||||
|
||||
|
||||
def version():
|
||||
'''
|
||||
Returns a string representation of the version
|
||||
'''
|
||||
return '%d.%d.%d' % (version_tuple())
|
||||
|
||||
|
||||
__version__ = version()
|
||||
44
trains_agent/external/requirements_parser/fragment.py
vendored
Normal file
44
trains_agent/external/requirements_parser/fragment.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
import re
|
||||
|
||||
# Copied from pip
|
||||
# https://github.com/pypa/pip/blob/281eb61b09d87765d7c2b92f6982b3fe76ccb0af/pip/index.py#L947
|
||||
HASH_ALGORITHMS = set(['sha1', 'sha224', 'sha384', 'sha256', 'sha512', 'md5'])
|
||||
|
||||
extras_require_search = re.compile(
|
||||
r'(?P<name>.+)\[(?P<extras>[^\]]+)\]').search
|
||||
|
||||
|
||||
def parse_fragment(fragment_string):
|
||||
"""Takes a fragment string nd returns a dict of the components"""
|
||||
fragment_string = fragment_string.lstrip('#')
|
||||
|
||||
try:
|
||||
return dict(
|
||||
key_value_string.split('=')
|
||||
for key_value_string in fragment_string.split('&')
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
'Invalid fragment string {fragment_string}'.format(
|
||||
fragment_string=fragment_string
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_hash_info(d):
|
||||
"""Returns the first matching hashlib name and value from a dict"""
|
||||
for key in d.keys():
|
||||
if key.lower() in HASH_ALGORITHMS:
|
||||
return key, d[key]
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def parse_extras_require(egg):
|
||||
if egg is not None:
|
||||
match = extras_require_search(egg)
|
||||
if match is not None:
|
||||
name = match.group('name')
|
||||
extras = match.group('extras')
|
||||
return name, [extra.strip() for extra in extras.split(',')]
|
||||
return egg, []
|
||||
50
trains_agent/external/requirements_parser/parser.py
vendored
Normal file
50
trains_agent/external/requirements_parser/parser.py
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from .requirement import Requirement
|
||||
|
||||
|
||||
def parse(reqstr):
|
||||
"""
|
||||
Parse a requirements file into a list of Requirements
|
||||
|
||||
See: pip/req.py:parse_requirements()
|
||||
|
||||
:param reqstr: a string or file like object containing requirements
|
||||
:returns: a *generator* of Requirement objects
|
||||
"""
|
||||
filename = getattr(reqstr, 'name', None)
|
||||
try:
|
||||
# Python 2.x compatibility
|
||||
if not isinstance(reqstr, basestring):
|
||||
reqstr = reqstr.read()
|
||||
except NameError:
|
||||
# Python 3.x only
|
||||
if not isinstance(reqstr, str):
|
||||
reqstr = reqstr.read()
|
||||
|
||||
for line in reqstr.splitlines():
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
continue
|
||||
elif not line or line.startswith('#'):
|
||||
# comments are lines that start with # only
|
||||
continue
|
||||
elif line.startswith('-r') or line.startswith('--requirement'):
|
||||
_, new_filename = line.split()
|
||||
new_file_path = os.path.join(os.path.dirname(filename or '.'),
|
||||
new_filename)
|
||||
with open(new_file_path) as f:
|
||||
for requirement in parse(f):
|
||||
yield requirement
|
||||
elif line.startswith('-f') or line.startswith('--find-links') or \
|
||||
line.startswith('-i') or line.startswith('--index-url') or \
|
||||
line.startswith('--extra-index-url') or \
|
||||
line.startswith('--no-index'):
|
||||
warnings.warn('Private repos not supported. Skipping.')
|
||||
continue
|
||||
elif line.startswith('-Z') or line.startswith('--always-unzip'):
|
||||
warnings.warn('Unused option --always-unzip. Skipping.')
|
||||
continue
|
||||
else:
|
||||
yield Requirement.parse(line)
|
||||
241
trains_agent/external/requirements_parser/requirement.py
vendored
Normal file
241
trains_agent/external/requirements_parser/requirement.py
vendored
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import unicode_literals
|
||||
import re
|
||||
from pkg_resources import Requirement as Req
|
||||
|
||||
from .fragment import get_hash_info, parse_fragment, parse_extras_require
|
||||
from .vcs import VCS, VCS_SCHEMES
|
||||
|
||||
|
||||
URI_REGEX = re.compile(
|
||||
r'^(?P<scheme>https?|file|ftps?)://(?P<path>[^#]+)'
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
VCS_REGEX = re.compile(
|
||||
r'^(?P<scheme>{0})://'.format(r'|'.join(
|
||||
[scheme.replace('+', r'\+') for scheme in VCS_SCHEMES])) +
|
||||
r'((?P<login>[^/@]+)@)?'
|
||||
r'(?P<path>[^#@]+)'
|
||||
r'(@(?P<revision>[^#]+))?'
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
# This matches just about everyting
|
||||
LOCAL_REGEX = re.compile(
|
||||
r'^((?P<scheme>file)://)?'
|
||||
r'(?P<path>[^#]+)' +
|
||||
r'(#(?P<fragment>\S+))?'
|
||||
)
|
||||
|
||||
|
||||
class Requirement(object):
|
||||
"""
|
||||
Represents a single requirementfrom trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
Typically instances of this class are created with ``Requirement.parse``.
|
||||
For local file requirements, there's no verification that the file
|
||||
exists. This class attempts to be *dict-like*.
|
||||
|
||||
See: http://www.pip-installer.org/en/latest/logic.html
|
||||
|
||||
**Members**:
|
||||
|
||||
* ``line`` - the actual requirement line being parsed
|
||||
* ``editable`` - a boolean whether this requirement is "editable"
|
||||
* ``local_file`` - a boolean whether this requirement is a local file/path
|
||||
* ``specifier`` - a boolean whether this requirement used a requirement
|
||||
specifier (eg. "django>=1.5" or "requirements")
|
||||
* ``vcs`` - a string specifying the version control system
|
||||
* ``revision`` - a version control system specifier
|
||||
* ``name`` - the name of the requirement
|
||||
* ``uri`` - the URI if this requirement was specified by URI
|
||||
* ``subdirectory`` - the subdirectory fragment of the URI
|
||||
* ``path`` - the local path to the requirement
|
||||
* ``hash_name`` - the type of hashing algorithm indicated in the line
|
||||
* ``hash`` - the hash value indicated by the requirement line
|
||||
* ``extras`` - a list of extras for this requirement
|
||||
(eg. "mymodule[extra1, extra2]")
|
||||
* ``specs`` - a list of specs for this requirement
|
||||
(eg. "mymodule>1.5,<1.6" => [('>', '1.5'), ('<', '1.6')])
|
||||
"""
|
||||
|
||||
def __init__(self, line):
|
||||
# Do not call this private method
|
||||
self.line = line
|
||||
self.editable = False
|
||||
self.local_file = False
|
||||
self.specifier = False
|
||||
self.vcs = None
|
||||
self.name = None
|
||||
self.subdirectory = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.revision = None
|
||||
self.hash_name = None
|
||||
self.hash = None
|
||||
self.extras = []
|
||||
self.specs = []
|
||||
|
||||
def __repr__(self):
|
||||
return '<Requirement: "{0}">'.format(self.line)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
@classmethod
|
||||
def parse_editable(cls, line):
|
||||
"""
|
||||
Parses a Requirement from an "editable" requirement which is either
|
||||
a local project path or a VCS project URI.
|
||||
|
||||
See: pip/req.py:from_editable()
|
||||
|
||||
:param line: an "editable" requirement
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
|
||||
req = cls('-e {0}'.format(line))
|
||||
req.editable = True
|
||||
vcs_match = VCS_REGEX.match(line)
|
||||
local_match = LOCAL_REGEX.match(line)
|
||||
|
||||
if vcs_match is not None:
|
||||
groups = vcs_match.groupdict()
|
||||
if groups.get('login'):
|
||||
req.uri = '{scheme}://{login}@{path}'.format(**groups)
|
||||
else:
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
req.revision = groups['revision']
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
for vcs in VCS:
|
||||
if req.uri.startswith(vcs):
|
||||
req.vcs = vcs
|
||||
else:
|
||||
assert local_match is not None, 'This should match everything'
|
||||
groups = local_match.groupdict()
|
||||
req.local_file = True
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
req.path = groups['path']
|
||||
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
def parse_line(cls, line):
|
||||
"""
|
||||
Parses a Requirement from a non-editable requirement.
|
||||
|
||||
See: pip/req.py:from_line()
|
||||
|
||||
:param line: a "non-editable" requirement
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
|
||||
req = cls(line)
|
||||
|
||||
vcs_match = VCS_REGEX.match(line)
|
||||
uri_match = URI_REGEX.match(line)
|
||||
local_match = LOCAL_REGEX.match(line)
|
||||
|
||||
if vcs_match is not None:
|
||||
groups = vcs_match.groupdict()
|
||||
if groups.get('login'):
|
||||
req.uri = '{scheme}://{login}@{path}'.format(**groups)
|
||||
else:
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
req.revision = groups['revision']
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
for vcs in VCS:
|
||||
if req.uri.startswith(vcs):
|
||||
req.vcs = vcs
|
||||
elif uri_match is not None:
|
||||
groups = uri_match.groupdict()
|
||||
req.uri = '{scheme}://{path}'.format(**groups)
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
req.name, req.extras = parse_extras_require(egg)
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
if groups['scheme'] == 'file':
|
||||
req.local_file = True
|
||||
elif '#egg=' in line:
|
||||
# Assume a local file match
|
||||
assert local_match is not None, 'This should match everything'
|
||||
groups = local_match.groupdict()
|
||||
req.local_file = True
|
||||
if groups['fragment']:
|
||||
fragment = parse_fragment(groups['fragment'])
|
||||
egg = fragment.get('egg')
|
||||
name, extras = parse_extras_require(egg)
|
||||
req.name = fragment.get('egg')
|
||||
req.hash_name, req.hash = get_hash_info(fragment)
|
||||
req.subdirectory = fragment.get('subdirectory')
|
||||
req.path = groups['path']
|
||||
else:
|
||||
# This is a requirement specifier.
|
||||
# Delegate to pkg_resources and hope for the best
|
||||
req.specifier = True
|
||||
pkg_req = Req.parse(line)
|
||||
req.name = pkg_req.unsafe_name
|
||||
req.extras = list(pkg_req.extras)
|
||||
req.specs = pkg_req.specs
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line):
|
||||
"""
|
||||
Parses a Requirement from a line of a requirement file.
|
||||
|
||||
:param line: a line of a requirement file
|
||||
:returns: a Requirement instance for the given line
|
||||
:raises: ValueError on an invalid requirement
|
||||
"""
|
||||
line = line.lstrip()
|
||||
if line.startswith('-e') or line.startswith('--editable'):
|
||||
# Editable installs are either a local project path
|
||||
# or a VCS project URI
|
||||
return cls.parse_editable(
|
||||
re.sub(r'^(-e|--editable=?)\s*', '', line))
|
||||
elif '@' in line and ('#' not in line or line.index('#') > line.index('@')):
|
||||
# Allegro bug fix: support 'name @ git+' entries
|
||||
name, uri = line.split('@', 1)
|
||||
name = name.strip()
|
||||
uri = uri.strip()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# check if the name is valid & parsed
|
||||
Req.parse(name)
|
||||
# if we are here, name is a valid package name, check if the vcs part is valid
|
||||
if VCS_REGEX.match(uri):
|
||||
req = cls.parse_line(uri)
|
||||
req.name = name
|
||||
return req
|
||||
elif URI_REGEX.match(uri):
|
||||
req = cls.parse_line(uri)
|
||||
req.name = name
|
||||
req.line = line
|
||||
return req
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return cls.parse_line(line)
|
||||
30
trains_agent/external/requirements_parser/vcs.py
vendored
Normal file
30
trains_agent/external/requirements_parser/vcs.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
VCS = [
|
||||
'git',
|
||||
'hg',
|
||||
'svn',
|
||||
'bzr',
|
||||
]
|
||||
|
||||
VCS_SCHEMES = [
|
||||
'git',
|
||||
'git+https',
|
||||
'git+ssh',
|
||||
'git+git',
|
||||
'hg+http',
|
||||
'hg+https',
|
||||
'hg+static-http',
|
||||
'hg+ssh',
|
||||
'svn',
|
||||
'svn+svn',
|
||||
'svn+http',
|
||||
'svn+https',
|
||||
'svn+ssh',
|
||||
'bzr+http',
|
||||
'bzr+https',
|
||||
'bzr+ssh',
|
||||
'bzr+sftp',
|
||||
'bzr+ftp',
|
||||
'bzr+lp',
|
||||
]
|
||||
1
trains_agent/glue/__init__.py
Normal file
1
trains_agent/glue/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
499
trains_agent/glue/k8s.py
Normal file
499
trains_agent/glue/k8s.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from __future__ import print_function, division, unicode_literals
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import Text, List
|
||||
|
||||
from trains_agent.commands.events import Events
|
||||
from trains_agent.commands.worker import Worker
|
||||
from trains_agent.errors import APIError
|
||||
from trains_agent.helper.base import safe_remove_file
|
||||
from trains_agent.helper.dicts import merge_dicts
|
||||
from trains_agent.helper.process import get_bash_output
|
||||
from trains_agent.helper.resource_monitor import ResourceMonitor
|
||||
from trains_agent.interface.base import ObjectID
|
||||
|
||||
|
||||
class K8sIntegration(Worker):
|
||||
K8S_PENDING_QUEUE = "k8s_scheduler"
|
||||
|
||||
KUBECTL_APPLY_CMD = "kubectl apply -f"
|
||||
|
||||
KUBECTL_RUN_CMD = "kubectl run trains-id-{task_id} " \
|
||||
"--image {docker_image} " \
|
||||
"--restart=Never --replicas=1 " \
|
||||
"--generator=run-pod/v1 " \
|
||||
"--namespace=trains"
|
||||
|
||||
KUBECTL_DELETE_CMD = "kubectl delete pods " \
|
||||
"--selector=TRAINS=agent " \
|
||||
"--field-selector=status.phase!=Pending,status.phase!=Running " \
|
||||
"--namespace=trains"
|
||||
|
||||
BASH_INSTALL_SSH_CMD = [
|
||||
"apt-get install -y openssh-server",
|
||||
"mkdir -p /var/run/sshd",
|
||||
"echo 'root:training' | chpasswd",
|
||||
"echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config",
|
||||
"sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config",
|
||||
r"sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd",
|
||||
"echo 'AcceptEnv TRAINS_API_ACCESS_KEY TRAINS_API_SECRET_KEY' >> /etc/ssh/sshd_config",
|
||||
'echo "export VISIBLE=now" >> /etc/profile',
|
||||
'echo "export PATH=$PATH" >> /etc/profile',
|
||||
'echo "ldconfig" >> /etc/profile',
|
||||
"/usr/sbin/sshd -p {port}"]
|
||||
|
||||
CONTAINER_BASH_SCRIPT = [
|
||||
"export DEBIAN_FRONTEND='noninteractive'",
|
||||
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/docker-clean",
|
||||
"chown -R root /root/.cache/pip",
|
||||
"apt-get update",
|
||||
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0",
|
||||
"declare LOCAL_PYTHON",
|
||||
"for i in {{10..5}}; do which python3.$i && python3.$i -m pip --version && "
|
||||
"export LOCAL_PYTHON=$(which python3.$i) && break ; done",
|
||||
"[ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip",
|
||||
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3",
|
||||
"$LOCAL_PYTHON -m pip install trains-agent",
|
||||
"{extra_bash_init_cmd}",
|
||||
"$LOCAL_PYTHON -m trains_agent execute --full-monitoring --require-queue --id {task_id}"
|
||||
]
|
||||
|
||||
AGENT_LABEL = "TRAINS=agent"
|
||||
LIMIT_POD_LABEL = "ai.allegro.agent.serial=pod-{pod_number}"
|
||||
|
||||
_edit_hyperparams_version = "2.9"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k8s_pending_queue_name=None,
|
||||
kubectl_cmd=None,
|
||||
container_bash_script=None,
|
||||
debug=False,
|
||||
ports_mode=False,
|
||||
num_of_services=20,
|
||||
user_props_cb=None,
|
||||
overrides_yaml=None,
|
||||
template_yaml=None,
|
||||
trains_conf_file=None,
|
||||
extra_bash_init_script=None,
|
||||
):
|
||||
"""
|
||||
Initialize the k8s integration glue layer daemon
|
||||
|
||||
:param str k8s_pending_queue_name: queue name to use when task is pending in the k8s scheduler
|
||||
:param str|callable kubectl_cmd: kubectl command line str, supports formatting (default: KUBECTL_RUN_CMD)
|
||||
example: "task={task_id} image={docker_image} queue_id={queue_id}"
|
||||
or a callable function: kubectl_cmd(task_id, docker_image, queue_id, task_data)
|
||||
:param str container_bash_script: container bash script to be executed in k8s (default: CONTAINER_BASH_SCRIPT)
|
||||
Notice this string will use format() call, if you have curly brackets they should be doubled { -> {{
|
||||
Format arguments passed: {task_id} and {extra_bash_init_cmd}
|
||||
:param bool debug: Switch logging on
|
||||
:param bool ports_mode: Adds a label to each pod which can be used in services in order to expose ports.
|
||||
Requires the `num_of_services` parameter.
|
||||
:param int num_of_services: Number of k8s services configured in the cluster. Required if `port_mode` is True.
|
||||
(default: 20)
|
||||
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
|
||||
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
|
||||
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
|
||||
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
|
||||
:param str template_yaml: YAML file containing the template for the pod (optional).
|
||||
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
|
||||
:param str trains_conf_file: trains.conf file to be use by the pod itself (optional)
|
||||
:param str extra_bash_init_script: Additional bash script to run before starting the Task inside the container
|
||||
"""
|
||||
super(K8sIntegration, self).__init__()
|
||||
self.k8s_pending_queue_name = k8s_pending_queue_name or self.K8S_PENDING_QUEUE
|
||||
self.kubectl_cmd = kubectl_cmd or self.KUBECTL_RUN_CMD
|
||||
self.container_bash_script = container_bash_script or self.CONTAINER_BASH_SCRIPT
|
||||
# Always do system packages, because by we will be running inside a docker
|
||||
self._session.config.put("agent.package_manager.system_site_packages", True)
|
||||
# Add debug logging
|
||||
if debug:
|
||||
self.log.logger.disabled = False
|
||||
self.log.logger.setLevel(logging.INFO)
|
||||
self.ports_mode = ports_mode
|
||||
self.num_of_services = num_of_services
|
||||
self._edit_hyperparams_support = None
|
||||
self._user_props_cb = user_props_cb
|
||||
self.trains_conf_file = None
|
||||
self.overrides_json_string = None
|
||||
self.template_dict = None
|
||||
self.extra_bash_init_script = extra_bash_init_script or None
|
||||
if self.extra_bash_init_script and not isinstance(self.extra_bash_init_script, str):
|
||||
self.extra_bash_init_script = ' ; '.join(self.extra_bash_init_script) # noqa
|
||||
self.pod_limits = []
|
||||
self.pod_requests = []
|
||||
if overrides_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(overrides_yaml))), 'rt') as f:
|
||||
overrides = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
if overrides:
|
||||
containers = overrides.get('spec', {}).get('containers', [])
|
||||
for c in containers:
|
||||
resources = {str(k).lower(): v for k, v in c.get('resources', {}).items()}
|
||||
if not resources:
|
||||
continue
|
||||
if resources.get('limits'):
|
||||
self.pod_limits += ['{}={}'.format(k, v) for k, v in resources['limits'].items()]
|
||||
if resources.get('requests'):
|
||||
self.pod_requests += ['{}={}'.format(k, v) for k, v in resources['requests'].items()]
|
||||
# remove double entries
|
||||
self.pod_limits = list(set(self.pod_limits))
|
||||
self.pod_requests = list(set(self.pod_requests))
|
||||
if self.pod_limits or self.pod_requests:
|
||||
self.log.warning('Found pod container requests={} limits={}'.format(
|
||||
self.pod_limits, self.pod_requests))
|
||||
if containers:
|
||||
self.log.warning('Removing containers section: {}'.format(overrides['spec'].pop('containers')))
|
||||
self.overrides_json_string = json.dumps(overrides)
|
||||
if template_yaml:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(template_yaml))), 'rt') as f:
|
||||
self.template_dict = yaml.load(f, Loader=getattr(yaml, 'FullLoader', None))
|
||||
|
||||
if trains_conf_file:
|
||||
with open(os.path.expandvars(os.path.expanduser(str(trains_conf_file))), 'rt') as f:
|
||||
self.trains_conf_file = f.read()
|
||||
# make sure we use system packages!
|
||||
self.trains_conf_file += '\nagent.package_manager.system_site_packages=true\n'
|
||||
|
||||
def _set_task_user_properties(self, task_id: str, **properties: str):
|
||||
if self._edit_hyperparams_support is not True:
|
||||
# either not supported or never tested
|
||||
if self._edit_hyperparams_support == self._session.api_version:
|
||||
# tested against latest api_version, not supported
|
||||
return
|
||||
if not self._session.check_min_api_version(self._edit_hyperparams_version):
|
||||
# not supported due to insufficient api_version
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
return
|
||||
try:
|
||||
self._session.get(
|
||||
service="tasks",
|
||||
action="edit_hyper_params",
|
||||
task=task_id,
|
||||
hyperparams=[
|
||||
{
|
||||
"section": "properties",
|
||||
"name": k,
|
||||
"value": str(v),
|
||||
}
|
||||
for k, v in properties.items()
|
||||
],
|
||||
)
|
||||
# definitely supported
|
||||
self._runtime_props_support = True
|
||||
except APIError as error:
|
||||
if error.code == 404:
|
||||
self._edit_hyperparams_support = self._session.api_version
|
||||
|
||||
def run_one_task(self, queue: Text, task_id: Text, worker_args=None, **_):
|
||||
print('Pulling task {} launching on kubernetes cluster'.format(task_id))
|
||||
task_data = self._session.api_client.tasks.get_all(id=[task_id])[0]
|
||||
|
||||
# push task into the k8s queue, so we have visibility on pending tasks in the k8s scheduler
|
||||
try:
|
||||
print('Pushing task {} into temporary pending queue'.format(task_id))
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(task_id, queue=self.k8s_pending_queue_name,
|
||||
status_reason='k8s pending scheduler')
|
||||
except Exception as e:
|
||||
self.log.error("ERROR: Could not push back task [{}] to k8s pending queue [{}], error: {}".format(
|
||||
task_id, self.k8s_pending_queue_name, e))
|
||||
return
|
||||
|
||||
if task_data.execution.docker_cmd:
|
||||
docker_parts = task_data.execution.docker_cmd
|
||||
else:
|
||||
docker_parts = str(os.environ.get("TRAINS_DOCKER_IMAGE") or
|
||||
self._session.config.get("agent.default_docker.image", "nvidia/cuda"))
|
||||
|
||||
# take the first part, this is the docker image name (not arguments)
|
||||
docker_parts = docker_parts.split()
|
||||
docker_image = docker_parts[0]
|
||||
docker_args = docker_parts[1:] if len(docker_parts) > 1 else []
|
||||
|
||||
# get the trains.conf encoded file
|
||||
# noinspection PyProtectedMember
|
||||
hocon_config_encoded = (self.trains_conf_file or self._session._config_file).encode('ascii')
|
||||
create_trains_conf = "echo '{}' | base64 --decode >> ~/trains.conf".format(
|
||||
base64.b64encode(
|
||||
hocon_config_encoded
|
||||
).decode('ascii')
|
||||
)
|
||||
|
||||
if self.ports_mode:
|
||||
print("Kubernetes looking for available pod to use")
|
||||
|
||||
# Search for a free pod number
|
||||
pod_number = 1
|
||||
while self.ports_mode:
|
||||
kubectl_cmd_new = "kubectl get pods -l {pod_label},{agent_label} -n trains".format(
|
||||
pod_label=self.LIMIT_POD_LABEL.format(pod_number=pod_number),
|
||||
agent_label=self.AGENT_LABEL
|
||||
)
|
||||
process = subprocess.Popen(kubectl_cmd_new.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
output = '' if not output else output if isinstance(output, str) else output.decode('utf-8')
|
||||
error = '' if not error else error if isinstance(error, str) else error.decode('utf-8')
|
||||
|
||||
if not output:
|
||||
# No such pod exist so we can use the pod_number we found
|
||||
break
|
||||
if pod_number >= self.num_of_services:
|
||||
# All pod numbers are taken, exit
|
||||
self.log.warning(
|
||||
"kubectl last result: {}\n{}\nAll k8s services are in use, task '{}' "
|
||||
"will be enqueued back to queue '{}'".format(
|
||||
error, output, task_id, queue
|
||||
)
|
||||
)
|
||||
self._session.api_client.tasks.reset(task_id)
|
||||
self._session.api_client.tasks.enqueue(task_id, queue=queue)
|
||||
return
|
||||
pod_number += 1
|
||||
|
||||
labels = ([self.LIMIT_POD_LABEL.format(pod_number=pod_number)] if self.ports_mode else []) + [self.AGENT_LABEL]
|
||||
|
||||
if self.ports_mode:
|
||||
print("Kubernetes scheduling task id={} on pod={}".format(task_id, pod_number))
|
||||
else:
|
||||
print("Kubernetes scheduling task id={}".format(task_id))
|
||||
|
||||
if self.template_dict:
|
||||
output, error = self._kubectl_apply(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image, docker_args=docker_args,
|
||||
task_id=task_id, queue=queue)
|
||||
else:
|
||||
output, error = self._kubectl_run(
|
||||
create_trains_conf=create_trains_conf,
|
||||
labels=labels, docker_image=docker_image,
|
||||
task_data=task_data,
|
||||
task_id=task_id, queue=queue)
|
||||
|
||||
error = '' if not error else (error if isinstance(error, str) else error.decode('utf-8'))
|
||||
output = '' if not output else (output if isinstance(output, str) else output.decode('utf-8'))
|
||||
print('kubectl output:\n{}\n{}'.format(error, output))
|
||||
|
||||
if error:
|
||||
self.log.error("Running kubectl encountered an error: {}".format(error))
|
||||
elif self.ports_mode:
|
||||
user_props = {"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]}
|
||||
if self._user_props_cb:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb()
|
||||
user_props.update(custom_props)
|
||||
except Exception:
|
||||
pass
|
||||
self._set_task_user_properties(
|
||||
task_id=task_id,
|
||||
**user_props
|
||||
)
|
||||
|
||||
def _parse_docker_args(self, docker_args):
|
||||
# type: (list) -> dict
|
||||
kube_args = {'env': []}
|
||||
while docker_args:
|
||||
cmd = docker_args.pop().strip()
|
||||
if cmd in ('-e', '--env',):
|
||||
env = docker_args.pop().strip()
|
||||
key, value = env.split('=', 1)
|
||||
kube_args[key] += {key: value}
|
||||
else:
|
||||
self.log.warning('skipping docker argument {} (only -e --env supported)'.format(cmd))
|
||||
return kube_args
|
||||
|
||||
def _kubectl_apply(self, create_trains_conf, docker_image, docker_args, labels, queue, task_id):
|
||||
template = deepcopy(self.template_dict)
|
||||
template.setdefault('apiVersion', 'v1')
|
||||
template['kind'] = 'Pod'
|
||||
template.setdefault('metadata', {})
|
||||
name = 'trains-id-{task_id}'.format(task_id=task_id)
|
||||
template['metadata']['name'] = name
|
||||
template.setdefault('spec', {})
|
||||
template['spec'].setdefault('containers', [])
|
||||
if labels:
|
||||
labels_dict = dict(pair.split('=', 1) for pair in labels)
|
||||
template['metadata'].setdefault('labels', {})
|
||||
template['metadata']['labels'].update(labels_dict)
|
||||
container = self._parse_docker_args(docker_args)
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
|
||||
script_encoded = '\n'.join(
|
||||
['#!/bin/bash', ] +
|
||||
[line.format(extra_bash_init_cmd=self.extra_bash_init_script or '', task_id=task_id)
|
||||
for line in container_bash_script])
|
||||
|
||||
create_init_script = \
|
||||
"echo '{}' | base64 --decode >> ~/__start_agent__.sh ; " \
|
||||
"/bin/bash ~/__start_agent__.sh".format(
|
||||
base64.b64encode(
|
||||
script_encoded.encode('ascii')
|
||||
).decode('ascii'))
|
||||
|
||||
container = merge_dicts(
|
||||
container,
|
||||
dict(name=name, image=docker_image,
|
||||
command=['/bin/bash'],
|
||||
args=['-c', '{} ; {}'.format(create_trains_conf, create_init_script)])
|
||||
)
|
||||
|
||||
if template['spec']['containers']:
|
||||
template['spec']['containers'][0] = merge_dicts(template['spec']['containers'][0], container)
|
||||
else:
|
||||
template['spec']['containers'].append(container)
|
||||
|
||||
fp, yaml_file = tempfile.mkstemp(prefix='trains_k8stmpl_', suffix='.yml')
|
||||
os.close(fp)
|
||||
with open(yaml_file, 'wt') as f:
|
||||
yaml.dump(template, f)
|
||||
|
||||
kubectl_cmd = self.KUBECTL_APPLY_CMD.format(
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue,
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
# add the template file at the end
|
||||
kubectl_cmd += [yaml_file]
|
||||
try:
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
except Exception as ex:
|
||||
return None, str(ex)
|
||||
finally:
|
||||
safe_remove_file(yaml_file)
|
||||
|
||||
return output, error
|
||||
|
||||
def _kubectl_run(self, create_trains_conf, docker_image, labels, queue, task_data, task_id):
|
||||
if callable(self.kubectl_cmd):
|
||||
kubectl_cmd = self.kubectl_cmd(task_id, docker_image, queue, task_data)
|
||||
else:
|
||||
kubectl_cmd = self.kubectl_cmd.format(
|
||||
task_id=task_id,
|
||||
docker_image=docker_image,
|
||||
queue_id=queue
|
||||
)
|
||||
# make sure we provide a list
|
||||
if isinstance(kubectl_cmd, str):
|
||||
kubectl_cmd = kubectl_cmd.split()
|
||||
|
||||
if self.overrides_json_string:
|
||||
kubectl_cmd += ['--overrides=' + self.overrides_json_string]
|
||||
|
||||
if self.pod_limits:
|
||||
kubectl_cmd += ['--limits', ",".join(self.pod_limits)]
|
||||
if self.pod_requests:
|
||||
kubectl_cmd += ['--requests', ",".join(self.pod_requests)]
|
||||
|
||||
container_bash_script = [self.container_bash_script] if isinstance(self.container_bash_script, str) \
|
||||
else self.container_bash_script
|
||||
container_bash_script = ' ; '.join(container_bash_script)
|
||||
|
||||
kubectl_cmd += [
|
||||
"--labels=" + ",".join(labels),
|
||||
"--command",
|
||||
"--",
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"{} ; {}".format(create_trains_conf, container_bash_script.format(
|
||||
extra_bash_init_cmd=self.extra_bash_init_script, task_id=task_id)),
|
||||
]
|
||||
process = subprocess.Popen(kubectl_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
return output, error
|
||||
|
||||
def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
|
||||
"""
|
||||
:summary: Pull and run tasks from queues.
|
||||
:description: 1. Go through ``queues`` by order.
|
||||
2. Try getting the next task for each and run the first one that returns.
|
||||
3. Go to step 1
|
||||
:param queues: IDs of queues to pull tasks from
|
||||
:type queues: list of ``Text``
|
||||
:param worker_params: Worker command line arguments
|
||||
:type worker_params: ``trains_agent.helper.process.WorkerParams``
|
||||
"""
|
||||
events_service = self.get_service(Events)
|
||||
|
||||
# make sure we have a k8s pending queue
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._session.api_client.queues.create(self.k8s_pending_queue_name)
|
||||
except Exception:
|
||||
pass
|
||||
# get queue id
|
||||
self.k8s_pending_queue_name = self._resolve_name(self.k8s_pending_queue_name, "queues")
|
||||
|
||||
_last_machine_update_ts = 0
|
||||
while True:
|
||||
# iterate over queues (priority style, queues[0] is highest)
|
||||
for queue in queues:
|
||||
# delete old completed / failed pods
|
||||
get_bash_output(self.KUBECTL_DELETE_CMD)
|
||||
|
||||
# get next task in queue
|
||||
try:
|
||||
response = self._session.api_client.queues.get_next_task(queue=queue)
|
||||
except Exception as e:
|
||||
print("Warning: Could not access task queue [{}], error: {}".format(queue, e))
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
task_id = response.entry.task
|
||||
except AttributeError:
|
||||
print("No tasks in queue {}".format(queue))
|
||||
continue
|
||||
events_service.send_log_events(
|
||||
self.worker_id,
|
||||
task_id=task_id,
|
||||
lines="task {} pulled from {} by worker {}".format(
|
||||
task_id, queue, self.worker_id
|
||||
),
|
||||
level="INFO",
|
||||
)
|
||||
|
||||
self.report_monitor(ResourceMonitor.StatusReport(queues=queues, queue=queue, task=task_id))
|
||||
self.run_one_task(queue, task_id, worker_params)
|
||||
self.report_monitor(ResourceMonitor.StatusReport(queues=self.queues))
|
||||
break
|
||||
else:
|
||||
# sleep and retry polling
|
||||
print("No tasks in Queues, sleeping for {:.1f} seconds".format(self._polling_interval))
|
||||
sleep(self._polling_interval)
|
||||
|
||||
if self._session.config["agent.reload_config"]:
|
||||
self.reload_config()
|
||||
|
||||
def k8s_daemon(self, queue):
|
||||
"""
|
||||
Start the k8s Glue service.
|
||||
This service will be pulling tasks from *queue* and scheduling them for execution using kubectl.
|
||||
Notice all scheduled tasks are pushed back into K8S_PENDING_QUEUE,
|
||||
and popped when execution actually starts. This creates full visibility into the k8s scheduler.
|
||||
Manually popping a task from the K8S_PENDING_QUEUE,
|
||||
will cause the k8s scheduler to skip the execution once the scheduled tasks needs to be executed
|
||||
|
||||
:param list(str) queue: queue name to pull from
|
||||
"""
|
||||
return self.daemon(queues=[ObjectID(name=queue)] if queue else None,
|
||||
log_level=logging.INFO, foreground=True, docker=False)
|
||||
|
||||
@classmethod
|
||||
def get_ssh_server_bash(cls, ssh_port_number):
|
||||
return ' ; '.join(line.format(port=ssh_port_number) for line in cls.BASH_INSTALL_SSH_CMD)
|
||||
@@ -157,6 +157,10 @@ def is_windows_platform():
|
||||
return any(platform.win32_ver())
|
||||
|
||||
|
||||
def is_linux_platform():
|
||||
return 'linux' in platform.system().lower()
|
||||
|
||||
|
||||
def normalize_path(*paths):
|
||||
"""
|
||||
normalize_path
|
||||
@@ -169,13 +173,64 @@ def normalize_path(*paths):
|
||||
|
||||
|
||||
def safe_remove_file(filename, error_message=None):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(filename)
|
||||
if filename:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
if error_message:
|
||||
print(error_message)
|
||||
|
||||
|
||||
def safe_remove_tree(filename):
|
||||
if not filename:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
shutil.rmtree(filename, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_python_path(script_dir, entry_point, package_api, is_conda_env=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
python_path_cmd = package_api.get_python_command(
|
||||
["-c", "import sys; print('{}'.join(sys.path))".format(python_path_sep)])
|
||||
org_python_path = python_path_cmd.get_output(cwd=script_dir)
|
||||
# Add path of the script directory and executable directory
|
||||
python_path = '{}{python_path_sep}{}{python_path_sep}'.format(
|
||||
Path(script_dir).absolute().as_posix(),
|
||||
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
|
||||
python_path_sep=python_path_sep)
|
||||
if is_windows_platform():
|
||||
python_path = python_path.replace('/', '\\')
|
||||
|
||||
return python_path if is_conda_env else (python_path + org_python_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def add_python_path(base_path, extra_path):
|
||||
try:
|
||||
if not extra_path:
|
||||
return base_path
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
base_path = base_path or ''
|
||||
if not base_path.endswith(python_path_sep):
|
||||
base_path += python_path_sep
|
||||
base_path += extra_path.replace(':', python_path_sep)
|
||||
except:
|
||||
pass
|
||||
return base_path
|
||||
|
||||
|
||||
class Singleton(ABCMeta):
|
||||
_instances = {}
|
||||
|
||||
@@ -405,9 +460,9 @@ def chain_map(*args):
|
||||
return reduce(lambda x, y: x.update(y) or x, args, {})
|
||||
|
||||
|
||||
def check_directory_path(path):
|
||||
def check_directory_path(path, check_whitespace_in_path=True):
|
||||
message = 'Could not create directory "{}": {}'
|
||||
if not is_windows_platform():
|
||||
if not is_windows_platform() and check_whitespace_in_path:
|
||||
match = re.search(r'\s', path)
|
||||
if match:
|
||||
raise CommandFailedError(
|
||||
@@ -440,6 +495,17 @@ def rm_tree(root): # type: (Union[Path, Text]) -> None
|
||||
return shutil.rmtree(os.path.expanduser(os.path.expandvars(Text(root))), onerror=on_error)
|
||||
|
||||
|
||||
def rm_file(filename): # type: (Union[Path, Text]) -> None
|
||||
"""
|
||||
A version of os.unlink that will not raise error
|
||||
"""
|
||||
try:
|
||||
os.unlink(os.path.expanduser(os.path.expandvars(Text(filename))))
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_conda(config):
|
||||
return config['agent.package_manager.type'].lower() == 'conda'
|
||||
|
||||
@@ -489,6 +555,7 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
branch = nullable_string
|
||||
version_num = nullable_string
|
||||
tag = nullable_string
|
||||
docker_cmd = nullable_string
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task_info):
|
||||
@@ -506,4 +573,24 @@ class ExecutionInfo(NonStrictAttrs):
|
||||
execution.entry_point = entry_point
|
||||
execution.working_dir = working_dir or ""
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
execution.docker_cmd = task_info.execution.docker_cmd
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
class safe_furl(furl.furl):
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
|
||||
@port.setter
|
||||
def port(self, port):
|
||||
"""
|
||||
Any port value is valid
|
||||
"""
|
||||
self._port = port
|
||||
|
||||
@@ -4,7 +4,7 @@ from time import sleep
|
||||
import requests
|
||||
import json
|
||||
from threading import Thread
|
||||
from semantic_version import Version
|
||||
from .package.requirements import SimpleVersion
|
||||
from ..version import __version__
|
||||
|
||||
__check_update_thread = None
|
||||
@@ -30,11 +30,11 @@ def _check_new_version_available():
|
||||
return None
|
||||
trains_answer = update_server_releases.get("trains-agent", {})
|
||||
latest_version = trains_answer.get("version")
|
||||
cur_version = Version(cur_version)
|
||||
latest_version = Version(latest_version)
|
||||
if cur_version >= latest_version:
|
||||
cur_version = cur_version
|
||||
latest_version = latest_version or ''
|
||||
if SimpleVersion.compare_versions(cur_version, '>=', latest_version):
|
||||
return None
|
||||
patch_upgrade = latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
|
||||
patch_upgrade = True # latest_version.major == cur_version.major and latest_version.minor == cur_version.minor
|
||||
return str(latest_version), patch_upgrade, trains_answer.get("description").split("\r\n")
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from attr import attrs, attrib
|
||||
|
||||
import six
|
||||
from six import binary_type, text_type
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree
|
||||
from trains_agent.helper.base import nonstrict_in_place_sort
|
||||
|
||||
|
||||
def print_text(text, newline=True):
|
||||
@@ -22,6 +22,24 @@ def print_text(text, newline=True):
|
||||
sys.stdout.write(data)
|
||||
|
||||
|
||||
def decode_binary_lines(binary_lines, encoding='utf-8', replace_cr=False, overwrite_cr=False):
|
||||
# decode per line, if we failed decoding skip the line
|
||||
lines = []
|
||||
for b in binary_lines:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
line = b.decode(encoding=encoding, errors='replace')
|
||||
if replace_cr:
|
||||
line = line.replace('\r', '\n')
|
||||
elif overwrite_cr:
|
||||
cr_lines = line.split('\r')
|
||||
line = cr_lines[-1] if cr_lines[-1] or len(cr_lines) < 2 else cr_lines[-2]
|
||||
except Exception:
|
||||
line = ''
|
||||
lines.append(line + '\n' if not line or line[-1] != '\n' else line)
|
||||
return lines
|
||||
|
||||
|
||||
def ensure_text(s, encoding='utf-8', errors='strict'):
|
||||
"""Coerce *s* to six.text_type.
|
||||
For Python 2:
|
||||
|
||||
@@ -3,3 +3,15 @@ from typing import Callable, Dict, Any
|
||||
|
||||
def filter_keys(filter_, dct): # type: (Callable[[Any], bool], Dict) -> Dict
|
||||
return {key: value for key, value in dct.items() if filter_(key)}
|
||||
|
||||
|
||||
def merge_dicts(dict1, dict2):
|
||||
""" Recursively merges dict2 into dict1 """
|
||||
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
|
||||
return dict2
|
||||
for k in dict2:
|
||||
if k in dict1:
|
||||
dict1[k] = merge_dicts(dict1[k], dict2[k])
|
||||
else:
|
||||
dict1[k] = dict2[k]
|
||||
return dict1
|
||||
|
||||
@@ -200,24 +200,30 @@ class GPUStatCollection(object):
|
||||
GPUStatCollection.global_processes[nv_process.pid] = \
|
||||
psutil.Process(pid=nv_process.pid)
|
||||
ps_process = GPUStatCollection.global_processes[nv_process.pid]
|
||||
process['username'] = ps_process.username()
|
||||
# cmdline returns full path;
|
||||
# as in `ps -o comm`, get short cmdnames.
|
||||
_cmdline = ps_process.cmdline()
|
||||
if not _cmdline:
|
||||
# sometimes, zombie or unknown (e.g. [kworker/8:2H])
|
||||
process['command'] = '?'
|
||||
process['full_command'] = ['?']
|
||||
else:
|
||||
process['command'] = os.path.basename(_cmdline[0])
|
||||
process['full_command'] = _cmdline
|
||||
# Bytes to MBytes
|
||||
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
|
||||
process['cpu_percent'] = ps_process.cpu_percent()
|
||||
process['cpu_memory_usage'] = \
|
||||
round((ps_process.memory_percent() / 100.0) *
|
||||
psutil.virtual_memory().total)
|
||||
process['pid'] = nv_process.pid
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# we do not actually use these, so no point in collecting them
|
||||
# process['username'] = ps_process.username()
|
||||
# # cmdline returns full path;
|
||||
# # as in `ps -o comm`, get short cmdnames.
|
||||
# _cmdline = ps_process.cmdline()
|
||||
# if not _cmdline:
|
||||
# # sometimes, zombie or unknown (e.g. [kworker/8:2H])
|
||||
# process['command'] = '?'
|
||||
# process['full_command'] = ['?']
|
||||
# else:
|
||||
# process['command'] = os.path.basename(_cmdline[0])
|
||||
# process['full_command'] = _cmdline
|
||||
# process['cpu_percent'] = ps_process.cpu_percent()
|
||||
# process['cpu_memory_usage'] = \
|
||||
# round((ps_process.memory_percent() / 100.0) *
|
||||
# psutil.virtual_memory().total)
|
||||
# Bytes to MBytes
|
||||
process['gpu_memory_usage'] = nv_process.usedGpuMemory // MB
|
||||
except Exception:
|
||||
# insufficient permissions
|
||||
pass
|
||||
return process
|
||||
|
||||
if not GPUStatCollection._gpu_device_info.get(index):
|
||||
@@ -285,12 +291,13 @@ class GPUStatCollection(object):
|
||||
# e.g. nvidia-smi reset or reboot the system
|
||||
pass
|
||||
|
||||
# TODO: Do not block if full process info is not requested
|
||||
time.sleep(0.1)
|
||||
for process in processes:
|
||||
pid = process['pid']
|
||||
cache_process = GPUStatCollection.global_processes[pid]
|
||||
process['cpu_percent'] = cache_process.cpu_percent()
|
||||
# we do not actually use these, so no point in collecting them
|
||||
# # TODO: Do not block if full process info is not requested
|
||||
# time.sleep(0.1)
|
||||
# for process in processes:
|
||||
# pid = process['pid']
|
||||
# cache_process = GPUStatCollection.global_processes[pid]
|
||||
# process['cpu_percent'] = cache_process.cpu_percent()
|
||||
|
||||
index = N.nvmlDeviceGetIndex(handle)
|
||||
gpu_info = {
|
||||
|
||||
0
trains_agent/helper/os/__init__.py
Normal file
0
trains_agent/helper/os/__init__.py
Normal file
74
trains_agent/helper/os/daemonize.py
Normal file
74
trains_agent/helper/os/daemonize.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
|
||||
def daemonize_process(redirect_fd=None):
|
||||
"""
|
||||
Detach a process from the controlling terminal and run it in the background as a daemon.
|
||||
"""
|
||||
assert redirect_fd is None or isinstance(redirect_fd, int)
|
||||
|
||||
# re-spawn in the same directory
|
||||
WORKDIR = os.getcwd()
|
||||
|
||||
# The standard I/O file descriptors are redirected to /dev/null by default.
|
||||
if hasattr(os, "devnull"):
|
||||
devnull = os.devnull
|
||||
else:
|
||||
devnull = "/dev/null"
|
||||
|
||||
try:
|
||||
# Fork a child process so the parent can exit. This returns control to
|
||||
# the command-line or shell. It also guarantees that the child will not
|
||||
# be a process group leader, since the child receives a new process ID
|
||||
# and inherits the parent's process group ID. This step is required
|
||||
# to insure that the next call to os.setsid is successful.
|
||||
pid = os.fork()
|
||||
except OSError as e:
|
||||
raise Exception("%s [%d]" % (e.strerror, e.errno))
|
||||
|
||||
if pid == 0: # The first child.
|
||||
# To become the session leader of this new session and the process group
|
||||
# leader of the new process group, we call os.setsid().
|
||||
# The process is also guaranteed not to have a controlling terminal.
|
||||
os.setsid()
|
||||
|
||||
# Is ignoring SIGHUP necessary? (Set handlers for asynchronous events.)
|
||||
# import signal
|
||||
# signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||
|
||||
try:
|
||||
# Fork a second child and exit immediately to prevent zombies. This
|
||||
# causes the second child process to be orphaned, making the init
|
||||
# process responsible for its cleanup.
|
||||
pid = os.fork() # Fork a second child.
|
||||
except OSError as e:
|
||||
raise Exception("%s [%d]" % (e.strerror, e.errno))
|
||||
|
||||
if pid == 0: # The second child.
|
||||
# Since the current working directory may be a mounted filesystem, we
|
||||
# avoid the issue of not being able to unmount the filesystem at
|
||||
# shutdown time by changing it to the root directory.
|
||||
os.chdir(WORKDIR)
|
||||
# We probably don't want the file mode creation mask inherited from
|
||||
# the parent, so we give the child complete control over permissions.
|
||||
os.umask(0)
|
||||
else:
|
||||
# Exit parent (the first child) of the second child.
|
||||
os._exit(0)
|
||||
else:
|
||||
# Exit parent of the first child.
|
||||
os._exit(0)
|
||||
|
||||
# notice we count on the fact that we keep all file descriptors open,
|
||||
# since we opened then in the parent process, but the daemon process will use them
|
||||
|
||||
# Redirect the standard I/O file descriptors to the specified file /dev/null.
|
||||
if redirect_fd is None:
|
||||
redirect_fd = os.open(devnull, os.O_RDWR)
|
||||
|
||||
# Duplicate standard input to standard output and standard error.
|
||||
# standard output (1), standard error (2)
|
||||
os.dup2(redirect_fd, 1)
|
||||
os.dup2(redirect_fd, 2)
|
||||
|
||||
return 0
|
||||
@@ -5,7 +5,7 @@ from contextlib import contextmanager
|
||||
from typing import Text, Iterable, Union
|
||||
|
||||
import six
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines
|
||||
from trains_agent.helper.base import mkstemp, safe_remove_file, join_lines, select_for_platform
|
||||
from trains_agent.helper.process import Executable, Argv, PathLike
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ class PackageManager(object):
|
||||
"""
|
||||
|
||||
_selected_manager = None
|
||||
_cwd = None
|
||||
_pip_version = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def bin(self):
|
||||
@@ -64,7 +66,20 @@ class PackageManager(object):
|
||||
pass
|
||||
|
||||
def upgrade_pip(self):
|
||||
return self._install("pip", "--upgrade")
|
||||
result = self._install(
|
||||
select_for_platform(windows='"pip{}"', linux='pip{}').format(self.get_pip_version()), "--upgrade")
|
||||
packages = self.run_with_env(('list',), output=True).splitlines()
|
||||
# p.split is ('pip', 'x.y.z')
|
||||
pip = [p.split() for p in packages if len(p.split()) == 2 and p.split()[0] == 'pip']
|
||||
if pip:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from .requirements import MarkerRequirement
|
||||
pip = pip[0][1].split('.')
|
||||
MarkerRequirement.pip_new_version = bool(int(pip[0]) >= 20)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def get_python_command(self, extra=()):
|
||||
# type: (...) -> Executable
|
||||
@@ -97,11 +112,44 @@ class PackageManager(object):
|
||||
# this is helpful when we want out of context requirement installations
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
@property
|
||||
def cwd(self):
|
||||
return self._cwd
|
||||
|
||||
@cwd.setter
|
||||
def cwd(self, value):
|
||||
self._cwd = value
|
||||
|
||||
@classmethod
|
||||
def out_of_scope_install_package(cls, package_name):
|
||||
def out_of_scope_install_package(cls, package_name, *args):
|
||||
if PackageManager._selected_manager is not None:
|
||||
try:
|
||||
return PackageManager._selected_manager._install(package_name)
|
||||
result = PackageManager._selected_manager._install(package_name, *args)
|
||||
if result not in (0, None, True):
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def out_of_scope_freeze(cls):
|
||||
if PackageManager._selected_manager is not None:
|
||||
try:
|
||||
return PackageManager._selected_manager.freeze()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def set_pip_version(cls, version):
|
||||
if not version:
|
||||
return
|
||||
version = version.replace(' ', '')
|
||||
if ('=' in version) or ('~' in version) or ('<' in version) or ('>' in version):
|
||||
cls._pip_version = version
|
||||
else:
|
||||
cls._pip_version = "=="+version
|
||||
|
||||
@classmethod
|
||||
def get_pip_version(cls):
|
||||
return cls._pip_version or ''
|
||||
|
||||
@@ -2,25 +2,31 @@ from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Text, Iterable, Union, Dict, Set, Sequence, Any
|
||||
|
||||
import six
|
||||
import yaml
|
||||
from time import time
|
||||
from attr import attrs, attrib, Factory
|
||||
from pathlib2 import Path
|
||||
from semantic_version import Version
|
||||
from requirements import parse
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform
|
||||
from trains_agent.helper.base import rm_tree, NonStrictAttrs, select_for_platform, is_windows_platform, ExecutionInfo
|
||||
from trains_agent.helper.process import Argv, Executable, DEVNULL, CommandSequence, PathLike
|
||||
from trains_agent.helper.package.requirements import SimpleVersion
|
||||
from trains_agent.session import Session
|
||||
from .base import PackageManager
|
||||
from .pip_api.venv import VirtualenvPip
|
||||
from .requirements import RequirementsManager, MarkerRequirement
|
||||
from ...backend_api.session.defs import ENV_CONDA_ENV_PACKAGE
|
||||
|
||||
package_normalize = partial(re.compile(r"""\[version=['"](.*)['"]\]""").sub, r"\1")
|
||||
|
||||
@@ -36,7 +42,8 @@ def _package_diff(path, packages):
|
||||
|
||||
class CondaPip(VirtualenvPip):
|
||||
def __init__(self, source=None, *args, **kwargs):
|
||||
super(CondaPip, self).__init__(*args, **kwargs)
|
||||
super(CondaPip, self).__init__(*args, interpreter=Path(kwargs.get('path'), "python.exe")
|
||||
if is_windows_platform() and kwargs.get('path') else None, **kwargs)
|
||||
self.source = source
|
||||
|
||||
def run_with_env(self, command, output=False, **kwargs):
|
||||
@@ -54,10 +61,10 @@ class CondaAPI(PackageManager):
|
||||
A programmatic interface for controlling conda
|
||||
"""
|
||||
|
||||
MINIMUM_VERSION = Version("4.3.30", partial=True)
|
||||
MINIMUM_VERSION = "4.3.30"
|
||||
|
||||
def __init__(self, session, path, python, requirements_manager):
|
||||
# type: (Session, PathLike, float, RequirementsManager) -> None
|
||||
def __init__(self, session, path, python, requirements_manager, execution_info=None, **kwargs):
|
||||
# type: (Session, PathLike, float, RequirementsManager, ExecutionInfo, Any) -> None
|
||||
"""
|
||||
:param python: base python version to use (e.g python3.6)
|
||||
:param path: path of env
|
||||
@@ -67,7 +74,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = None
|
||||
self.requirements_manager = requirements_manager
|
||||
self.path = path
|
||||
self.env_read_only = False
|
||||
self.extra_channels = self.session.config.get('agent.package_manager.conda_channels', [])
|
||||
self.conda_env_as_base_docker = \
|
||||
self.session.config.get('agent.package_manager.conda_env_as_base_docker', None) or \
|
||||
bool(ENV_CONDA_ENV_PACKAGE.get())
|
||||
if ENV_CONDA_ENV_PACKAGE.get():
|
||||
self.conda_pre_build_env_path = ENV_CONDA_ENV_PACKAGE.get()
|
||||
else:
|
||||
self.conda_pre_build_env_path = execution_info.docker_cmd if execution_info else None
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
@@ -75,12 +90,17 @@ class CondaAPI(PackageManager):
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
self.conda = (
|
||||
find_executable("conda")
|
||||
or Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(shell=True).strip()
|
||||
)
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output()
|
||||
self.conda = (
|
||||
find_executable("conda") or
|
||||
Argv(select_for_platform(windows="where", linux="which"), "conda").get_output(
|
||||
shell=select_for_platform(windows=True, linux=False)).strip()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("ERROR: package manager \"conda\" selected, "
|
||||
"but \'conda\' executable could not be located")
|
||||
try:
|
||||
output = Argv(self.conda, "--version").get_output(stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
raise CommandFailedError(
|
||||
"Unable to determine conda version: {ex}, output={ex.output}".format(
|
||||
@@ -88,7 +108,7 @@ class CondaAPI(PackageManager):
|
||||
)
|
||||
)
|
||||
self.conda_version = self.get_conda_version(output)
|
||||
if Version(self.conda_version, partial=True) < self.MINIMUM_VERSION:
|
||||
if SimpleVersion.compare_versions(self.conda_version, '<', self.MINIMUM_VERSION):
|
||||
raise CommandFailedError(
|
||||
"conda version '{}' is smaller than minimum supported conda version '{}'".format(
|
||||
self.conda_version, self.MINIMUM_VERSION
|
||||
@@ -106,13 +126,58 @@ class CondaAPI(PackageManager):
|
||||
def bin(self):
|
||||
return self.pip.bin
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
def upgrade_pip(self):
|
||||
return self.pip.upgrade_pip()
|
||||
# do not change pip version if pre built environement is used
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping pip upgrade.')
|
||||
return ''
|
||||
return self._install(select_for_platform(windows='"pip{}"', linux='pip{}').format(self.pip.get_pip_version()))
|
||||
|
||||
def create(self):
|
||||
"""
|
||||
Create a new environment
|
||||
"""
|
||||
if self.conda_env_as_base_docker and self.conda_pre_build_env_path:
|
||||
if Path(self.conda_pre_build_env_path).is_dir():
|
||||
print("Using pre-existing Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
self.path = Path(self.conda_pre_build_env_path)
|
||||
self.source = ("conda", "activate", self.path.as_posix())
|
||||
self.pip = CondaPip(
|
||||
session=self.session,
|
||||
source=self.source,
|
||||
python=self.python,
|
||||
requirements_manager=self.requirements_manager,
|
||||
path=self.path,
|
||||
)
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
self.env_read_only = True
|
||||
return self
|
||||
elif Path(self.conda_pre_build_env_path).is_file():
|
||||
print("Restoring Conda environment from {}".format(self.conda_pre_build_env_path))
|
||||
tar_path = find_executable("tar")
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
output = Argv(
|
||||
tar_path,
|
||||
"-xzf",
|
||||
self.conda_pre_build_env_path,
|
||||
"-C",
|
||||
self.path,
|
||||
).get_output()
|
||||
|
||||
self.source = self.pip.source = ("conda", "activate", self.path.as_posix())
|
||||
conda_env = self._get_conda_sh()
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
# unpack cleanup
|
||||
print("Fixing prefix in Conda environment {}".format(self.path))
|
||||
CommandSequence(('source', conda_env.as_posix()),
|
||||
((self.path / 'bin' / 'conda-unpack').as_posix(), )).get_output()
|
||||
return self
|
||||
else:
|
||||
raise ValueError("Could not restore Conda environment, cannot find {}".format(
|
||||
self.conda_pre_build_env_path))
|
||||
|
||||
output = Argv(
|
||||
self.conda,
|
||||
"create",
|
||||
@@ -128,13 +193,15 @@ class CondaAPI(PackageManager):
|
||||
self.source = self.pip.source = (
|
||||
tuple(match.group(1).split()) + (match.group(2),)
|
||||
if match
|
||||
else ("activate", self.path)
|
||||
else ("conda", "activate", self.path.as_posix())
|
||||
)
|
||||
conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
|
||||
conda_env = self._get_conda_sh()
|
||||
if conda_env.is_file() and not is_windows_platform():
|
||||
self.source = self.pip.source = CommandSequence(('source', conda_env.as_posix()), self.source)
|
||||
|
||||
# install cuda toolkit
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = float(int(self.session.config['agent.cuda_version'])) / 10.0
|
||||
if cuda_version > 0:
|
||||
@@ -161,6 +228,12 @@ class CondaAPI(PackageManager):
|
||||
except Exception:
|
||||
pass
|
||||
rm_tree(self.path)
|
||||
# if we failed removing the path, change it's name
|
||||
if is_windows_platform() and Path(self.path).exists():
|
||||
try:
|
||||
Path(self.path).rename(Path(self.path).as_posix() + '_' + str(time()))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _install_from_file(self, path):
|
||||
"""
|
||||
@@ -170,6 +243,10 @@ class CondaAPI(PackageManager):
|
||||
|
||||
def _install(self, *args):
|
||||
# type: (*PathLike) -> ()
|
||||
# if we are in read only mode, do not install anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package installing: {}'.format(args))
|
||||
return
|
||||
channels_args = tuple(
|
||||
chain.from_iterable(("-c", channel) for channel in self.extra_channels)
|
||||
)
|
||||
@@ -197,6 +274,10 @@ class CondaAPI(PackageManager):
|
||||
return self._install(*packages)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping package uninstalling: {}'.format(packages))
|
||||
return ''
|
||||
return self._run_command(("uninstall", "-p", self.path))
|
||||
|
||||
def install_from_file(self, path):
|
||||
@@ -215,31 +296,281 @@ class CondaAPI(PackageManager):
|
||||
with self.temp_file("pip_reqs", pip_packages) as reqs:
|
||||
self.pip.install_from_file(reqs)
|
||||
|
||||
def freeze(self):
|
||||
# result = yaml.load(
|
||||
# self._run_command((self.conda, "env", "export", "-p", self.path), raw=True)
|
||||
# )
|
||||
# for key in "name", "prefix":
|
||||
# result.pop(key, None)
|
||||
# freeze = {"conda": result}
|
||||
# try:
|
||||
# freeze["pip"] = result["dependencies"][-1]["pip"]
|
||||
# except (TypeError, KeyError):
|
||||
# freeze["pip"] = []
|
||||
# else:
|
||||
# del result["dependencies"][-1]
|
||||
# return freeze
|
||||
return self.pip.freeze()
|
||||
def freeze(self, freeze_full_environment=False):
|
||||
requirements = self.pip.freeze()
|
||||
req_lines = []
|
||||
conda_lines = []
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
pip_lines = requirements['pip']
|
||||
conda_packages_json = json.loads(
|
||||
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||
for r in conda_packages_json:
|
||||
# check if this is a pypi package, if it is, leave it outside
|
||||
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||
name = (r['name'].replace('-', '_'), r['name'])
|
||||
pip_req_line = [l for l in pip_lines
|
||||
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
|
||||
if pip_req_line and \
|
||||
('@' not in pip_req_line[0] or
|
||||
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
|
||||
req_lines.append(pip_req_line[0])
|
||||
continue
|
||||
|
||||
req_lines.append(
|
||||
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
|
||||
continue
|
||||
|
||||
# check if we have it in our required packages
|
||||
name = r['name']
|
||||
# hack support pytorch/torch different naming convention
|
||||
if name == 'pytorch':
|
||||
name = 'torch'
|
||||
# skip over packages with _
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
|
||||
# make sure we see the conda packages, put them into the pip as well
|
||||
if conda_lines:
|
||||
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
|
||||
|
||||
requirements['pip'] = req_lines
|
||||
requirements['conda'] = conda_lines
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if freeze_full_environment:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(
|
||||
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
|
||||
conda_env_json.pop('name', None)
|
||||
conda_env_json.pop('prefix', None)
|
||||
conda_env_json.pop('channels', None)
|
||||
requirements['conda_env_json'] = json.dumps(conda_env_json)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return requirements
|
||||
|
||||
def _load_conda_full_env(self, conda_env_dict, requirements):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except Exception:
|
||||
cuda_version = 0
|
||||
|
||||
conda_env_dict['channels'] = self.extra_channels
|
||||
if 'dependencies' not in conda_env_dict:
|
||||
conda_env_dict['dependencies'] = []
|
||||
new_dependencies = OrderedDict()
|
||||
pip_requirements = None
|
||||
for line in conda_env_dict['dependencies']:
|
||||
if isinstance(line, dict):
|
||||
pip_requirements = line.pop('pip', None)
|
||||
continue
|
||||
name = line.strip().split('=', 1)[0].lower()
|
||||
if name == 'pip':
|
||||
continue
|
||||
elif name == 'python':
|
||||
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
|
||||
elif name == 'tensorflow-gpu' and cuda_version == 0:
|
||||
line = 'tensorflow={}'.format(line.split('=')[1])
|
||||
elif name == 'tensorflow' and cuda_version > 0:
|
||||
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
|
||||
elif name in ('cupti', 'cudnn'):
|
||||
# cudatoolkit should pull them based on the cudatoolkit version
|
||||
continue
|
||||
elif name.startswith('_'):
|
||||
continue
|
||||
new_dependencies[line.split('=', 1)[0].strip()] = line
|
||||
|
||||
# fix packages:
|
||||
conda_env_dict['dependencies'] = list(new_dependencies.values())
|
||||
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
|
||||
result = self._run_command(
|
||||
("env", "update", "-p", self.path, "--file", name)
|
||||
)
|
||||
|
||||
# check if we need to remove specific packages
|
||||
bad_req = self._parse_conda_result_bad_packges(result)
|
||||
if bad_req:
|
||||
print('failed installing the following conda packages: {}'.format(bad_req))
|
||||
return False
|
||||
|
||||
if pip_requirements:
|
||||
# create a list of vcs packages that we need to replace in the pip section
|
||||
vcs_reqs = {}
|
||||
if 'pip' in requirements:
|
||||
pip_lines = requirements['pip'].splitlines() \
|
||||
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
|
||||
for line in pip_lines:
|
||||
try:
|
||||
marker = list(parse(line))
|
||||
except Exception:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
if m.vcs:
|
||||
vcs_reqs[m.name] = m
|
||||
try:
|
||||
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
|
||||
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
# if we are in read only mode, do not uninstall anything
|
||||
if self.env_read_only:
|
||||
print('Conda environment in read-only mode, skipping requirements installation.')
|
||||
return None
|
||||
|
||||
# if we have a full conda environment, use it and pass the pip to pip
|
||||
if requirements.get('conda_env_json'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
conda_env_json = json.loads(requirements.get('conda_env_json'))
|
||||
print('Conda restoring full yaml environment')
|
||||
return self._load_conda_full_env(conda_env_json, requirements)
|
||||
except Exception:
|
||||
print('Could not load fully stored conda environment, falling back to requirements')
|
||||
|
||||
# create new environment file
|
||||
conda_env = dict()
|
||||
conda_env['channels'] = self.extra_channels
|
||||
reqs = [MarkerRequirement(next(parse(r))) for r in requirements['pip']]
|
||||
reqs = []
|
||||
if isinstance(requirements['pip'], six.string_types):
|
||||
requirements['pip'] = requirements['pip'].split('\n')
|
||||
if isinstance(requirements.get('conda'), six.string_types):
|
||||
requirements['conda'] = requirements['conda'].split('\n')
|
||||
has_torch = False
|
||||
has_matplotlib = False
|
||||
try:
|
||||
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||
except:
|
||||
cuda_version = 0
|
||||
|
||||
# notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
|
||||
# this should happen if experiment was executed on non-conda machine or old trains client
|
||||
conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda']
|
||||
conda_supported_req_names = []
|
||||
pip_requirements = []
|
||||
for r in conda_supported_req:
|
||||
try:
|
||||
marker = list(parse(r))
|
||||
except:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# conda does not support version control links
|
||||
if m.vcs:
|
||||
pip_requirements.append(m)
|
||||
continue
|
||||
# Skip over pip
|
||||
if m.name in ('pip', 'virtualenv', ):
|
||||
continue
|
||||
# python version, only major.minor
|
||||
if m.name == 'python' and m.specs:
|
||||
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
|
||||
if '.' not in m.specs[0][1]:
|
||||
continue
|
||||
|
||||
conda_supported_req_names.append(m.name.lower())
|
||||
if m.req.name.lower() == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
elif m.req.name.lower().startswith('torch'):
|
||||
has_torch = True
|
||||
|
||||
if m.req.name.lower() in ('torch', 'pytorch'):
|
||||
has_torch = True
|
||||
m.req.name = 'pytorch'
|
||||
|
||||
if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'):
|
||||
has_torch = True
|
||||
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
|
||||
|
||||
reqs.append(m)
|
||||
|
||||
# if we have a conda list, the rest should be installed with pip,
|
||||
if requirements.get('conda', None) is not None:
|
||||
for r in requirements['pip']:
|
||||
try:
|
||||
marker = list(parse(r))
|
||||
except:
|
||||
marker = None
|
||||
if not marker:
|
||||
continue
|
||||
|
||||
m = MarkerRequirement(marker[0])
|
||||
# skip over local files (we cannot change the version to a local file)
|
||||
if m.local_file:
|
||||
continue
|
||||
m_name = m.name.lower()
|
||||
if m_name in conda_supported_req_names:
|
||||
# this package is in the conda list,
|
||||
# make sure that if we changed version and we match it in conda
|
||||
## conda_supported_req_names.remove(m_name)
|
||||
for cr in reqs:
|
||||
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
|
||||
# match versions
|
||||
cr.specs = m.specs
|
||||
# # conda always likes "-" not "_" but only on pypi packages
|
||||
# cr.name = cr.name.lower().replace('_', '-')
|
||||
break
|
||||
else:
|
||||
# not in conda, it is a pip package
|
||||
pip_requirements.append(m)
|
||||
if m_name == 'matplotlib':
|
||||
has_matplotlib = True
|
||||
|
||||
# Conda requirements Hacks:
|
||||
if has_matplotlib:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||
|
||||
# remove specific cudatoolkit, it should have being preinstalled.
|
||||
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
|
||||
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
|
||||
|
||||
if has_torch and cuda_version == 0:
|
||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||
|
||||
# make sure we have no double entries
|
||||
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
|
||||
|
||||
# conform conda packages (version/name)
|
||||
for r in reqs:
|
||||
# change _ to - in name but not the prefix _ (as this is conda prefix)
|
||||
if not r.name.startswith('_') and not requirements.get('conda', None):
|
||||
r.name = r.name.replace('_', '-')
|
||||
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
||||
if r.specs and r.specs[0]:
|
||||
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
||||
|
||||
while reqs:
|
||||
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs]
|
||||
# notice, we give conda more freedom in version selection, to help it choose best combination
|
||||
def clean_ver(ar):
|
||||
if not ar.specs:
|
||||
return ar.tostr()
|
||||
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
|
||||
return ar.tostr()
|
||||
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
|
||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||
result = self._run_command(
|
||||
@@ -252,7 +583,7 @@ class CondaAPI(PackageManager):
|
||||
|
||||
solved = False
|
||||
for bad_r in bad_req:
|
||||
name = bad_r.split('[')[0].split('=')[0]
|
||||
name = bad_r.split('[')[0].split('=')[0].split('~')[0].split('<')[0].split('>')[0]
|
||||
# look for name in requirements
|
||||
for r in reqs:
|
||||
if r.name.lower() == name.lower():
|
||||
@@ -269,14 +600,17 @@ class CondaAPI(PackageManager):
|
||||
|
||||
if pip_requirements:
|
||||
try:
|
||||
pip_req_str = [r.tostr() for r in pip_requirements]
|
||||
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
|
||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||
self.pip.load_requirements('\n'.join(pip_req_str))
|
||||
PackageManager._selected_manager = self.pip
|
||||
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
finally:
|
||||
PackageManager._selected_manager = self
|
||||
|
||||
self.requirements_manager.post_install()
|
||||
self.requirements_manager.post_install(self.session)
|
||||
return True
|
||||
|
||||
def _parse_conda_result_bad_packges(self, result_dict):
|
||||
@@ -293,7 +627,7 @@ class CondaAPI(PackageManager):
|
||||
if len(empty_lines) >= 2:
|
||||
deps = error_lines[empty_lines[0]+1:empty_lines[1]]
|
||||
try:
|
||||
return yaml.load('\n'.join(deps))
|
||||
return yaml.load('\n'.join(deps), Loader=yaml.SafeLoader)
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
@@ -308,19 +642,28 @@ class CondaAPI(PackageManager):
|
||||
:param kwargs: kwargs for Argv.get_output()
|
||||
:return: JSON output or text output
|
||||
"""
|
||||
def escape_ansi(line):
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
|
||||
command = Argv(*command) # type: Executable
|
||||
if not raw:
|
||||
command = (self.conda,) + command + ("--quiet", "--json")
|
||||
try:
|
||||
print('Executing Conda: {}'.format(command.serialize()))
|
||||
result = command.get_output(stdin=DEVNULL, **kwargs)
|
||||
if self.session.debug_mode:
|
||||
print(result)
|
||||
except Exception as e:
|
||||
result = e.output if hasattr(e, 'output') else ''
|
||||
if self.session.debug_mode:
|
||||
print(result)
|
||||
if raw:
|
||||
raise
|
||||
result = e.output if hasattr(e, 'output') else ''
|
||||
if raw:
|
||||
return result
|
||||
result = json.loads(result) if result else {}
|
||||
|
||||
result = json.loads(escape_ansi(result)) if result else {}
|
||||
if result.get('success', False):
|
||||
print('Pass')
|
||||
elif result.get('error'):
|
||||
@@ -330,8 +673,22 @@ class CondaAPI(PackageManager):
|
||||
def get_python_command(self, extra=()):
|
||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||
|
||||
def _get_conda_sh(self):
|
||||
# type () -> Path
|
||||
base_conda_env = Path(self.conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if base_conda_env.is_file():
|
||||
return base_conda_env
|
||||
for path in os.environ.get('PATH', '').split(select_for_platform(windows=';', linux=':')):
|
||||
conda = find_executable("conda", path=path)
|
||||
if not conda:
|
||||
continue
|
||||
conda_env = Path(conda).parent.parent / 'etc' / 'profile.d' / 'conda.sh'
|
||||
if conda_env.is_file():
|
||||
return conda_env
|
||||
return base_conda_env
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on unhashable exceptions
|
||||
|
||||
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
|
||||
exception = attrs(str=True, cmp=False)
|
||||
|
||||
|
||||
@@ -362,4 +719,4 @@ class PackageNotFoundError(CondaException):
|
||||
as a singleton YAML list.
|
||||
"""
|
||||
|
||||
pkg = attrib(default="", converter=lambda val: yaml.load(val)[0].replace(" ", ""))
|
||||
pkg = attrib(default="", converter=lambda val: yaml.load(val, Loader=yaml.SafeLoader)[0].replace(" ", ""))
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class CythonRequirement(SimpleSubstitution):
|
||||
|
||||
name = "cython"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CythonRequirement, self).__init__(*args, **kwargs)
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return self.name == req.name.lower()
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# install Cython before
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
106
trains_agent/helper/package/external_req.py
Normal file
106
trains_agent/helper/package/external_req.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
from ..base import safe_furl as furl
|
||||
|
||||
|
||||
class ExternalRequirements(SimpleSubstitution):
|
||||
|
||||
name = "external_link"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ExternalRequirements, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = []
|
||||
self.post_install_req_lookup = OrderedDict()
|
||||
|
||||
def match(self, req):
|
||||
# match both editable or code or unparsed
|
||||
if not (not req.name or req.req and (req.req.editable or req.req.vcs)):
|
||||
return False
|
||||
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
|
||||
return False
|
||||
if req.pip_new_version and not (req.req.editable or req.req.vcs):
|
||||
return False
|
||||
return True
|
||||
|
||||
def post_install(self, session):
|
||||
post_install_req = self.post_install_req
|
||||
self.post_install_req = []
|
||||
for req in post_install_req:
|
||||
try:
|
||||
freeze_base = PackageManager.out_of_scope_freeze() or ''
|
||||
except:
|
||||
freeze_base = ''
|
||||
|
||||
req_line = req.tostr(markers=False)
|
||||
if req_line.strip().startswith('-e ') or req_line.strip().startswith('--editable'):
|
||||
req_line = re.sub(r'^(-e|--editable=?)\s*', '', req_line, count=1)
|
||||
|
||||
if req.req.vcs and req_line.startswith('git+'):
|
||||
try:
|
||||
url_no_frag = furl(req_line)
|
||||
url_no_frag.set(fragment=None)
|
||||
# reverse replace
|
||||
fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1]
|
||||
vcs_url = req_line[4:]
|
||||
# reverse replace
|
||||
vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1]
|
||||
from ..repo import Git
|
||||
vcs = Git(session=session, url=vcs_url, location=None, revision=None)
|
||||
vcs._set_ssh_url()
|
||||
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
|
||||
if new_req_line != req_line:
|
||||
furl_line = furl(new_req_line)
|
||||
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
|
||||
req_line,
|
||||
furl_line.set(password='xxxxxx').tostr() if furl_line.password else new_req_line))
|
||||
req_line = new_req_line
|
||||
except Exception:
|
||||
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
|
||||
|
||||
# if we have older pip version we have to make sure we replace back the package name with the
|
||||
# git repository link. In new versions this is supported and we get "package @ git+https://..."
|
||||
if not req.pip_new_version:
|
||||
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
freeze_post = PackageManager.out_of_scope_freeze() or ''
|
||||
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
|
||||
if package_name and package_name[0] not in self.post_install_req_lookup:
|
||||
self.post_install_req_lookup[package_name[0]] = req.req.line
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# no need to force reinstall, pip will always rebuilt if the package comes from git
|
||||
# and make sure the required packages are installed (if they are not it will install them)
|
||||
if not PackageManager.out_of_scope_install_package(req_line):
|
||||
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req.append(req)
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
|
||||
def replace_back(self, list_of_requirements):
|
||||
if not list_of_requirements:
|
||||
return list_of_requirements
|
||||
|
||||
for k in list_of_requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = list_of_requirements[k]
|
||||
list_of_requirements[k] = [r for r in original_requirements
|
||||
if r not in self.post_install_req_lookup]
|
||||
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
|
||||
for r in self.post_install_req_lookup.keys() if r in original_requirements]
|
||||
return list_of_requirements
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class HorovodRequirement(SimpleSubstitution):
|
||||
|
||||
name = "horovod"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(HorovodRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return self.name == req.name.lower()
|
||||
|
||||
def post_install(self):
|
||||
if self.post_install_req:
|
||||
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
|
||||
self.post_install_req = None
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req = req
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
@@ -1,22 +1,24 @@
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import Text
|
||||
from typing import Text, Optional
|
||||
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES, PROGRAM_NAME
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from trains_agent.session import Session
|
||||
|
||||
|
||||
class SystemPip(PackageManager):
|
||||
|
||||
indices_args = None
|
||||
|
||||
def __init__(self, interpreter=None):
|
||||
# type: (Text) -> ()
|
||||
def __init__(self, interpreter=None, session=None):
|
||||
# type: (Optional[Text], Optional[Session]) -> ()
|
||||
"""
|
||||
Program interface to the system pip.
|
||||
"""
|
||||
self._bin = interpreter or sys.executable
|
||||
self.session = session
|
||||
|
||||
@property
|
||||
def bin(self):
|
||||
@@ -29,13 +31,13 @@ class SystemPip(PackageManager):
|
||||
pass
|
||||
|
||||
def install_from_file(self, path):
|
||||
self.run_with_env(('install', '-r', path) + self.install_flags())
|
||||
self.run_with_env(('install', '-r', path) + self.install_flags(), cwd=self.cwd)
|
||||
|
||||
def install_packages(self, *packages):
|
||||
self._install(*(packages + self.install_flags()))
|
||||
|
||||
def _install(self, *args):
|
||||
self.run_with_env(('install',) + args)
|
||||
self.run_with_env(('install',) + args, cwd=self.cwd)
|
||||
|
||||
def uninstall_packages(self, *packages):
|
||||
self.run_with_env(('uninstall', '-y') + packages)
|
||||
@@ -82,7 +84,7 @@ class SystemPip(PackageManager):
|
||||
return (command.get_output if output else command.check_call)(stdin=DEVNULL, **kwargs)
|
||||
|
||||
def _make_command(self, command):
|
||||
return Argv(self.bin, '-m', 'pip', *command)
|
||||
return Argv(self.bin, '-m', 'pip', '--disable-pip-version-check', *command)
|
||||
|
||||
def install_flags(self):
|
||||
if self.indices_args is None:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree
|
||||
from trains_agent.helper.base import select_for_platform, rm_tree, ExecutionInfo
|
||||
from trains_agent.helper.package.base import PackageManager
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session
|
||||
@@ -9,37 +11,35 @@ from ..requirements import RequirementsManager
|
||||
|
||||
|
||||
class VirtualenvPip(SystemPip, PackageManager):
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike) -> ()
|
||||
def __init__(self, session, python, requirements_manager, path, interpreter=None, execution_info=None, **kwargs):
|
||||
# type: (Session, float, RequirementsManager, PathLike, PathLike, ExecutionInfo, Any) -> ()
|
||||
"""
|
||||
Program interface to virtualenv pip.
|
||||
Must be given either path to virtualenv or source command.
|
||||
Either way, ``self.source`` is exposed.
|
||||
:param session: a Session object for communication
|
||||
:param python: interpreter path
|
||||
:param path: path of virtual environment to create/manipulate
|
||||
:param python: python version
|
||||
:param interpreter: path of python interpreter
|
||||
"""
|
||||
super(VirtualenvPip, self).__init__(
|
||||
interpreter
|
||||
or Path(
|
||||
path,
|
||||
select_for_platform(linux="bin/python", windows="scripts/python.exe"),
|
||||
)
|
||||
session=session,
|
||||
interpreter=interpreter or Path(
|
||||
path, select_for_platform(linux="bin/python", windows="scripts/python.exe"))
|
||||
)
|
||||
self.session = session
|
||||
self.path = path
|
||||
self.requirements_manager = requirements_manager
|
||||
self.python = "python{}".format(python)
|
||||
self.python = python
|
||||
|
||||
def _make_command(self, command):
|
||||
return self.session.command(self.bin, "-m", "pip", *command)
|
||||
return self.session.command(self.bin, "-m", "pip", "--disable-pip-version-check", *command)
|
||||
|
||||
def load_requirements(self, requirements):
|
||||
if isinstance(requirements, dict) and requirements.get("pip"):
|
||||
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
|
||||
super(VirtualenvPip, self).load_requirements(requirements)
|
||||
self.requirements_manager.post_install()
|
||||
self.requirements_manager.post_install(self.session)
|
||||
|
||||
def create_flags(self):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
|
||||
import attr
|
||||
import sys
|
||||
import os
|
||||
from pathlib2 import Path
|
||||
from trains_agent.helper.process import Argv, DEVNULL
|
||||
from trains_agent.helper.process import Argv, DEVNULL, check_if_command_exists
|
||||
from trains_agent.session import Session, POETRY
|
||||
|
||||
|
||||
@@ -35,10 +38,12 @@ def prop_guard(prop, log_prop=None):
|
||||
|
||||
class PoetryConfig:
|
||||
|
||||
def __init__(self, session):
|
||||
# type: (Session) -> ()
|
||||
def __init__(self, session, interpreter=None):
|
||||
# type: (Session, str) -> ()
|
||||
self.session = session
|
||||
self._log = session.get_logger(__name__)
|
||||
self._python = interpreter or sys.executable
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
@@ -53,7 +58,20 @@ class PoetryConfig:
|
||||
def run(self, *args, **kwargs):
|
||||
func = kwargs.pop("func", Argv.get_output)
|
||||
kwargs.setdefault("stdin", DEVNULL)
|
||||
argv = Argv("poetry", "-n", *args)
|
||||
kwargs['env'] = deepcopy(os.environ)
|
||||
if 'VIRTUAL_ENV' in kwargs['env'] or 'CONDA_PREFIX' in kwargs['env']:
|
||||
kwargs['env'].pop('VIRTUAL_ENV', None)
|
||||
kwargs['env'].pop('CONDA_PREFIX', None)
|
||||
kwargs['env'].pop('PYTHONPATH', None)
|
||||
if hasattr(sys, "real_prefix") and hasattr(sys, "base_prefix"):
|
||||
path = ':'+kwargs['env']['PATH']
|
||||
path = path.replace(':'+sys.base_prefix, ':'+sys.real_prefix, 1)
|
||||
kwargs['env']['PATH'] = path
|
||||
|
||||
if check_if_command_exists("poetry"):
|
||||
argv = Argv("poetry", *args)
|
||||
else:
|
||||
argv = Argv(self._python, "-m", "poetry", *args)
|
||||
self.log.debug("running: %s", argv)
|
||||
return func(argv, **kwargs)
|
||||
|
||||
@@ -61,10 +79,16 @@ class PoetryConfig:
|
||||
return self.run("config", *args, **kwargs)
|
||||
|
||||
@_guard_enabled
|
||||
def initialize(self):
|
||||
self._config("settings.virtualenvs.in-project", "true")
|
||||
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
|
||||
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
|
||||
def initialize(self, cwd=None):
|
||||
if not self._initialized:
|
||||
self._initialized = True
|
||||
try:
|
||||
self._config("--local", "virtualenvs.in-project", "true", cwd=cwd)
|
||||
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
|
||||
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
|
||||
except Exception as ex:
|
||||
print("Exception: {}\nError: Failed configuring Poetry virtualenvs.in-project".format(ex))
|
||||
raise
|
||||
|
||||
def get_api(self, path):
|
||||
# type: (Path) -> PoetryAPI
|
||||
@@ -81,7 +105,7 @@ class PoetryAPI(object):
|
||||
def install(self):
|
||||
# type: () -> bool
|
||||
if self.enabled:
|
||||
self.config.run("install", cwd=str(self.path), func=Argv.check_call)
|
||||
self.config.run("install", "-n", cwd=str(self.path), func=Argv.check_call)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -92,7 +116,24 @@ class PoetryAPI(object):
|
||||
)
|
||||
|
||||
def freeze(self):
|
||||
return {"poetry": self.config.run("show", cwd=str(self.path)).splitlines()}
|
||||
lines = self.config.run("show", cwd=str(self.path)).splitlines()
|
||||
lines = [[p for p in line.split(' ') if p] for line in lines]
|
||||
return {"pip": [parts[0]+'=='+parts[1]+' # '+' '.join(parts[2:]) for parts in lines]}
|
||||
|
||||
def get_python_command(self, extra):
|
||||
return Argv("poetry", "run", "python", *extra)
|
||||
if check_if_command_exists("poetry"):
|
||||
return Argv("poetry", "run", "python", *extra)
|
||||
else:
|
||||
return Argv(self.config._python, "-m", "poetry", "run", "python", *extra)
|
||||
|
||||
def upgrade_pip(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_selected_package_manager(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def out_of_scope_install_package(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def install_from_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
48
trains_agent/helper/package/post_req.py
Normal file
48
trains_agent/helper/package/post_req.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PostRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("horovod", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PostRequirement, self).__init__(*args, **kwargs)
|
||||
self.post_install_req = []
|
||||
# check if we need to replace the packages:
|
||||
post_packages = self.config.get('agent.package_manager.post_packages', None)
|
||||
if post_packages:
|
||||
self.__class__.name = post_packages
|
||||
post_optional_packages = self.config.get('agent.package_manager.post_optional_packages', None)
|
||||
if post_optional_packages:
|
||||
self.__class__.optional_package_names = post_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both horovod
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def post_install(self, session):
|
||||
for req in self.post_install_req:
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
PackageManager.out_of_scope_install_package(req.tostr(markers=False))
|
||||
|
||||
self.post_install_req = []
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
# Store in post req install, and return nothing
|
||||
self.post_install_req.append(req)
|
||||
# mark skip package, we will install it in post install hook
|
||||
return Text('')
|
||||
75
trains_agent/helper/package/priority_req.py
Normal file
75
trains_agent/helper/package/priority_req.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Text
|
||||
|
||||
from .base import PackageManager
|
||||
from .requirements import SimpleSubstitution
|
||||
|
||||
|
||||
class PriorityPackageRequirement(SimpleSubstitution):
|
||||
|
||||
name = ("cython", "numpy", "setuptools", )
|
||||
optional_package_names = tuple()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PriorityPackageRequirement, self).__init__(*args, **kwargs)
|
||||
# check if we need to replace the packages:
|
||||
priority_packages = self.config.get('agent.package_manager.priority_packages', None)
|
||||
if priority_packages:
|
||||
self.__class__.name = priority_packages
|
||||
priority_optional_packages = self.config.get('agent.package_manager.priority_optional_packages', None)
|
||||
if priority_optional_packages:
|
||||
self.__class__.optional_package_names = priority_optional_packages
|
||||
|
||||
def match(self, req):
|
||||
# match both Cython & cython
|
||||
return req.name and (req.name.lower() in self.name or req.name.lower() in self.optional_package_names)
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
if req.name in self.optional_package_names:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if PackageManager.out_of_scope_install_package(str(req)):
|
||||
return Text(req)
|
||||
except Exception:
|
||||
pass
|
||||
return Text('')
|
||||
PackageManager.out_of_scope_install_package(str(req))
|
||||
return Text(req)
|
||||
|
||||
|
||||
class PackageCollectorRequirement(SimpleSubstitution):
|
||||
"""
|
||||
This RequirementSubstitution class will allow you to have multiple instances of the same
|
||||
package, it will output the last one (by order) to be actually used.
|
||||
"""
|
||||
name = tuple()
|
||||
|
||||
def __init__(self, session, collect_package):
|
||||
super(PackageCollectorRequirement, self).__init__(session)
|
||||
self._collect_packages = collect_package or tuple()
|
||||
self._last_req = None
|
||||
|
||||
def match(self, req):
|
||||
# match package names
|
||||
return req.name and req.name.lower() in self._collect_packages
|
||||
|
||||
def replace(self, req):
|
||||
"""
|
||||
Replace a requirement
|
||||
:raises: ValueError if version is pre-release
|
||||
"""
|
||||
self._last_req = req.clone()
|
||||
return ''
|
||||
|
||||
def post_scan_add_req(self):
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
last_req = self._last_req
|
||||
self._last_req = None
|
||||
return last_req
|
||||
@@ -10,10 +10,9 @@ from typing import Text
|
||||
|
||||
import attr
|
||||
import requests
|
||||
from semantic_version import Version, Spec
|
||||
|
||||
import six
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
|
||||
|
||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||
|
||||
@@ -75,6 +74,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
|
||||
page_lookup_template = 'https://download.pytorch.org/whl/cu{}/torch_stable.html'
|
||||
nightly_page_lookup_template = 'https://download.pytorch.org/whl/nightly/cu{}/torch_nightly.html'
|
||||
torch_page_lookup = {
|
||||
0: 'https://download.pytorch.org/whl/cpu/torch_stable.html',
|
||||
80: 'https://download.pytorch.org/whl/cu80/torch_stable.html',
|
||||
@@ -82,6 +82,8 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
92: 'https://download.pytorch.org/whl/cu92/torch_stable.html',
|
||||
100: 'https://download.pytorch.org/whl/cu100/torch_stable.html',
|
||||
101: 'https://download.pytorch.org/whl/cu101/torch_stable.html',
|
||||
102: 'https://download.pytorch.org/whl/cu102/torch_stable.html',
|
||||
110: 'https://download.pytorch.org/whl/cu110/torch_stable.html',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -116,23 +118,42 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
package_manager.add_extra_install_flags(('-f', extra_url))
|
||||
|
||||
@classmethod
|
||||
def get_torch_page(cls, cuda_version):
|
||||
def get_torch_page(cls, cuda_version, nightly=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cuda = int(cuda_version)
|
||||
except:
|
||||
except Exception:
|
||||
cuda = 0
|
||||
|
||||
if nightly:
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
# then try the nightly builds, it might be there...
|
||||
torch_url = cls.nightly_page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch nightly CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# first check if key is valid
|
||||
if cuda in cls.torch_page_lookup:
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
|
||||
# then try a new cuda version page
|
||||
torch_url = cls.page_lookup_template.format(cuda)
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
cls.torch_page_lookup[cuda] = torch_url
|
||||
return cls.torch_page_lookup[cuda], cuda
|
||||
except Exception:
|
||||
pass
|
||||
for c in range(cuda, max(-1, cuda-15), -1):
|
||||
torch_url = cls.page_lookup_template.format(c)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if requests.get(torch_url, timeout=10).ok:
|
||||
print('Torch CUDA {} download page found'.format(c))
|
||||
cls.torch_page_lookup[c] = torch_url
|
||||
return cls.torch_page_lookup[c], c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
keys = sorted(cls.torch_page_lookup.keys(), reverse=True)
|
||||
for k in keys:
|
||||
@@ -145,7 +166,7 @@ class SimplePytorchRequirement(SimpleSubstitution):
|
||||
class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
name = "torch"
|
||||
packages = ("torch", "torchvision", "torchaudio")
|
||||
packages = ("torch", "torchvision", "torchaudio", "torchcsprng", "torchtext")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os_name = kwargs.pop("os_override", None)
|
||||
@@ -155,10 +176,15 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
self.os = os_name or self.get_platform()
|
||||
self.cuda = "cuda{}".format(self.cuda_version).lower()
|
||||
self.python_version_string = str(self.config["agent.default_python"])
|
||||
self.python_semantic_version = Version.coerce(
|
||||
self.python_version_string, partial=True
|
||||
)
|
||||
self.python = "python{}.{}".format(self.python_semantic_version.major, self.python_semantic_version.minor)
|
||||
self.python_major_minor_str = '.'.join(self.python_version_string.split('.')[:2])
|
||||
if '.' not in self.python_major_minor_str:
|
||||
raise PytorchResolutionError(
|
||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||
"must have both major and minor parts of the version (for example: '3.7')".format(
|
||||
self.python_version_string
|
||||
)
|
||||
)
|
||||
self.python = "python{}".format(self.python_major_minor_str)
|
||||
|
||||
self.exceptions = [
|
||||
PytorchResolutionError(message)
|
||||
@@ -176,6 +202,8 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
except PytorchResolutionError as e:
|
||||
self.log.warn("will not be able to install pytorch wheels: %s", e.args[0])
|
||||
|
||||
self._original_req = []
|
||||
|
||||
@property
|
||||
def is_conda(self):
|
||||
return self.package_manager == "conda"
|
||||
@@ -188,9 +216,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
"""
|
||||
Make sure python version has both major and minor versions as required for choosing pytorch wheel
|
||||
"""
|
||||
if self.is_pip and not (
|
||||
self.python_semantic_version.major and self.python_semantic_version.minor
|
||||
):
|
||||
if self.is_pip and not self.python_major_minor_str:
|
||||
raise PytorchResolutionError(
|
||||
"invalid python version {!r} defined in configuration file, key 'agent.default_python': "
|
||||
"must have both major and minor parts of the version (for example: '3.7')".format(
|
||||
@@ -215,8 +241,10 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
links_parser = LinksHTMLParser()
|
||||
links_parser.feed(requests.get(torch_url, timeout=10).text)
|
||||
platform_wheel = "win" if self.get_platform() == "windows" else self.get_platform()
|
||||
py_ver = "{0.major}{0.minor}".format(self.python_semantic_version)
|
||||
py_ver = self.python_major_minor_str.replace('.', '')
|
||||
url = None
|
||||
last_v = None
|
||||
closest_v = None
|
||||
# search for our package
|
||||
for l in links_parser.links:
|
||||
parts = l.split('/')[-1].split('-')
|
||||
@@ -225,54 +253,108 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if parts[0] != req.name:
|
||||
continue
|
||||
# version (ignore +cpu +cu92 etc. + is %2B in the file link)
|
||||
if parts[1].split('%')[0].split('+')[0] != req.specs[0][1]:
|
||||
# version ignore .postX suffix (treat as regular version)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = str(parts[1].split('%')[0].split('+')[0])
|
||||
except Exception:
|
||||
continue
|
||||
if not parts[2].endswith(py_ver):
|
||||
if len(parts) < 3 or not parts[2].endswith(py_ver):
|
||||
continue
|
||||
if platform_wheel not in parts[4]:
|
||||
if len(parts) < 5 or platform_wheel not in parts[4]:
|
||||
continue
|
||||
# update the closest matched version (from above)
|
||||
if not closest_v:
|
||||
closest_v = v
|
||||
elif SimpleVersion.compare_versions(
|
||||
version_a=closest_v, op='>=', version_b=v, num_parts=3) and \
|
||||
SimpleVersion.compare_versions(
|
||||
version_a=v, op='>=', version_b=req.specs[0][1], num_parts=3):
|
||||
closest_v = v
|
||||
# check if this an actual match
|
||||
if not req.compare_version(v) or \
|
||||
(last_v and SimpleVersion.compare_versions(last_v, '>', v, ignore_sub_versions=False)):
|
||||
continue
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
break
|
||||
|
||||
return url
|
||||
url = '/'.join(torch_url.split('/')[:-1] + l.split('/'))
|
||||
last_v = v
|
||||
# if we found an exact match, use it
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if req.specs[0][0] == '==' and \
|
||||
SimpleVersion.compare_versions(req.specs[0][1], '==', v, ignore_sub_versions=False):
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url, last_v or closest_v
|
||||
|
||||
def get_url_for_platform(self, req):
|
||||
assert self.package_manager == "pip"
|
||||
assert self.os != "mac"
|
||||
assert req.specs
|
||||
# check if package is already installed with system packages
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self.config.get("agent.package_manager.system_site_packages", None):
|
||||
from pip._internal.commands.show import search_packages_info
|
||||
installed_torch = list(search_packages_info([req.name]))
|
||||
# notice the comparison order, the first part will make sure we have a valid installed package
|
||||
if installed_torch and installed_torch[0]['version'] and \
|
||||
req.compare_version(installed_torch[0]['version']):
|
||||
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
||||
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
|
||||
# package already installed, do nothing
|
||||
req.specs = [('==', str(installed_torch[0]['version']))]
|
||||
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# make sure we have a specific version to retrieve
|
||||
if not req.specs:
|
||||
req.specs = [('>', '0')]
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
req.specs[0] = (req.specs[0][0], req.specs[0][1].split('+')[0])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
op, version = req.specs[0]
|
||||
# assert op == "=="
|
||||
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if not url and self.config.get("agent.package_manager.torch_nightly", None):
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
# try one more time, with a lower cuda version (never fallback to CPU):
|
||||
while not url and torch_url_key > 0:
|
||||
previous_cuda_key = torch_url_key
|
||||
print('Warning, could not locate PyTorch {} matching CUDA version {}, best candidate {}\n'.format(
|
||||
req, previous_cuda_key, closest_matched_version))
|
||||
url, closest_matched_version = self._get_link_from_torch_page(req, torch_url)
|
||||
if url:
|
||||
break
|
||||
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(int(torch_url_key)-1)
|
||||
# never fallback to CPU
|
||||
if torch_url_key < 1:
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError('Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, self.cuda_version))
|
||||
print('Warning! Could not locate PyTorch version {} matching CUDA version {}, trying CUDA version {}'.format(
|
||||
req, previous_cuda_key, torch_url_key))
|
||||
url = self._get_link_from_torch_page(req, torch_url)
|
||||
print(
|
||||
'Error! Could not locate PyTorch version {} matching CUDA version {}'.format(
|
||||
req, previous_cuda_key))
|
||||
raise ValueError(
|
||||
'Could not locate PyTorch version {} matching CUDA version {}'.format(req, self.cuda_version))
|
||||
else:
|
||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
torch_version=fix_version(version),
|
||||
python="{0.major}{0.minor}".format(self.python_semantic_version),
|
||||
python=self.python_major_minor_str.replace('.', ''),
|
||||
os_name=self.os,
|
||||
cuda_version=self.cuda_version,
|
||||
).make_url()
|
||||
if url:
|
||||
# normalize url (sometimes we will get ../ which we should not...
|
||||
url = '/'.join(url.split('/')[:3]) + urllib.parse.quote(str(furl(url).path.normalize()))
|
||||
# print found
|
||||
print('Found PyTorch version {} matching CUDA version {}'.format(req, torch_url_key))
|
||||
|
||||
self.log.debug("checking url: %s", url)
|
||||
return url, requests.head(url, timeout=10).ok
|
||||
@@ -280,20 +362,17 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
@staticmethod
|
||||
def match_version(req, options):
|
||||
versioned_options = sorted(
|
||||
((Version(fix_version(key)), value) for key, value in options.items()),
|
||||
((fix_version(key), value) for key, value in options.items()),
|
||||
key=itemgetter(0),
|
||||
reverse=True,
|
||||
)
|
||||
req.specs = [(op, fix_version(version)) for op, version in req.specs]
|
||||
if req.specs:
|
||||
specs = Spec(req.format_specs())
|
||||
else:
|
||||
specs = None
|
||||
|
||||
try:
|
||||
return next(
|
||||
replacement
|
||||
for version, replacement in versioned_options
|
||||
if not specs or version in specs
|
||||
if req.compare_version(version)
|
||||
)
|
||||
except StopIteration:
|
||||
raise PytorchResolutionError(
|
||||
@@ -342,7 +421,10 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
def replace(self, req):
|
||||
try:
|
||||
return self._replace(req)
|
||||
new_req = self._replace(req)
|
||||
if new_req:
|
||||
self._original_req.append((req, new_req))
|
||||
return new_req
|
||||
except Exception as e:
|
||||
message = "Exception when trying to resolve python wheel"
|
||||
self.log.debug(message, exc_info=True)
|
||||
@@ -357,17 +439,17 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
result = self._table_lookup(req)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
else:
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
# try:
|
||||
# result = self._table_lookup(req)
|
||||
# except Exception as e:
|
||||
# exc = e
|
||||
# else:
|
||||
# self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
# return result
|
||||
# self.log.debug(
|
||||
# "Could not find Pytorch wheel in table, trying manually constructing URL"
|
||||
# )
|
||||
|
||||
self.log.debug(
|
||||
"Could not find Pytorch wheel in table, trying manually constructing URL"
|
||||
)
|
||||
result = ok = None
|
||||
# try:
|
||||
# result, ok = self.get_url_for_platform(req)
|
||||
@@ -378,7 +460,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
if result:
|
||||
self.log.debug("URL not found: {}".format(result))
|
||||
exc = PytorchResolutionError(
|
||||
"Was not able to find pytorch wheel URL: {}".format(exc)
|
||||
"Could not find pytorch wheel URL for: {} with cuda {} support".format(req, self.cuda_version)
|
||||
)
|
||||
# cancel exception chaining
|
||||
six.raise_from(exc, None)
|
||||
@@ -386,6 +468,43 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
self.log.debug('Replacing requirement "%s" with %r', req, result)
|
||||
return result
|
||||
|
||||
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
|
||||
"""
|
||||
:param list_of_requirements: {'pip': ['a==1.0', ]}
|
||||
:return: {'pip': ['a==1.0', ]}
|
||||
"""
|
||||
if not self._original_req:
|
||||
return list_of_requirements
|
||||
try:
|
||||
for k, lines in list_of_requirements.items():
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
for i, line in enumerate(lines):
|
||||
if not line or line.lstrip().startswith('#'):
|
||||
continue
|
||||
parts = [p for p in re.split('\s|=|\.|<|>|~|!|@|#', line) if p]
|
||||
if not parts:
|
||||
continue
|
||||
for req, new_req in self._original_req:
|
||||
if req.req.name == parts[0]:
|
||||
# support for pip >= 20.1
|
||||
if '@' in line:
|
||||
# skip if we have nothing to add
|
||||
if str(req).strip() != str(new_req).strip():
|
||||
# if this is local file and use the version detection
|
||||
if req.local_file:
|
||||
lines[i] = '{}'.format(str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(str(req), str(new_req))
|
||||
else:
|
||||
lines[i] = '{} # {}'.format(line, str(new_req))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
return list_of_requirements
|
||||
|
||||
MAP = {
|
||||
"windows": {
|
||||
"cuda100": {
|
||||
|
||||
@@ -4,24 +4,23 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
from itertools import chain, starmap
|
||||
from operator import itemgetter
|
||||
from os import path
|
||||
from typing import Text, List, Type, Optional, Tuple
|
||||
from typing import Text, List, Type, Optional, Tuple, Dict
|
||||
|
||||
import semantic_version
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
from requirements import parse
|
||||
# noinspection PyPackageRequirements
|
||||
from requirements.requirement import Requirement
|
||||
|
||||
import six
|
||||
from trains_agent.definitions import PIP_EXTRA_INDICES
|
||||
from trains_agent.helper.base import warning, is_conda, which, join_lines, is_windows_platform
|
||||
from trains_agent.helper.process import Argv, PathLike
|
||||
from trains_agent.session import Session, normalize_cuda_version
|
||||
from trains_agent.external.requirements_parser import parse
|
||||
from trains_agent.external.requirements_parser.requirement import Requirement
|
||||
|
||||
from .translator import RequirementsTranslator
|
||||
|
||||
|
||||
@@ -36,6 +35,10 @@ class FatalSpecsResolutionError(Exception):
|
||||
@six.python_2_unicode_compatible
|
||||
class MarkerRequirement(object):
|
||||
|
||||
# if True pip version above 20.x and with support for "package @ scheme://link"
|
||||
# default is True
|
||||
pip_new_version = True
|
||||
|
||||
def __init__(self, req): # type: (Requirement) -> None
|
||||
self.req = req
|
||||
|
||||
@@ -48,14 +51,28 @@ class MarkerRequirement(object):
|
||||
|
||||
def tostr(self, markers=True):
|
||||
if not self.uri:
|
||||
parts = [self.name]
|
||||
parts = [self.name or self.line]
|
||||
|
||||
if self.extras:
|
||||
parts.append('[{0}]'.format(','.join(sorted(self.extras))))
|
||||
|
||||
if self.specifier:
|
||||
parts.append(self.format_specs())
|
||||
|
||||
elif self.vcs:
|
||||
# leave the line as is, let pip handle it
|
||||
if self.line:
|
||||
return self.line
|
||||
else:
|
||||
# let's build the line manually
|
||||
parts = [
|
||||
self.uri,
|
||||
'@{}'.format(self.revision) if self.revision else '',
|
||||
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
|
||||
]
|
||||
elif self.pip_new_version and self.uri and self.name and self.line and self.local_file:
|
||||
# package @ file:///example.com/somewheel.whl
|
||||
# leave the line as is, let pip handle it
|
||||
return self.line
|
||||
else:
|
||||
parts = [self.uri]
|
||||
|
||||
@@ -64,13 +81,27 @@ class MarkerRequirement(object):
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
def clone(self):
|
||||
return MarkerRequirement(copy(self.req))
|
||||
|
||||
__str__ = tostr
|
||||
|
||||
def __repr__(self):
|
||||
return '{self.__class__.__name__}[{self}]'.format(self=self)
|
||||
|
||||
def format_specs(self):
|
||||
return ','.join(starmap(operator.add, self.specs))
|
||||
def format_specs(self, num_parts=None, max_num_parts=None):
|
||||
max_num_parts = max_num_parts or num_parts
|
||||
if max_num_parts is None or not self.specs:
|
||||
return ','.join(starmap(operator.add, self.specs))
|
||||
|
||||
op, version = self.specs[0]
|
||||
for v in self._sub_versions_pep440:
|
||||
version = version.replace(v, '.')
|
||||
if num_parts:
|
||||
version = (version.strip('.').split('.') + ['0'] * num_parts)[:max_num_parts]
|
||||
else:
|
||||
version = version.strip('.').split('.')[:max_num_parts]
|
||||
return op+'.'.join(version)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.req, item)
|
||||
@@ -99,6 +130,187 @@ class MarkerRequirement(object):
|
||||
else:
|
||||
self.specs = greater + smaller
|
||||
|
||||
def compare_version(self, requested_version, op=None, num_parts=3):
|
||||
"""
|
||||
compare the requested version with the one we have in the spec,
|
||||
If the requested version is 1.2.3 the self.spec should be 1.2.3*
|
||||
If the requested version is 1.2 the self.spec should be 1.2*
|
||||
etc.
|
||||
|
||||
:param str requested_version:
|
||||
:param str op: '==', '>', '>=', '<=', '<', '~='
|
||||
:param int num_parts: number of parts to compare
|
||||
:return: True if we answer the requested version
|
||||
"""
|
||||
# if we have no specific version, we cannot compare, so assume it's okay
|
||||
if not self.specs:
|
||||
return True
|
||||
|
||||
version = self.specs[0][1]
|
||||
op = (op or self.specs[0][0]).strip()
|
||||
|
||||
return SimpleVersion.compare_versions(
|
||||
version_a=requested_version, op=op, version_b=version, num_parts=num_parts)
|
||||
|
||||
|
||||
class SimpleVersion:
|
||||
_sub_versions_pep440 = ['a', 'b', 'rc', '.post', '.dev', '+', ]
|
||||
VERSION_PATTERN = r"""
|
||||
v?
|
||||
(?:
|
||||
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||
(?P<pre> # pre-release
|
||||
[-_\.]?
|
||||
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||
[-_\.]?
|
||||
(?P<pre_n>[0-9]+)?
|
||||
)?
|
||||
(?P<post> # post release
|
||||
(?:-(?P<post_n1>[0-9]+))
|
||||
|
|
||||
(?:
|
||||
[-_\.]?
|
||||
(?P<post_l>post|rev|r)
|
||||
[-_\.]?
|
||||
(?P<post_n2>[0-9]+)?
|
||||
)
|
||||
)?
|
||||
(?P<dev> # dev release
|
||||
[-_\.]?
|
||||
(?P<dev_l>dev)
|
||||
[-_\.]?
|
||||
(?P<dev_n>[0-9]+)?
|
||||
)?
|
||||
)
|
||||
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||
"""
|
||||
_local_version_separators = re.compile(r"[\._-]")
|
||||
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||
|
||||
@classmethod
|
||||
def compare_versions(cls, version_a, op, version_b, ignore_sub_versions=True, num_parts=3):
|
||||
"""
|
||||
Compare two versions based on the op operator
|
||||
returns bool(version_a op version_b)
|
||||
Notice: Ignores a/b/rc/post/dev markers on the version
|
||||
|
||||
:param str version_a:
|
||||
:param str op: '==', '===', '>', '>=', '<=', '<', '~='
|
||||
:param str version_b:
|
||||
:param bool ignore_sub_versions: if true compare only major.minor.patch
|
||||
(ignore a/b/rc/post/dev in the comparison)
|
||||
:param int num_parts: number of parts to compare, split by . (dot)
|
||||
:return bool: version_a op version_b
|
||||
"""
|
||||
|
||||
if not version_b:
|
||||
return True
|
||||
|
||||
if op == '~=':
|
||||
num_parts = max(num_parts, 2)
|
||||
op = '=='
|
||||
ignore_sub_versions = True
|
||||
elif op == '===':
|
||||
op = '=='
|
||||
|
||||
try:
|
||||
version_a_key = cls._get_match_key(cls._regex.search(version_a), num_parts, ignore_sub_versions)
|
||||
version_b_key = cls._get_match_key(cls._regex.search(version_b), num_parts, ignore_sub_versions)
|
||||
except:
|
||||
# revert to string based
|
||||
for v in cls._sub_versions_pep440:
|
||||
version_a = version_a.replace(v, '.')
|
||||
version_b = version_b.replace(v, '.')
|
||||
|
||||
version_a = (version_a.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||
version_b = (version_b.strip('.').split('.') + ['0'] * num_parts)[:num_parts]
|
||||
version_a_key = ''
|
||||
version_b_key = ''
|
||||
for i in range(num_parts):
|
||||
pad = '{:0>%d}.' % max([9, 1 + len(version_a[i]), 1 + len(version_b[i])])
|
||||
version_a_key += pad.format(version_a[i])
|
||||
version_b_key += pad.format(version_b[i])
|
||||
|
||||
if op == '==':
|
||||
return version_a_key == version_b_key
|
||||
if op == '<=':
|
||||
return version_a_key <= version_b_key
|
||||
if op == '>=':
|
||||
return version_a_key >= version_b_key
|
||||
if op == '>':
|
||||
return version_a_key > version_b_key
|
||||
if op == '<':
|
||||
return version_a_key < version_b_key
|
||||
raise ValueError('Unrecognized comparison operator [{}]'.format(op))
|
||||
|
||||
@staticmethod
|
||||
def _parse_letter_version(
|
||||
letter, # type: str
|
||||
number, # type: Union[str, bytes, SupportsInt]
|
||||
):
|
||||
# type: (...) -> Optional[Tuple[str, int]]
|
||||
|
||||
if letter:
|
||||
# We consider there to be an implicit 0 in a pre-release if there is
|
||||
# not a numeral associated with it.
|
||||
if number is None:
|
||||
number = 0
|
||||
|
||||
# We normalize any letters to their lower case form
|
||||
letter = letter.lower()
|
||||
|
||||
# We consider some words to be alternate spellings of other words and
|
||||
# in those cases we want to normalize the spellings to our preferred
|
||||
# spelling.
|
||||
if letter == "alpha":
|
||||
letter = "a"
|
||||
elif letter == "beta":
|
||||
letter = "b"
|
||||
elif letter in ["c", "pre", "preview"]:
|
||||
letter = "rc"
|
||||
elif letter in ["rev", "r"]:
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
if not letter and number:
|
||||
# We assume if we are given a number, but we are not given a letter
|
||||
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
|
||||
return ()
|
||||
|
||||
@staticmethod
|
||||
def _get_match_key(match, num_parts, ignore_sub_versions):
|
||||
if ignore_sub_versions:
|
||||
return (0, tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||
(), (), (), (),)
|
||||
return (
|
||||
int(match.group("epoch")) if match.group("epoch") else 0,
|
||||
tuple(int(i) for i in match.group("release").split(".")[:num_parts]),
|
||||
SimpleVersion._parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
||||
SimpleVersion._parse_letter_version(
|
||||
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
||||
),
|
||||
SimpleVersion._parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
||||
SimpleVersion._parse_local_version(match.group("local")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_local_version(local):
|
||||
# type: (str) -> Optional[LocalType]
|
||||
"""
|
||||
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||
"""
|
||||
if local is not None:
|
||||
return tuple(
|
||||
part.lower() if not part.isdigit() else int(part)
|
||||
for part in SimpleVersion._local_version_separators.split(local)
|
||||
)
|
||||
return ()
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class RequirementSubstitution(object):
|
||||
@@ -126,7 +338,15 @@ class RequirementSubstitution(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_install(self):
|
||||
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
return None
|
||||
|
||||
def post_install(self, session):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@@ -177,13 +397,20 @@ class SimpleSubstitution(RequirementSubstitution):
|
||||
|
||||
if req.specs:
|
||||
_, version_number = req.specs[0]
|
||||
assert semantic_version.Version(version_number, partial=True)
|
||||
# assert packaging_version.parse(version_number)
|
||||
else:
|
||||
version_number = self.get_pip_version(self.name)
|
||||
|
||||
req.specs = [('==', version_number + self.suffix)]
|
||||
return Text(req)
|
||||
|
||||
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
|
||||
"""
|
||||
:param list_of_requirements: {'pip': ['a==1.0', ]}
|
||||
:return: {'pip': ['a==1.0', ]}
|
||||
"""
|
||||
return list_of_requirements
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class CudaSensitiveSubstitution(SimpleSubstitution):
|
||||
@@ -235,15 +462,17 @@ class RequirementsManager(object):
|
||||
return None
|
||||
|
||||
def replace(self, requirements): # type: (Text) -> Text
|
||||
def safe_parse(req_str):
|
||||
try:
|
||||
return next(parse(req_str))
|
||||
except Exception as ex:
|
||||
return Requirement(req_str)
|
||||
|
||||
parsed_requirements = tuple(
|
||||
map(
|
||||
MarkerRequirement,
|
||||
filter(
|
||||
None,
|
||||
parse(requirements)
|
||||
if isinstance(requirements, six.text_type)
|
||||
else (next(parse(line), None) for line in requirements)
|
||||
)
|
||||
[safe_parse(line) for line in (requirements.splitlines()
|
||||
if isinstance(requirements, six.text_type) else requirements)]
|
||||
)
|
||||
)
|
||||
if not parsed_requirements:
|
||||
@@ -258,7 +487,7 @@ class RequirementsManager(object):
|
||||
warning('could not resolve python wheel replacement for {}'.format(req))
|
||||
raise
|
||||
except Exception:
|
||||
warning('could not resolve python wheel replacement for {}, '
|
||||
warning('could not resolve python wheel replacement for \"{}\", '
|
||||
'using original requirements line: {}'.format(req, i))
|
||||
return None
|
||||
|
||||
@@ -271,14 +500,34 @@ class RequirementsManager(object):
|
||||
)
|
||||
if not conda:
|
||||
result = map(self.translator.translate, result)
|
||||
|
||||
result = list(result)
|
||||
# add post scan add requirements call back
|
||||
for h in self.handlers:
|
||||
req = h.post_scan_add_req()
|
||||
if req:
|
||||
result.append(req.tostr())
|
||||
|
||||
return join_lines(result)
|
||||
|
||||
def post_install(self):
|
||||
def post_install(self, session):
|
||||
for h in self.handlers:
|
||||
try:
|
||||
h.post_install()
|
||||
h.post_install(session)
|
||||
except Exception as ex:
|
||||
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
|
||||
raise
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if self.translator:
|
||||
requirements = self.translator.replace_back(requirements)
|
||||
|
||||
for h in self.handlers:
|
||||
try:
|
||||
requirements = h.replace_back(requirements)
|
||||
except Exception:
|
||||
pass
|
||||
return requirements
|
||||
|
||||
@staticmethod
|
||||
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
|
||||
@@ -288,10 +537,24 @@ class RequirementsManager(object):
|
||||
if cuda_version and cudnn_version:
|
||||
return normalize_cuda_version(cuda_version), normalize_cuda_version(cudnn_version)
|
||||
|
||||
if not cuda_version and is_windows_platform():
|
||||
try:
|
||||
cuda_vers = [int(k.replace('CUDA_PATH_V', '').replace('_', '')) for k in os.environ.keys()
|
||||
if k.startswith('CUDA_PATH_V')]
|
||||
cuda_vers = max(cuda_vers)
|
||||
if cuda_vers > 40:
|
||||
cuda_version = cuda_vers
|
||||
except:
|
||||
pass
|
||||
|
||||
if not cuda_version:
|
||||
try:
|
||||
try:
|
||||
output = Argv('nvcc', '--version').get_output()
|
||||
nvcc = 'nvcc.exe' if is_windows_platform() else 'nvcc'
|
||||
if is_windows_platform() and 'CUDA_PATH' in os.environ:
|
||||
nvcc = os.path.join(os.environ['CUDA_PATH'], nvcc)
|
||||
|
||||
output = Argv(nvcc, '--version').get_output()
|
||||
except OSError:
|
||||
raise CudaNotFound('nvcc not found')
|
||||
match = re.search(r'release (.{3})', output).group(1)
|
||||
|
||||
@@ -22,7 +22,8 @@ class RequirementsTranslator(object):
|
||||
self.enabled = config["agent.pip_download_cache.enabled"]
|
||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.config = Config()
|
||||
self.pip = SystemPip(interpreter=interpreter)
|
||||
self.pip = SystemPip(interpreter=interpreter, session=self._session)
|
||||
self._translate_back = {}
|
||||
|
||||
def download(self, url):
|
||||
self.pip.download_package(url, cache_dir=self.cache_dir)
|
||||
@@ -60,4 +61,30 @@ class RequirementsTranslator(object):
|
||||
except Exception:
|
||||
command.error('Could not download wheel name of "{}"'.format(line))
|
||||
return line
|
||||
|
||||
self._translate_back[str(downloaded)] = line
|
||||
|
||||
return downloaded
|
||||
|
||||
def replace_back(self, requirements):
|
||||
if not requirements:
|
||||
return requirements
|
||||
|
||||
for k in requirements:
|
||||
# k is either pip/conda
|
||||
if k not in ('pip', 'conda'):
|
||||
continue
|
||||
|
||||
original_requirements = requirements[k]
|
||||
new_requirements = []
|
||||
for line in original_requirements:
|
||||
local_file = [d for d in self._translate_back.keys() if d in line]
|
||||
if local_file:
|
||||
local_file = local_file[0]
|
||||
new_requirements.append(line.replace(local_file, self._translate_back[local_file]))
|
||||
else:
|
||||
new_requirements.append(line)
|
||||
|
||||
requirements[k] = new_requirements
|
||||
|
||||
return requirements
|
||||
|
||||
@@ -11,6 +11,7 @@ from copy import deepcopy
|
||||
from distutils.spawn import find_executable
|
||||
from itertools import chain, repeat, islice
|
||||
from os.path import devnull
|
||||
from time import sleep
|
||||
from typing import Union, Text, Sequence, Any, TypeVar, Callable
|
||||
|
||||
import psutil
|
||||
@@ -41,6 +42,30 @@ def get_bash_output(cmd, strip=False, stderr=subprocess.STDOUT, stdin=False):
|
||||
return output if not strip or not output else output.strip()
|
||||
|
||||
|
||||
def terminate_process(pid, timeout=10.):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
proc = psutil.Process(pid)
|
||||
proc.terminate()
|
||||
cnt = 0
|
||||
while proc.is_running() and cnt < timeout:
|
||||
sleep(1.)
|
||||
cnt += 1
|
||||
proc.terminate()
|
||||
cnt = 0
|
||||
while proc.is_running() and cnt < timeout:
|
||||
sleep(1.)
|
||||
cnt += 1
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return not psutil.Process(pid).is_running()
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def kill_all_child_processes(pid=None):
|
||||
# get current process if pid not provided
|
||||
include_parent = True
|
||||
@@ -59,6 +84,58 @@ def kill_all_child_processes(pid=None):
|
||||
parent.kill()
|
||||
|
||||
|
||||
def get_docker_id(docker_cmd_contains):
|
||||
try:
|
||||
containers_running = get_bash_output(cmd='docker ps --no-trunc --format \"{{.ID}}: {{.Command}}\"')
|
||||
for docker_line in containers_running.split('\n'):
|
||||
parts = docker_line.split(':')
|
||||
if docker_cmd_contains in parts[-1]:
|
||||
# we found our docker, return it
|
||||
return parts[0]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def shutdown_docker_process(docker_cmd_contains=None, docker_id=None):
|
||||
try:
|
||||
if not docker_id:
|
||||
docker_id = get_docker_id(docker_cmd_contains=docker_cmd_contains)
|
||||
if docker_id:
|
||||
# we found our docker, stop it
|
||||
get_bash_output(cmd='docker stop -t 1 {}'.format(docker_id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def commit_docker(container_name, docker_cmd_contains=None, docker_id=None, apply_change=None):
|
||||
"""
|
||||
Commit a docker into a new image
|
||||
:param str container_name: Name for the new image
|
||||
:param docker_cmd_contains: partial container id to be committed
|
||||
:param str docker_id: Id of container to be comitted
|
||||
:param str apply_change: apply Dockerfile instructions to the image that is created
|
||||
(see docker commit documentation for '--change').
|
||||
"""
|
||||
try:
|
||||
if not docker_id:
|
||||
docker_id = get_docker_id(docker_cmd_contains=docker_cmd_contains)
|
||||
if not docker_id:
|
||||
print("Failed locating requested docker")
|
||||
return False
|
||||
|
||||
if docker_id:
|
||||
# we found our docker, stop it
|
||||
apply_change = '--change=\'{}\''.format(apply_change) if apply_change else ''
|
||||
output = get_bash_output(cmd='docker commit {} {} {}'.format(apply_change, docker_id, container_name))
|
||||
return output
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("Failed storing requested docker")
|
||||
return False
|
||||
|
||||
|
||||
def check_if_command_exists(cmd):
|
||||
return bool(find_executable(cmd))
|
||||
|
||||
@@ -124,6 +201,8 @@ class Argv(Executable):
|
||||
"""
|
||||
Returns a string of the shell command
|
||||
"""
|
||||
if is_windows_platform():
|
||||
return self.ARGV_SEPARATOR.join(map(double_quote, self))
|
||||
return self.ARGV_SEPARATOR.join(map(quote, self))
|
||||
|
||||
def call_subprocess(self, func, censor_password=False, *args, **kwargs):
|
||||
@@ -144,6 +223,9 @@ class Argv(Executable):
|
||||
return "Executing: {}".format(self.argv)
|
||||
|
||||
def __iter__(self):
|
||||
if is_windows_platform():
|
||||
return (word.as_posix().replace('/', '\\') if isinstance(word, Path) else six.text_type(word)
|
||||
for word in self.argv)
|
||||
return (six.text_type(word) for word in self.argv)
|
||||
|
||||
def __getitem__(self, item):
|
||||
@@ -224,7 +306,8 @@ class CommandSequence(Executable):
|
||||
return islice(chain.from_iterable(zip(repeat(delimiter), seq)), 1, None)
|
||||
|
||||
def normalize(command):
|
||||
return list(command) if is_windows_platform() else command.serialize()
|
||||
# return list(command) if is_windows_platform() else command.serialize()
|
||||
return command.serialize()
|
||||
|
||||
return ' '.join(list(intersperse(self.JOIN_COMMAND_OPERATOR, map(normalize, self.commands))))
|
||||
|
||||
@@ -266,8 +349,6 @@ class CommandSequence(Executable):
|
||||
|
||||
def pretty(self):
|
||||
serialized = self.serialize()
|
||||
if is_windows_platform():
|
||||
return " ".join(serialized)
|
||||
return serialized
|
||||
|
||||
|
||||
@@ -361,3 +442,18 @@ def quote(s):
|
||||
# use single quotes, and put single quotes into double quotes
|
||||
# the string $'b is then quoted as '$'"'"'b'
|
||||
return "'" + s.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
|
||||
def double_quote(s):
|
||||
"""
|
||||
Backport of shlex.quote():
|
||||
Return a shell-escaped version of the string *s*.
|
||||
"""
|
||||
if not s:
|
||||
return "''"
|
||||
if _find_unsafe(s) is None:
|
||||
return s
|
||||
|
||||
# use single quotes, and put single quotes into double quotes
|
||||
# the string $"b is then quoted as "$"""b"
|
||||
return '"' + s.replace('"', '"\'\"\'"') + '"'
|
||||
|
||||
@@ -5,13 +5,15 @@ import subprocess
|
||||
from distutils.spawn import find_executable
|
||||
from hashlib import md5
|
||||
from os import environ, getenv
|
||||
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple
|
||||
from typing import Text, Sequence, Mapping, Iterable, TypeVar, Callable, Tuple, Optional
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
|
||||
from trains_agent.definitions import ENV_AGENT_GIT_USER, ENV_AGENT_GIT_PASS, ENV_AGENT_GIT_HOST
|
||||
from trains_agent.helper.console import ensure_text, ensure_binary
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.base import (
|
||||
@@ -42,7 +44,9 @@ class VcsFactory(object):
|
||||
:param location: (desired) clone location
|
||||
"""
|
||||
url = execution_info.repository
|
||||
is_git = url.endswith(cls.GIT_SUFFIX)
|
||||
# We only support git, hg is deprecated
|
||||
is_git = True
|
||||
# is_git = url.endswith(cls.GIT_SUFFIX)
|
||||
vcs_cls = Git if is_git else Hg
|
||||
revision = (
|
||||
execution_info.version_num
|
||||
@@ -93,7 +97,7 @@ class VCS(object):
|
||||
:param session: program session
|
||||
:param url: repository url
|
||||
:param location: (desired) clone location
|
||||
:param: desired clone revision
|
||||
:param revision: desired clone revision
|
||||
"""
|
||||
self.session = session
|
||||
self.log = self.session.get_logger(
|
||||
@@ -146,12 +150,23 @@ class VCS(object):
|
||||
"""
|
||||
Apply patch repository at `location`
|
||||
"""
|
||||
self.log.info("applying diff to %s", location)
|
||||
self.log.info("applying diff to %s" % location)
|
||||
|
||||
for match in filter(
|
||||
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
|
||||
):
|
||||
create_file_if_not_exists(normalize_path(location, match.group("path")))
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
for match in filter(
|
||||
None, map(self.PATCH_ADDED_FILE_RE.match, patch_content.splitlines())
|
||||
):
|
||||
file_path = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
file_path = normalize_path(location, match.group("path"))
|
||||
create_file_if_not_exists(file_path)
|
||||
except Exception:
|
||||
if file_path:
|
||||
self.log.warning("failed creating file for git diff (%s)" % file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return_code, errors = self.call_with_stdin(
|
||||
patch_content, *self.patch_base, cwd=location
|
||||
@@ -204,7 +219,7 @@ class VCS(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_ssh_url(cls, url):
|
||||
def replace_ssh_url(cls, url):
|
||||
# type: (Text) -> Text
|
||||
"""
|
||||
Replace SSH URL with HTTPS URL when applicable
|
||||
@@ -238,18 +253,58 @@ class VCS(object):
|
||||
).url
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def replace_http_url(cls, url, port=None):
|
||||
# type: (Text, Optional[int]) -> Text
|
||||
"""
|
||||
Replace HTTPS URL with SSH URL when applicable
|
||||
"""
|
||||
parsed_url = furl(url)
|
||||
if parsed_url.scheme == "https":
|
||||
parsed_url.scheme = "ssh"
|
||||
parsed_url.username = "git"
|
||||
parsed_url.password = None
|
||||
# make sure there is no port in the final url (safe_furl support)
|
||||
# the original port was an https port, and we do not know if there is a different ssh port,
|
||||
# so we have to clear the original port specified (https) and use the default ssh schema port.
|
||||
parsed_url.port = port or None
|
||||
url = parsed_url.url
|
||||
return url
|
||||
|
||||
def _set_ssh_url(self):
|
||||
"""
|
||||
Replace instance URL with SSH substitution result and report to log.
|
||||
According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
|
||||
"""
|
||||
if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
|
||||
parsed_url = furl(self.url)
|
||||
# only apply to a specific domain (if requested)
|
||||
config_domain = \
|
||||
ENV_AGENT_GIT_HOST.get() or self.session.config.get("agent.git_host", None)
|
||||
if config_domain and config_domain != parsed_url.host:
|
||||
return
|
||||
if parsed_url.scheme == "https":
|
||||
new_url = self.replace_http_url(
|
||||
self.url, port=self.session.config.get('agent.force_git_ssh_port', None))
|
||||
if new_url != self.url:
|
||||
print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
|
||||
self.url, new_url))
|
||||
self.url = new_url
|
||||
return
|
||||
|
||||
if not self.session.config.agent.translate_ssh:
|
||||
return
|
||||
|
||||
ssh_agent_variable = "SSH_AUTH_SOCK"
|
||||
if not getenv(ssh_agent_variable) and (self.session.config.get('agent.git_user', None) and
|
||||
self.session.config.get('agent.git_pass', None)):
|
||||
new_url = self.resolve_ssh_url(self.url)
|
||||
# if we have git_user / git_pass replace ssh credentials with https authentication
|
||||
if (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and \
|
||||
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)):
|
||||
# only apply to a specific domain (if requested)
|
||||
config_domain = \
|
||||
ENV_AGENT_GIT_HOST.get() or self.session.config.get("git_host", None)
|
||||
if config_domain and config_domain != furl(self.url).host:
|
||||
return
|
||||
|
||||
new_url = self.replace_ssh_url(self.url)
|
||||
if new_url != self.url:
|
||||
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
|
||||
self.url, new_url))
|
||||
@@ -263,8 +318,9 @@ class VCS(object):
|
||||
"""
|
||||
self._set_ssh_url()
|
||||
clone_command = ("clone", self.url_with_auth, self.location) + self.clone_flags
|
||||
if branch:
|
||||
clone_command += ("-b", branch)
|
||||
# clone all branches regardless of when we want to later checkout
|
||||
# if branch:
|
||||
# clone_command += ("-b", branch)
|
||||
if self.session.debug_mode:
|
||||
self.call(*clone_command)
|
||||
return
|
||||
@@ -325,9 +381,10 @@ class VCS(object):
|
||||
"""
|
||||
Run command with `input_` as stdin
|
||||
"""
|
||||
input_ = input_.encode("latin1")
|
||||
if not input_.endswith(b"\n"):
|
||||
input_ += b"\n"
|
||||
input_ = input_.encode("utf-8")
|
||||
# always add extra empty line
|
||||
# (there is no downside, and it solves empty lines issue at end of patch cause corrupt message)
|
||||
input_ += b"\n"
|
||||
process = self._call_subprocess(
|
||||
subprocess.Popen,
|
||||
argv,
|
||||
@@ -389,15 +446,20 @@ class VCS(object):
|
||||
Add username and password to URL if missing from URL and present in config.
|
||||
Does not modify ssh URLs.
|
||||
"""
|
||||
parsed_url = furl(url)
|
||||
try:
|
||||
parsed_url = furl(url)
|
||||
except ValueError:
|
||||
return url
|
||||
if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"):
|
||||
return parsed_url.url
|
||||
config_user = config.get("agent.{}_user".format(cls.executable_name), None)
|
||||
config_pass = config.get("agent.{}_pass".format(cls.executable_name), None)
|
||||
config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)
|
||||
config_pass = ENV_AGENT_GIT_PASS.get() or config.get("agent.{}_pass".format(cls.executable_name), None)
|
||||
config_domain = ENV_AGENT_GIT_HOST.get() or config.get("agent.{}_host".format(cls.executable_name), None)
|
||||
if (
|
||||
(not (parsed_url.username and parsed_url.password))
|
||||
and config_user
|
||||
and config_pass
|
||||
and (not config_domain or config_domain.lower() == parsed_url.host)
|
||||
):
|
||||
parsed_url.set(username=config_user, password=config_pass)
|
||||
return parsed_url.url
|
||||
@@ -453,16 +515,26 @@ class Git(VCS):
|
||||
)
|
||||
|
||||
def pull(self):
|
||||
self.call("fetch", "origin", cwd=self.location)
|
||||
self.call("fetch", "--all", "--recurse-submodules", cwd=self.location)
|
||||
|
||||
def checkout(self): # type: () -> None
|
||||
"""
|
||||
Checkout repository at specified revision
|
||||
"""
|
||||
self.call("checkout", self.revision, *self.checkout_flags, cwd=self.location)
|
||||
try:
|
||||
self.call("submodule", "update", "--recursive", cwd=self.location)
|
||||
except:
|
||||
pass
|
||||
|
||||
info_commands = dict(
|
||||
url=Argv("git", "remote", "get-url", "origin"),
|
||||
branch=Argv("git", "rev-parse", "--abbrev-ref", "HEAD"),
|
||||
commit=Argv("git", "rev-parse", "HEAD"),
|
||||
root=Argv("git", "rev-parse", "--show-toplevel"),
|
||||
url=Argv(executable_name, "ls-remote", "--get-url", "origin"),
|
||||
branch=Argv(executable_name, "rev-parse", "--abbrev-ref", "HEAD"),
|
||||
commit=Argv(executable_name, "rev-parse", "HEAD"),
|
||||
root=Argv(executable_name, "rev-parse", "--show-toplevel"),
|
||||
)
|
||||
|
||||
patch_base = ("apply",)
|
||||
patch_base = ("apply", "--unidiff-zero", )
|
||||
|
||||
|
||||
class Hg(VCS):
|
||||
@@ -493,10 +565,10 @@ class Hg(VCS):
|
||||
)
|
||||
|
||||
info_commands = dict(
|
||||
url=Argv("hg", "paths", "--verbose"),
|
||||
branch=Argv("hg", "--debug", "id", "-b"),
|
||||
commit=Argv("hg", "--debug", "id", "-i"),
|
||||
root=Argv("hg", "root"),
|
||||
url=Argv(executable_name, "paths", "--verbose"),
|
||||
branch=Argv(executable_name, "--debug", "id", "-b"),
|
||||
commit=Argv(executable_name, "--debug", "id", "-i"),
|
||||
root=Argv(executable_name, "root"),
|
||||
)
|
||||
|
||||
|
||||
@@ -516,11 +588,16 @@ def clone_repository_cached(session, execution, destination):
|
||||
|
||||
clone_folder_name = Path(str(furl(repo_url).path)).name # type: str
|
||||
clone_folder = Path(destination) / clone_folder_name
|
||||
cached_repo_path = (
|
||||
Path(session.config["agent.vcs_cache.path"]).expanduser()
|
||||
/ "{}.{}".format(clone_folder_name, md5(ensure_binary(repo_url)).hexdigest())
|
||||
/ clone_folder_name
|
||||
) # type: Path
|
||||
|
||||
standalone_mode = session.config.get("agent.standalone_mode", False)
|
||||
if standalone_mode:
|
||||
cached_repo_path = clone_folder
|
||||
else:
|
||||
cached_repo_path = (
|
||||
Path(session.config["agent.vcs_cache.path"]).expanduser()
|
||||
/ "{}.{}".format(clone_folder_name, md5(ensure_binary(repo_url)).hexdigest())
|
||||
/ clone_folder_name
|
||||
) # type: Path
|
||||
|
||||
vcs = VcsFactory.create(
|
||||
session, execution_info=execution, location=cached_repo_path
|
||||
@@ -528,24 +605,29 @@ def clone_repository_cached(session, execution, destination):
|
||||
if not find_executable(vcs.executable_name):
|
||||
raise CommandFailedError(vcs.executable_not_found_error_help())
|
||||
|
||||
if session.config["agent.vcs_cache.enabled"] and cached_repo_path.exists():
|
||||
print('Using cached repository in "{}"'.format(cached_repo_path))
|
||||
else:
|
||||
print("cloning: {}".format(no_password_url))
|
||||
rm_tree(cached_repo_path)
|
||||
vcs.clone(branch=execution.branch)
|
||||
if not standalone_mode:
|
||||
if session.config["agent.vcs_cache.enabled"] and cached_repo_path.exists():
|
||||
print('Using cached repository in "{}"'.format(cached_repo_path))
|
||||
|
||||
vcs.pull()
|
||||
vcs.checkout()
|
||||
else:
|
||||
print("cloning: {}".format(no_password_url))
|
||||
rm_tree(cached_repo_path)
|
||||
# We clone the entire repository, not a specific branch
|
||||
vcs.clone() # branch=execution.branch)
|
||||
|
||||
rm_tree(destination)
|
||||
shutil.copytree(Text(cached_repo_path), Text(clone_folder))
|
||||
if not clone_folder.is_dir():
|
||||
raise CommandFailedError(
|
||||
"copying of repository failed: from {} to {}".format(
|
||||
cached_repo_path, clone_folder
|
||||
vcs.pull()
|
||||
rm_tree(destination)
|
||||
shutil.copytree(Text(cached_repo_path), Text(clone_folder))
|
||||
if not clone_folder.is_dir():
|
||||
raise CommandFailedError(
|
||||
"copying of repository failed: from {} to {}".format(
|
||||
cached_repo_path, clone_folder
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# checkout in the newly copy destination
|
||||
vcs.location = Text(clone_folder)
|
||||
vcs.checkout()
|
||||
|
||||
repo_info = vcs.get_repository_copy_info(clone_folder)
|
||||
|
||||
|
||||
@@ -75,9 +75,15 @@ class ResourceMonitor(object):
|
||||
self._exit_event = Event()
|
||||
self._gpustat_fail = 0
|
||||
self._gpustat = gpustat
|
||||
if not self._gpustat:
|
||||
self._active_gpus = None
|
||||
if os.environ.get('NVIDIA_VISIBLE_DEVICES') == 'none':
|
||||
# NVIDIA_VISIBLE_DEVICES set to none, marks cpu_only flag
|
||||
# active_gpus == False means no GPU reporting
|
||||
self._active_gpus = False
|
||||
elif not self._gpustat:
|
||||
log.warning('Trains-Agent Resource Monitor: GPU monitoring is not available')
|
||||
else:
|
||||
# None means no filtering, report all gpus
|
||||
self._active_gpus = None
|
||||
try:
|
||||
active_gpus = os.environ.get('NVIDIA_VISIBLE_DEVICES', '') or \
|
||||
@@ -244,8 +250,8 @@ class ResourceMonitor(object):
|
||||
stats["io_read_mbs"] = BytesSizes.megabytes(io_stats.read_bytes)
|
||||
stats["io_write_mbs"] = BytesSizes.megabytes(io_stats.write_bytes)
|
||||
|
||||
# check if we can access the gpu statistics
|
||||
if self._gpustat:
|
||||
# check if we need to monitor gpus and if we can access the gpu statistics
|
||||
if self._active_gpus is not False and self._gpustat:
|
||||
try:
|
||||
gpu_stat = self._gpustat.new_query()
|
||||
for i, g in enumerate(gpu_stat.gpus):
|
||||
|
||||
170
trains_agent/helper/runtime_verification.py
Normal file
170
trains_agent/helper/runtime_verification.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from trains_agent.backend_config.defs import UptimeConf
|
||||
|
||||
DAYS = ["SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"]
|
||||
PATTERN = re.compile(r"^(?P<hours>[^\s]+)\s(?P<days>[^\s]+)")
|
||||
|
||||
|
||||
def check_runtime(ranges_list, is_uptime=True):
|
||||
# type: (List[str], bool) -> bool
|
||||
for entry in ranges_list:
|
||||
|
||||
days_list = get_days_list(entry)
|
||||
if not check_day(days_list):
|
||||
continue
|
||||
|
||||
hours_list = get_hours_list(entry)
|
||||
if check_hour(hours_list):
|
||||
return is_uptime
|
||||
return not is_uptime
|
||||
|
||||
|
||||
def check_hour(hours):
|
||||
# type: (List[str]) -> bool
|
||||
return datetime.now().hour in hours
|
||||
|
||||
|
||||
def check_day(days):
|
||||
# type: (List[str]) -> bool
|
||||
return datetime.now().strftime("%a").upper() in days
|
||||
|
||||
|
||||
def get_days_list(entry):
|
||||
# type: (str) -> List[str]
|
||||
days_intervals = PATTERN.match(entry)["days"].split(",")
|
||||
days_total = []
|
||||
for days in days_intervals:
|
||||
start, end = days.split("-") if "-" in days else (days, days)
|
||||
try:
|
||||
days_total.extend(
|
||||
[*range(DAYS.index(start.upper()), DAYS.index(end.upper()) + 1)]
|
||||
)
|
||||
except ValueError:
|
||||
print(
|
||||
"Warning: days interval '{}' is invalid, use intervals of the format <start>-<end>."
|
||||
" make sure to use the abbreviated format SUN-SAT".format(days)
|
||||
)
|
||||
continue
|
||||
return [DAYS[day] for day in days_total]
|
||||
|
||||
|
||||
def get_hours_list(entry):
|
||||
# type: (str) -> List[str]
|
||||
hours_intervals = PATTERN.match(entry)["hours"].split(",")
|
||||
hours_total = []
|
||||
for hours in hours_intervals:
|
||||
start, end = get_start_end_hours(hours)
|
||||
if not (start and end):
|
||||
continue
|
||||
hours_total.extend([*range(start, end)])
|
||||
return hours_total
|
||||
|
||||
|
||||
def get_start_end_hours(hours):
|
||||
# type: (str) -> Tuple[int, int]
|
||||
try:
|
||||
start, end = (
|
||||
tuple(map(int, hours.split("-"))) if "-" in hours else (int(hours), 24)
|
||||
)
|
||||
except Exception as ex:
|
||||
print(
|
||||
"Warning: hours interval '{}' is invalid, use intervals of the format <start>-<end>".format(
|
||||
hours, ex
|
||||
)
|
||||
)
|
||||
start, end = (None, None)
|
||||
if end == 0:
|
||||
end = 24
|
||||
return start, end
|
||||
|
||||
|
||||
def print_uptime_properties(
|
||||
ranges_list, queues_info, runtime_properties, is_uptime=True
|
||||
):
|
||||
# type: (List[str], List[dict], List[dict], bool) -> None
|
||||
if ranges_list:
|
||||
uptime_string = ["Working hours {} configurations".format("uptime" if is_uptime else "downtime")]
|
||||
uptime_string.extend(get_uptime_string(entry) for entry in ranges_list)
|
||||
else:
|
||||
uptime_string = ["No uptime/downtime configurations found"]
|
||||
|
||||
is_server_forced, server_string = get_runtime_properties_string(runtime_properties)
|
||||
is_queue_forced, queues_string = get_queues_tags_string(queues_info)
|
||||
|
||||
res = list(
|
||||
filter(
|
||||
len,
|
||||
[
|
||||
"\n".join(uptime_string),
|
||||
"\nCurrently forced {}".format(is_queue_forced or is_server_forced)
|
||||
if is_queue_forced or is_server_forced
|
||||
else " ",
|
||||
server_string,
|
||||
queues_string,
|
||||
],
|
||||
)
|
||||
)
|
||||
print("\n".join(res))
|
||||
|
||||
|
||||
def get_uptime_string(entry):
|
||||
# type: (str) -> str
|
||||
res = []
|
||||
days_list = get_days_list(entry)
|
||||
hours_intervals = PATTERN.match(entry)["hours"].split(",")
|
||||
for hours in hours_intervals:
|
||||
start, end = get_start_end_hours(hours)
|
||||
if not (start and end):
|
||||
continue
|
||||
res.append(
|
||||
" - {}:00-{}:59 on {}".format(start, end - 1, ' and '.join(days_list))
|
||||
if not (start == end)
|
||||
else ""
|
||||
)
|
||||
return "\n".join(res)
|
||||
|
||||
|
||||
def get_runtime_properties_string(runtime_properties):
|
||||
# type: (List[dict]) -> Tuple[Optional[str], str]
|
||||
server_string = []
|
||||
force_flag = next(
|
||||
(prop for prop in runtime_properties if prop["key"] == UptimeConf.worker_key),
|
||||
None,
|
||||
)
|
||||
is_server_forced = None
|
||||
if force_flag:
|
||||
is_server_forced = force_flag["value"].upper()
|
||||
expiry_hour = (
|
||||
(datetime.now() + timedelta(seconds=force_flag["expiry"])).strftime("%H:%M")
|
||||
if force_flag["expiry"]
|
||||
else None
|
||||
)
|
||||
expires = " expires at {}".format(expiry_hour) if expiry_hour else ""
|
||||
server_string.append(
|
||||
" - Server runtime property '{}: {}'{}".format(force_flag['key'], force_flag['value'], expires)
|
||||
)
|
||||
return is_server_forced, "\n".join(server_string)
|
||||
|
||||
|
||||
def get_queues_tags_string(queues_info):
|
||||
# type: (List[dict]) -> Tuple[Optional[str], str]
|
||||
queues_string = []
|
||||
is_queue_forced = None
|
||||
for queue in queues_info:
|
||||
if "force_workers:off" in queue.get("tags", []):
|
||||
tag = "force_workers:off"
|
||||
is_queue_forced = is_queue_forced or "OFF"
|
||||
elif "force_workers:on" in queue.get("tags", []):
|
||||
tag = "force_workers:on"
|
||||
is_queue_forced = "ON"
|
||||
else:
|
||||
tag = None
|
||||
tagged = " (tagged '{}')'".format(tag) if tag else ""
|
||||
queues_string.append(
|
||||
" - Listening to queue '{}'{}".format(queue.get('name', ''), tagged)
|
||||
)
|
||||
return is_queue_forced, "\n".join(queues_string)
|
||||
@@ -4,11 +4,14 @@ from time import sleep
|
||||
from glob import glob
|
||||
from tempfile import gettempdir, NamedTemporaryFile
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from trains_agent.definitions import ENV_DOCKER_HOST_MOUNT
|
||||
from trains_agent.helper.base import warning
|
||||
|
||||
|
||||
class Singleton(object):
|
||||
prefix = 'trainsagent'
|
||||
prefix = '.trainsagent'
|
||||
sep = '_'
|
||||
ext = '.tmp'
|
||||
worker_id = None
|
||||
@@ -17,9 +20,31 @@ class Singleton(object):
|
||||
_pid_file = None
|
||||
_lock_file_name = sep+prefix+sep+'global.lock'
|
||||
_lock_timeout = 10
|
||||
_pid = None
|
||||
|
||||
@classmethod
|
||||
def register_instance(cls, unique_worker_id=None, worker_name=None):
|
||||
def update_pid_file(cls):
|
||||
new_pid = str(os.getpid())
|
||||
if not cls._pid_file or cls._pid == new_pid:
|
||||
return
|
||||
old_name = cls._pid_file.name
|
||||
parts = cls._pid_file.name.split(os.path.sep)
|
||||
parts[-1] = parts[-1].replace(cls.sep + cls._pid + cls.sep, cls.sep + new_pid + cls.sep)
|
||||
new_pid_file = os.path.sep.join(parts)
|
||||
cls._pid = new_pid
|
||||
cls._pid_file.name = new_pid_file
|
||||
# we need to rename to match new pid
|
||||
try:
|
||||
os.rename(old_name, new_pid_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_lock_filename(cls):
|
||||
return os.path.join(cls._get_temp_folder(), cls._lock_file_name)
|
||||
|
||||
@classmethod
|
||||
def register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
|
||||
"""
|
||||
# Exit the process if another instance of us is using the same worker_id
|
||||
|
||||
@@ -28,7 +53,7 @@ class Singleton(object):
|
||||
:return: (str worker_id, int slot_number) Return None value on instance already running
|
||||
"""
|
||||
# try to lock file
|
||||
lock_file = os.path.join(gettempdir(), cls._lock_file_name)
|
||||
lock_file = cls.get_lock_filename()
|
||||
timeout = 0
|
||||
while os.path.exists(lock_file):
|
||||
if timeout > cls._lock_timeout:
|
||||
@@ -46,7 +71,9 @@ class Singleton(object):
|
||||
f.write(bytes(os.getpid()))
|
||||
f.flush()
|
||||
try:
|
||||
ret = cls._register_instance(unique_worker_id=unique_worker_id, worker_name=worker_name)
|
||||
ret = cls._register_instance(
|
||||
unique_worker_id=unique_worker_id, worker_name=worker_name,
|
||||
api_client=api_client, allow_double=allow_double)
|
||||
except:
|
||||
ret = None, None
|
||||
|
||||
@@ -58,23 +85,51 @@ class Singleton(object):
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _register_instance(cls, unique_worker_id=None, worker_name=None):
|
||||
def get_running_pids(cls):
|
||||
# type: () -> List[Tuple[int, Optional[str], Optional[int], str]]
|
||||
temp_folder = cls._get_temp_folder()
|
||||
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
|
||||
pids = []
|
||||
for file in files:
|
||||
parts = os.path.basename(file).split(cls.sep)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if not psutil.pid_exists(pid):
|
||||
pid = -1
|
||||
except Exception:
|
||||
# something is wrong, use non existing pid and delete the file
|
||||
pid = -1
|
||||
|
||||
uid, slot = None, None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
uid, slot = str(f.read()).split('\n')
|
||||
slot = int(slot)
|
||||
except Exception:
|
||||
pass
|
||||
pids.append((pid, uid, slot, file))
|
||||
|
||||
return pids
|
||||
|
||||
@classmethod
|
||||
def _register_instance(cls, unique_worker_id=None, worker_name=None, api_client=None, allow_double=False):
|
||||
if cls.worker_id:
|
||||
return cls.worker_id, cls.instance_slot
|
||||
# make sure we have a unique name
|
||||
instance_num = 0
|
||||
temp_folder = gettempdir()
|
||||
files = glob(os.path.join(temp_folder, cls.prefix + cls.sep + '*' + cls.ext))
|
||||
slots = {}
|
||||
for file in files:
|
||||
parts = file.split(cls.sep)
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
except Exception:
|
||||
# something is wrong, use non existing pid and delete the file
|
||||
pid = -1
|
||||
for pid, uid, slot, file in cls.get_running_pids():
|
||||
worker = None
|
||||
if api_client and ENV_DOCKER_HOST_MOUNT.get() and uid:
|
||||
try:
|
||||
worker = [w for w in api_client.workers.get_all() if w.id == uid]
|
||||
except Exception:
|
||||
worker = None
|
||||
|
||||
# count active instances and delete dead files
|
||||
if not psutil.pid_exists(pid):
|
||||
if not worker and pid < 0:
|
||||
# delete the file
|
||||
try:
|
||||
os.remove(os.path.join(file))
|
||||
@@ -83,15 +138,15 @@ class Singleton(object):
|
||||
continue
|
||||
|
||||
instance_num += 1
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
uid, slot = str(f.read()).split('\n')
|
||||
slot = int(slot)
|
||||
except Exception:
|
||||
if slot is None:
|
||||
continue
|
||||
|
||||
if uid == unique_worker_id:
|
||||
return None, None
|
||||
if allow_double:
|
||||
warning('Instance with the same WORKER_ID [{}] was found on this machine. '
|
||||
'We are ignoring it, make sure this not a mistake.'.format(unique_worker_id))
|
||||
else:
|
||||
return None, None
|
||||
|
||||
slots[slot] = uid
|
||||
|
||||
@@ -110,10 +165,27 @@ class Singleton(object):
|
||||
unique_worker_id = worker_name + cls.worker_name_sep + str(cls.instance_slot)
|
||||
|
||||
# create lock
|
||||
cls._pid_file = NamedTemporaryFile(dir=gettempdir(), prefix=cls.prefix + cls.sep + str(os.getpid()) + cls.sep,
|
||||
suffix=cls.ext)
|
||||
cls._pid = str(os.getpid())
|
||||
cls._pid_file = NamedTemporaryFile(
|
||||
dir=cls._get_temp_folder(), prefix=cls.prefix + cls.sep + cls._pid + cls.sep, suffix=cls.ext)
|
||||
cls._pid_file.write(('{}\n{}'.format(unique_worker_id, cls.instance_slot)).encode())
|
||||
cls._pid_file.flush()
|
||||
cls.worker_id = unique_worker_id
|
||||
|
||||
return cls.worker_id, cls.instance_slot
|
||||
|
||||
@classmethod
|
||||
def _get_temp_folder(cls):
|
||||
if ENV_DOCKER_HOST_MOUNT.get():
|
||||
return ENV_DOCKER_HOST_MOUNT.get().split(':')[-1]
|
||||
return gettempdir()
|
||||
|
||||
@classmethod
|
||||
def get_slot(cls):
|
||||
return cls.instance_slot or 0
|
||||
|
||||
@classmethod
|
||||
def get_pid_file(cls):
|
||||
if not cls._pid_file:
|
||||
return None
|
||||
return cls._pid_file.name
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from functools import partial
|
||||
from importlib import import_module
|
||||
import argparse
|
||||
@@ -24,8 +25,17 @@ def get_parser():
|
||||
from .worker import COMMANDS
|
||||
subparsers = top_parser.add_subparsers(dest='command')
|
||||
for c in COMMANDS:
|
||||
parser = subparsers.add_parser(name=c, help=COMMANDS[c]['help'])
|
||||
for a in COMMANDS[c].get('args', {}).keys():
|
||||
parser.add_argument(a, **COMMANDS[c]['args'][a])
|
||||
parser = subparsers.add_parser(name=c, help=COMMANDS[c]["help"])
|
||||
groups = itertools.groupby(
|
||||
sorted(
|
||||
COMMANDS[c].get("args", {}).items(), key=lambda x: x[1].get("group", "")
|
||||
),
|
||||
key=lambda x: x[1].pop("group", ""),
|
||||
)
|
||||
for group_name, group in groups:
|
||||
p = parser if not group_name else parser.add_argument_group(group_name)
|
||||
for key, value in group:
|
||||
aliases = value.pop("aliases", [])
|
||||
p.add_argument(key, *aliases, **value)
|
||||
|
||||
return top_parser
|
||||
|
||||
@@ -30,6 +30,17 @@ WORKER_ARGS = {
|
||||
'type': lambda x: x.upper(),
|
||||
'default': 'INFO',
|
||||
},
|
||||
'--gpus': {
|
||||
'help': 'Specify active GPUs for the daemon to use (docker / virtual environment), '
|
||||
'Equivalent to setting NVIDIA_VISIBLE_DEVICES '
|
||||
'Examples: --gpus 0 or --gpu 0,1,2 or --gpus all',
|
||||
'group': 'Docker support',
|
||||
},
|
||||
'--cpu-only': {
|
||||
'help': 'Disable GPU access for the daemon, only use CPU in either docker or virtual environment',
|
||||
'action': 'store_true',
|
||||
'group': 'Docker support',
|
||||
},
|
||||
}
|
||||
|
||||
DAEMON_ARGS = dict({
|
||||
@@ -40,9 +51,15 @@ DAEMON_ARGS = dict({
|
||||
'--docker': {
|
||||
'help': 'Run execution task inside a docker (v19.03 and above). Optional args <image> <arguments> or '
|
||||
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
|
||||
'set NVIDIA_VISIBLE_DEVICES to limit gpu visibility for docker',
|
||||
'use --gpus/--cpu-only (or set NVIDIA_VISIBLE_DEVICES) to limit gpu visibility for docker',
|
||||
'nargs': '*',
|
||||
'default': False,
|
||||
'group': 'Docker support',
|
||||
},
|
||||
'--force-current-version': {
|
||||
'help': 'Force trains-agent to use the current trains-agent version when running in the docker',
|
||||
'action': 'store_true',
|
||||
'group': 'Docker support',
|
||||
},
|
||||
'--queue': {
|
||||
'help': 'Queue ID(s)/Name(s) to pull tasks from (\'default\' queue)',
|
||||
@@ -51,6 +68,49 @@ DAEMON_ARGS = dict({
|
||||
'dest': 'queues',
|
||||
'type': foreign_object_id('queues'),
|
||||
},
|
||||
'--order-fairness': {
|
||||
'help': 'Pull from each queue in a round-robin order, instead of priority order.',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--standalone-mode': {
|
||||
'help': 'Do not use any network connects, assume everything is pre-installed',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--services-mode': {
|
||||
'help': 'Launch multiple long-term docker services. Implies docker & cpu-only flags.',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--create-queue': {
|
||||
'help': 'Create requested queue if it does not exist already.',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--detached': {
|
||||
'help': 'Detached mode, run agent in the background',
|
||||
'action': 'store_true',
|
||||
'aliases': ['-d'],
|
||||
},
|
||||
'--stop': {
|
||||
'help': 'Stop the running agent (based on the same set of arguments)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--uptime': {
|
||||
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "17-20 TUE" to set '
|
||||
'Tuesday\'s uptime to 17-20'
|
||||
'Note: Make sure to have only one of uptime/downtime configuration and not both.',
|
||||
'nargs': '*',
|
||||
'default': None,
|
||||
},
|
||||
'--downtime': {
|
||||
'help': 'Specify uptime for trains-agent in "<hours> <days>" format. for example, use "09-13 TUE" to set '
|
||||
'Tuesday\'s downtime to 09-13'
|
||||
'Note: Make sure to have only on of uptime/downtime configuration and not both.',
|
||||
'nargs': '*',
|
||||
'default': None,
|
||||
},
|
||||
'--status': {
|
||||
'help': 'Print the worker\'s schedule (uptime properties, server\'s runtime properties and listening queues)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
}, **WORKER_ARGS)
|
||||
|
||||
COMMANDS = {
|
||||
@@ -74,6 +134,26 @@ COMMANDS = {
|
||||
'help': 'Full environment setup log & task logging & monitoring (stdout is still visible)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--require-queue': {
|
||||
'help': 'If the specified task is not queued (in any Queue), the execution will fail. '
|
||||
'(Used for 3rd party scheduler integration, e.g. K8s, SLURM, etc.)',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--standalone-mode': {
|
||||
'help': 'Do not use any network connects, assume everything is pre-installed',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--docker': {
|
||||
'help': 'Run execution task inside a docker (v19.03 and above). Optional args <image> <arguments> or '
|
||||
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
|
||||
'use --gpus/--cpu-only (or set NVIDIA_VISIBLE_DEVICES) to limit gpu visibility for docker',
|
||||
'nargs': '*',
|
||||
'default': False,
|
||||
},
|
||||
'--clone': {
|
||||
'help': 'Clone the experiment before execution, and execute the cloned experiment',
|
||||
'action': 'store_true',
|
||||
},
|
||||
}, **WORKER_ARGS),
|
||||
},
|
||||
'build': {
|
||||
@@ -87,12 +167,32 @@ COMMANDS = {
|
||||
'dest': 'task_id',
|
||||
'type': foreign_object_id('tasks'),
|
||||
},
|
||||
'--target-folder': {
|
||||
'help': 'Where to build the task\'s virtual environment and source code',
|
||||
'--target': {
|
||||
'help': 'Where to build the task\'s virtual environment and source code. '
|
||||
'When used with --docker, target docker image name to create',
|
||||
},
|
||||
'--install-globally': {
|
||||
'help': 'Install required python packages before creating the virtual environment used to execute an '
|
||||
'experiment, and use the \'agent.package_manager.system_site_packages\' virtual env option. '
|
||||
'Note: when --docker is used, install-globally is always true',
|
||||
'action': 'store_true',
|
||||
},
|
||||
'--docker': {
|
||||
'help': 'Build the experiment inside a docker (v19.03 and above). Optional args <image> <arguments> or '
|
||||
'specify default docker image in agent.default_docker.image / agent.default_docker.arguments'
|
||||
'use --gpus/--cpu-only (or set NVIDIA_VISIBLE_DEVICES) to limit gpu visibility for docker',
|
||||
'nargs': '*',
|
||||
'default': False,
|
||||
},
|
||||
'--python-version': {
|
||||
'help': 'Virtual environment python version to use',
|
||||
},
|
||||
'--entry-point': {
|
||||
'help': 'Run the task in the new docker. There are two options:\nEither add "reuse_task" to run the '
|
||||
'given task in the docker, or "clone_task" to first clone the given task and then run it in the docker',
|
||||
'default': False,
|
||||
'choices': ['reuse_task', 'clone_task'],
|
||||
}
|
||||
}, **WORKER_ARGS),
|
||||
},
|
||||
'list': {
|
||||
|
||||
@@ -15,7 +15,7 @@ from pyhocon import ConfigFactory, HOCONConverter, ConfigTree
|
||||
from trains_agent.backend_api.session import Session as _Session, Request
|
||||
from trains_agent.backend_api.session.client import APIClient
|
||||
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES
|
||||
from trains_agent.definitions import ENVIRONMENT_CONFIG
|
||||
from trains_agent.definitions import ENVIRONMENT_CONFIG, ENV_TASK_EXECUTE_AS_USER, ENVIRONMENT_BACKWARD_COMPATIBLE
|
||||
from trains_agent.errors import APIError
|
||||
from trains_agent.helper.base import HOCONEncoder
|
||||
from trains_agent.helper.process import Argv
|
||||
@@ -63,6 +63,7 @@ def tree(*args):
|
||||
|
||||
class Session(_Session):
|
||||
version = __version__
|
||||
force_debug = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# make sure we set the environment variable so the parent session opens the correct file
|
||||
@@ -72,16 +73,35 @@ class Session(_Session):
|
||||
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
|
||||
if not Path(config_file).is_file():
|
||||
raise ValueError("Could not open configuration file: {}".format(config_file))
|
||||
|
||||
cpu_only = kwargs.get('cpu_only')
|
||||
if cpu_only:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = 'none'
|
||||
|
||||
if kwargs.get('gpus') and not os.environ.get('KUBERNETES_SERVICE_HOST') \
|
||||
and not os.environ.get('KUBERNETES_PORT'):
|
||||
# CUDA_VISIBLE_DEVICES does not support 'all'
|
||||
if kwargs.get('gpus') == 'all':
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES', None)
|
||||
os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
|
||||
|
||||
if kwargs.get('only_load_config'):
|
||||
from trains_agent.backend_api.config import load
|
||||
self.config = load()
|
||||
else:
|
||||
super(Session, self).__init__(*args, **kwargs)
|
||||
|
||||
# set force debug mode, if it's on:
|
||||
if Session.force_debug:
|
||||
self.config["agent"]["debug"] = True
|
||||
|
||||
self.log = self.get_logger(__name__)
|
||||
self.trace = kwargs.get('trace', False)
|
||||
self._config_file = kwargs.get('config_file') or \
|
||||
os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0]
|
||||
self.api_client = APIClient(session=self, api_version="2.4")
|
||||
self.api_client = APIClient(session=self, api_version="2.5")
|
||||
# HACK make sure we have python version to execute,
|
||||
# if nothing was specific, use the one that runs us
|
||||
def_python = ConfigValue(self.config, "agent.default_python")
|
||||
@@ -89,21 +109,36 @@ class Session(_Session):
|
||||
def_python.set("{version.major}.{version.minor}".format(version=sys.version_info))
|
||||
|
||||
# HACK: backwards compatibility
|
||||
os.environ['ALG_CONFIG_FILE'] = self._config_file
|
||||
os.environ['SM_CONFIG_FILE'] = self._config_file
|
||||
if ENVIRONMENT_BACKWARD_COMPATIBLE.get():
|
||||
os.environ['ALG_CONFIG_FILE'] = self._config_file
|
||||
os.environ['SM_CONFIG_FILE'] = self._config_file
|
||||
|
||||
if not self.config.get('api.host', None) and self.config.get('api.api_server', None):
|
||||
self.config['api']['host'] = self.config.get('api.api_server')
|
||||
|
||||
# initialize nvidia visibility variable
|
||||
# initialize nvidia visibility variables
|
||||
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
|
||||
if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
||||
# do not create CUDA_VISIBLE_DEVICES if it doesn't exist, it breaks TF/PyTotch CUDA detection
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
||||
pass
|
||||
elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'):
|
||||
os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
# override with environment variables
|
||||
# cuda_version & cudnn_version are overridden with os.environ here, and normalized in the next section
|
||||
for config_key, env_config in ENVIRONMENT_CONFIG.items():
|
||||
# check if the propery is of a list:
|
||||
if config_key.endswith('.0'):
|
||||
if all(not i.get() for i in env_config.values()):
|
||||
continue
|
||||
parent = config_key.partition('.0')[0]
|
||||
if not self.config[parent]:
|
||||
self.config.put(parent, [])
|
||||
|
||||
self.config.put(parent, self.config[parent] + [ConfigTree((k, v.get()) for k, v in env_config.items())])
|
||||
continue
|
||||
|
||||
value = env_config.get()
|
||||
if not value:
|
||||
continue
|
||||
@@ -115,7 +150,7 @@ class Session(_Session):
|
||||
from trains_agent.helper.package.requirements import RequirementsManager
|
||||
agent = self.config['agent']
|
||||
agent['cuda_version'], agent['cudnn_version'] = \
|
||||
RequirementsManager.get_cuda_version(self.config)
|
||||
RequirementsManager.get_cuda_version(self.config) if not cpu_only else ('0', '0')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -133,9 +168,16 @@ class Session(_Session):
|
||||
logger.propagate = True
|
||||
return TrainsAgentLogger(logger)
|
||||
|
||||
@staticmethod
|
||||
def set_debug_mode(enable):
|
||||
if enable:
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
Session.force_debug = enable
|
||||
|
||||
@property
|
||||
def debug_mode(self):
|
||||
return self.config.get("agent.debug", False)
|
||||
return Session.force_debug or self.config.get("agent.debug", False)
|
||||
|
||||
@property
|
||||
def config_file(self):
|
||||
@@ -158,7 +200,11 @@ class Session(_Session):
|
||||
folder_keys = ('agent.venvs_dir', 'agent.vcs_cache.path',
|
||||
'agent.pip_download_cache.path',
|
||||
'agent.docker_pip_cache', 'agent.docker_apt_cache')
|
||||
singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path',)
|
||||
singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path', 'agent.docker_apt_cache')
|
||||
|
||||
if os.environ.get(ENV_TASK_EXECUTE_AS_USER):
|
||||
folder_keys = tuple(list(folder_keys) + ['sdk.storage.cache.default_base_dir'])
|
||||
singleton_folders = tuple(list(singleton_folders) + ['sdk.storage.cache.default_base_dir'])
|
||||
|
||||
for key in folder_keys:
|
||||
folder_key = ConfigValue(self.config, key)
|
||||
@@ -199,6 +245,15 @@ class Session(_Session):
|
||||
config.pop('env', None)
|
||||
if remove_secret_keys:
|
||||
recursive_remove_secrets(config, secret_keys=remove_secret_keys)
|
||||
# remove logging.loggers.urllib3.level from the print
|
||||
try:
|
||||
config['logging']['loggers']['urllib3'].pop('level', None)
|
||||
except (KeyError, TypeError, AttributeError):
|
||||
pass
|
||||
try:
|
||||
config['logging'].pop('version', None)
|
||||
except (KeyError, TypeError, AttributeError):
|
||||
pass
|
||||
config = ConfigFactory.from_dict(config)
|
||||
self.log.debug("Run by interpreter: %s", sys.executable)
|
||||
print(
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.12.0'
|
||||
__version__ = '0.16.2'
|
||||
|
||||
Reference in New Issue
Block a user