mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
194 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c | ||
|
|
5ce202cc99 | ||
|
|
d09528bc26 | ||
|
|
42d2a41dbe | ||
|
|
82be1840b0 | ||
|
|
27352c5cb6 | ||
|
|
1ea6408d41 | ||
|
|
5e095af3aa | ||
|
|
ab3dceed92 | ||
|
|
3bf5126d84 | ||
|
|
ab2ab7b23a | ||
|
|
c9184d125b | ||
|
|
0c0fdb72b9 | ||
|
|
86378053d4 | ||
|
|
b1cbba0cf1 | ||
|
|
f31526042d | ||
|
|
3f8d5bc346 | ||
|
|
11d76e7d8c | ||
|
|
e76c0fbc63 | ||
|
|
fdc9956da3 | ||
|
|
f4addaa653 | ||
|
|
667964cc82 | ||
|
|
e1309e30b7 | ||
|
|
9403942ef7 | ||
|
|
84a75d9e70 | ||
|
|
c85ab66ae6 | ||
|
|
bf7f0f646b | ||
|
|
dcdf2a3d58 | ||
|
|
f8d8fc40a6 | ||
|
|
45d434a123 | ||
|
|
1834abe5bc | ||
|
|
d6321588f3 | ||
|
|
c17b10ff1d | ||
|
|
b125a56f86 | ||
|
|
c43ce3a17b | ||
|
|
b0b09616a8 | ||
|
|
ede5586ccc | ||
|
|
a1dcdffa53 | ||
|
|
35a11db58e | ||
|
|
d9bdebefc7 | ||
|
|
f29884f05a | ||
|
|
0f72d662f8 | ||
|
|
6202219034 | ||
|
|
bb3218f65d | ||
|
|
cbcaa7c789 | ||
|
|
427322a424 | ||
|
|
0e7d7d36a9 | ||
|
|
06032a6d66 | ||
|
|
b48f4eb2eb | ||
|
|
383b2666c4 | ||
|
|
50c373cf0d | ||
|
|
394a9de5fa | ||
|
|
fb5c06e9c3 | ||
|
|
1a9bbc9420 | ||
|
|
294da32401 | ||
|
|
7f00672010 | ||
|
|
99bf89a360 | ||
|
|
6c8508eb7f | ||
|
|
69714d5b5c | ||
|
|
f9516ec7d3 | ||
|
|
6fdde93dee | ||
|
|
7afc71ec91 | ||
|
|
4595117d91 | ||
|
|
8630cc1021 | ||
|
|
135885b609 | ||
|
|
eb0865662c | ||
|
|
b7b94e7ae5 | ||
|
|
72be8bee19 | ||
|
|
0722b20c1c | ||
|
|
a392a0e6ff | ||
|
|
e22fa2f478 | ||
|
|
8b49c1ac06 | ||
|
|
da1182a405 | ||
|
|
53e995ee8c | ||
|
|
4732dc1a88 | ||
|
|
e325bcaf67 | ||
|
|
a7c30453db | ||
|
|
dedac3b2fe | ||
|
|
7d10bbdf8e | ||
|
|
72213dffa4 | ||
|
|
f778837d4b | ||
|
|
153ed6a7b7 | ||
|
|
5d279c8c5a | ||
|
|
ed910d5f6a | ||
|
|
87d2b6fa15 | ||
|
|
94cfb17291 | ||
|
|
3f641d37b7 | ||
|
|
551be12f01 | ||
|
|
b536020058 | ||
|
|
fb6fbc0a06 | ||
|
|
5ae64fd791 | ||
|
|
f9776e4319 | ||
|
|
75e736e7d5 | ||
|
|
1e4756aa1d | ||
|
|
52529d3c55 | ||
|
|
53296e8891 | ||
|
|
1c87ebc900 | ||
|
|
14d9924ea0 | ||
|
|
69f9b424c7 | ||
|
|
1a6da301a8 | ||
|
|
2728b3ed14 | ||
|
|
38284eef1f | ||
|
|
9debe1adcd | ||
|
|
cc93c15f8a | ||
|
|
2c3f0e4ba3 | ||
|
|
c48eb34d8d | ||
|
|
49515e06e1 | ||
|
|
4a1d97c02f | ||
|
|
6c6c1c3f41 | ||
|
|
0ad687008c | ||
|
|
fe3dbc92dc | ||
|
|
dc53970ff0 | ||
|
|
73592b991b | ||
|
|
47b981a993 | ||
|
|
b500bcab0b | ||
|
|
59e910db1a | ||
|
|
2ecb430f02 | ||
|
|
a08722e394 | ||
|
|
67c210d9d7 | ||
|
|
101ba540f4 | ||
|
|
82fc28d477 | ||
|
|
7b73f699d2 | ||
|
|
a7e5380f67 | ||
|
|
bcade31786 | ||
|
|
6b902f85f4 | ||
|
|
6d4c974045 | ||
|
|
2346c6f3f5 | ||
|
|
82e51b4d36 | ||
|
|
e63599254e | ||
|
|
8e7e234161 | ||
|
|
17d94b26c3 | ||
|
|
1e701becd3 | ||
|
|
18c8dd449d | ||
|
|
50031c4d6d | ||
|
|
6101dc4f11 | ||
|
|
5d17059cbe | ||
|
|
b93e843143 | ||
|
|
1a732ccd8e | ||
|
|
2ea25e498f | ||
|
|
1b1cdb34ad | ||
|
|
e171a8b523 | ||
|
|
539b76d362 | ||
|
|
64b5e1f1f0 | ||
|
|
6a1eb9cea0 | ||
|
|
24907b4eaa | ||
|
|
efc540b837 | ||
|
|
96ffc89c64 | ||
|
|
4f2564d33a | ||
|
|
70ae090cc0 | ||
|
|
4f01778961 | ||
|
|
596bdd06ec | ||
|
|
6c56d0fc33 | ||
|
|
5f0213d2de | ||
|
|
15eb00a931 | ||
|
|
becc4fb6a2 | ||
|
|
32476a216a | ||
|
|
a9ba1580dc | ||
|
|
cfcd0b22a0 | ||
|
|
780355250c | ||
|
|
fd65ad38bc | ||
|
|
e29973a0b2 | ||
|
|
c259d0883e | ||
|
|
9eab017a31 | ||
|
|
68c7f307a2 | ||
|
|
0aa5694b58 | ||
|
|
639d72c5d6 | ||
|
|
70708ecdcc | ||
|
|
dacdd5e965 | ||
|
|
c199976f70 | ||
|
|
c3e2bc5ad7 | ||
|
|
f0c900c174 | ||
|
|
1bdbc44720 | ||
|
|
c6e765bd07 | ||
|
|
c037ddd044 | ||
|
|
ffe4764f20 | ||
|
|
1681fd6bf4 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
syntax: glob
|
||||
.idea
|
||||
apierrors/errors
|
||||
static/build.json
|
||||
@@ -18,3 +19,4 @@ build
|
||||
dist
|
||||
code.tar.gz
|
||||
server/schema/services/_cache.json
|
||||
server/apierrors/errors/*
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,7 +1,7 @@
|
||||
Server Side Public License
|
||||
VERSION 1, OCTOBER 16, 2018
|
||||
|
||||
Copyright © 2018 MongoDB, Inc.
|
||||
Copyright © 2019 allegro.ai, Inc.
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
437
README.md
437
README.md
@@ -1,317 +1,216 @@
|
||||
# TRAINS Server
|
||||
# Trains Server
|
||||
|
||||
## Auto-Magical Experiment Manager & Version Control for AI
|
||||
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
|
||||
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://img.shields.io/badge/status-beta-yellow.svg)
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [TRAINS](https://github.com/allegroai/trains).
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, TRAINS is set up to work with the TRAINS demo server, which is open to anyone and resets periodically.
|
||||
In order to host your own server, you will need to install **trains-server** and point TRAINS to it.
|
||||
By default, **Trains** is set up to work with the **Trains** demo server, which is open to anyone and resets periodically.
|
||||
In order to host your own server, you will need to launch **trains-server** and point **Trains** to it.
|
||||
|
||||
**trains-server** contains the following components:
|
||||
|
||||
* The TRAINS Web-App, a single-page UI for experiment management and browsing
|
||||
* The **Trains** Web-App, a single-page UI for experiment management and browsing
|
||||
* RESTful API for:
|
||||
* Documenting and logging experiment information, statistics and results
|
||||
* Querying experiments history, logs and results
|
||||
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
|
||||
|
||||
You can quickly setup your **trains-server** using a pre-built Docker image (see [Installation](#installation)).
|
||||
You can quickly [deploy](#launching-trains-server) your **trains-server** using Docker, AWS EC2 AMI, or Kubernetes.
|
||||
|
||||
When new releases are available, you can upgrade your pre-built Docker image (see [Upgrade](#upgrade)).
|
||||
|
||||
## System diagram
|
||||
## System design
|
||||
|
||||
|
||||

|
||||
|
||||
**trains-server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
- Web application on port 8080
|
||||
- API service on port 8008
|
||||
- File storage service on port 8081
|
||||
|
||||
## Install / Upgrade - AWS
|
||||
|
||||
Use our pre-installed Amazon Machine Image for easy deployment in AWS.
|
||||
|
||||
Details and instructions can be found [here](docs/install_aws.md).
|
||||
|
||||
## Installation - Docker
|
||||
|
||||
This section contains the instructions to setup and launch a pre-built Docker image for the **trains-server**.
|
||||
This is the quickest way to get started with your own server.
|
||||
Alternatively, you can build the entire trains-server architecture using the code available in our repositories.
|
||||
|
||||
**Please Note**:
|
||||
* This Docker image was tested with Linux, only. For Windows users, we recommend running the server
|
||||
on a Linux virtual machine.
|
||||
|
||||
* All command-line instructions below assume you're using `bash`.
|
||||
- Sub-Domain configuration with default http/s ports (80 or 443)
|
||||
- Web application on sub-domain: app.\*.\*
|
||||
- API service on sub-domain: api.\*.\*
|
||||
- File storage service on sub-domain: files.\*.\*
|
||||
|
||||
## Launching trains-server
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Make sure you are logged in as a user with sudo privileges.
|
||||
The ports 8080/8081/8008 must be available for the **trains-server** services.
|
||||
|
||||
For example, to see if port `8080` is in use:
|
||||
|
||||
### Setup
|
||||
* Linux or macOS:
|
||||
|
||||
sudo lsof -Pn -i4 | grep :8080 | grep LISTEN
|
||||
|
||||
#### Step 1: Install Docker CE
|
||||
* Windows:
|
||||
|
||||
In order to run the pre-packaged **trains-server**, install Docker.
|
||||
|
||||
* See [Supported platforms](https://docs.docker.com/install//#support) in the Docker documentation for instructions
|
||||
|
||||
* For example, to install in [Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/) / Mint (x86_64/amd64):
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y apt-transport-https ca-certificates curl software-properties-common
|
||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
|
||||
. /etc/os-release
|
||||
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $UBUNTU_CODENAME stable"
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y docker-ce
|
||||
```
|
||||
|
||||
#### Step 2: Setup the Docker daemon
|
||||
|
||||
To run the ElasticSearch Docker container, setup the Docker daemon by modifying the default
|
||||
values required by Elastic in your Docker configuration file (see [Notes for production use and defaults](https://www.elastic.co/guide/en/elasticsearch/reference/master/docker.html#_notes_for_production_use_and_defaults)). We provide instructions for the most common Docker configuration files.
|
||||
|
||||
Edit or create the Docker configuration file:
|
||||
|
||||
* If your system contains a `/etc/sysconfig/docker` Docker configuration file, edit it.
|
||||
|
||||
Add the options in quotes to the available arguments in the `OPTIONS` section:
|
||||
|
||||
```bash
|
||||
OPTIONS="--default-ulimit nofile=1024:65536 --default-ulimit memlock=-1:-1"
|
||||
```
|
||||
|
||||
* Otherwise, edit `/etc/docker/daemon.json` (if it exists) or create it (if it does not exist).
|
||||
|
||||
Add or modify the `defaults-ulimits` section as shown below. Be sure the `defaults-ulimits` section contains the `nofile` and `memlock` sub-sections and values shown.
|
||||
|
||||
**Note**: Your configuration file may contain other sections. If so, confirm that the sections are separated by commas (valid JSON format). For more information about Docker configuration files, see [Daemon configuration file](https://docs.docker.com/engine/reference/commandline/dockerd/#daemon-configuration-file) in the Docker documentation.
|
||||
|
||||
The **trains-server** required defaults values are:
|
||||
|
||||
```json
|
||||
{
|
||||
"default-ulimits": {
|
||||
"nofile": {
|
||||
"name": "nofile",
|
||||
"hard": 65536,
|
||||
"soft": 1024
|
||||
},
|
||||
"memlock":
|
||||
{
|
||||
"name": "memlock",
|
||||
"soft": -1,
|
||||
"hard": -1
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Step 3: Restart the Docker daemon
|
||||
|
||||
After modifying the configuration file, restart the Docker daemon:
|
||||
|
||||
```bash
|
||||
sudo service docker stop
|
||||
sudo service docker start
|
||||
```
|
||||
|
||||
#### Step 4: Set the Maximum Number of Memory Map Areas
|
||||
|
||||
The maximum number of memory map areas a process can use is defined
|
||||
using the `vm.max_map_count` kernel setting.
|
||||
|
||||
Elastic requires that `vm.max_map_count` is at least 262144 (see [Production mode](https://www.elastic.co/guide/en/elasticsearch/reference/master/docker.html#docker-cli-run-prod-mode)).
|
||||
|
||||
* For CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19 users, we tested the following commands to set
|
||||
`vm.max_map_count`:
|
||||
|
||||
```bash
|
||||
sudo echo "vm.max_map_count=262144" > /tmp/99-trains.conf
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
```
|
||||
|
||||
* For information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
|
||||
|
||||
#### Step 5: Choose a Data Directory
|
||||
|
||||
Choose a directory on your system in which all data maintained by the **trains-server** is stored.
|
||||
Create this directory, and set its owner and group to `uid` 1000. The data stored in this directory will include the database, uploaded files and logs.
|
||||
|
||||
For example, if your data directory is `/opt/trains`, then use the following command:
|
||||
|
||||
```bash
|
||||
sudo mkdir -p /opt/trains/data/elastic && sudo chown -R 1000:1000 /opt/trains
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
The **trains-server** default configuration can be easily overridden using external configuration files. By default, the server will look for these files in `/opt/trains/config`.
|
||||
|
||||
If the configuration is changed while the server is running, the server should be restarted for changes to take effect.
|
||||
|
||||
|
||||
#### Fixed users mode (basic users management)
|
||||
|
||||
In this mode, the server authenticates users based only on a pre-configured users list.
|
||||
|
||||
Enable this feature by placing `apiserver.conf` file under `/opt/trains/config`.
|
||||
|
||||
**Notice**: In order for the changes to take effect, the *trains-apiserver* (docker) will need to be restarted.
|
||||
|
||||
Sample fixed user configuration file `/opt/trains/config/apiserver.conf`:
|
||||
|
||||
auth {
|
||||
# Fixed users login credetials
|
||||
# No other user will be able to login
|
||||
fixed_users {
|
||||
enabled: true
|
||||
users: [
|
||||
{
|
||||
username: "jane"
|
||||
password: "12345678"
|
||||
name: "Jane Doe"
|
||||
},
|
||||
{
|
||||
username: "john"
|
||||
password: "12345678"
|
||||
name: "John Doe"
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#### Non-responsive experiments watchdog
|
||||
|
||||
This watchdog monitors experiments that were not updated for a given period of time, and marks them as `stopped`. The watchdog is always active.
|
||||
|
||||
To change the watchdog's timeouts, place a `services.conf` file under `/opt/trains/config`, containing for example:
|
||||
|
||||
tasks {
|
||||
non_responsive_tasks_watchdog {
|
||||
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
# Watchdog will sleep for this number of seconds after each cycle
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
}
|
||||
|
||||
### Launching Docker Containers
|
||||
|
||||
**Note**:
|
||||
* If your data directory is not `/opt/trains`, please find and replace `/opt/trains` in the following commands with your data directory path
|
||||
|
||||
* Make sure ports `8008`, `8080` and `8081` are not in use before starting the docker containers, as the containers will fail to initialize if these ports are already taken. If the following commands shows no output, the ports are available:
|
||||
```bash
|
||||
sudo netstat -tplna | egrep "8008|8080|8081"
|
||||
```
|
||||
|
||||
To launch the Docker containers, use the following commands:
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-elastic" -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/config:/opt/trains/config allegroai/trains:latest apiserver
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-webserver" -p 8080:80 allegroai/trains:latest webserver
|
||||
```
|
||||
|
||||
After the **trains-server** Dockers are up, the following are available:
|
||||
|
||||
* API server on port `8008`
|
||||
* Web server on port `8080`
|
||||
* File server on port `8081`
|
||||
|
||||
### Configuring **TRAINS** client
|
||||
|
||||
Once you have installed the **trains-server**, make sure to configure **TRAINS** [client](https://github.com/allegroai/trains) to use your locally installed server (and not the demo server).
|
||||
|
||||
If you have already installed **TRAINS**, run the `trains-init` command for an interactive setup or edit your `trains.conf` file and make sure the `api.host` value is configured as follows:
|
||||
|
||||
api {
|
||||
host: "http://localhost:8008"
|
||||
}
|
||||
|
||||
See [Installing and Configuring TRAINS](https://github.com/allegroai/trains#installing-and-configuring-trains) for more details.
|
||||
|
||||
## What next?
|
||||
|
||||
Now that the **trains-server** is installed, and TRAINS is configured to use it,
|
||||
you can [use](https://github.com/allegroai/trains#using-trains) TRAINS in your experiments and view them in the web server,
|
||||
for example http://localhost:8080
|
||||
|
||||
## Upgrade
|
||||
|
||||
We are constantly updating, improving and adding to the **trains-server**.
|
||||
New releases will include new pre-built Docker images.
|
||||
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
|
||||
|
||||
1. Shut down and remove each of your Docker instances using the following commands:
|
||||
|
||||
sudo docker stop <docker-name>
|
||||
sudo docker rm -v <docker-name>
|
||||
|
||||
The Docker names are (see [Launching Docker Containers](#launching-docker-containers)):
|
||||
|
||||
* `trains-elastic`
|
||||
* `trains-mongo`
|
||||
* `trains-fileserver`
|
||||
* `trains-apiserver`
|
||||
* `trains-webserver`
|
||||
|
||||
2. Pull the new **trains-server** docker image using the following command:
|
||||
|
||||
sudo docker pull allegroai/trains:latest
|
||||
netstat -an |find /i "8080"
|
||||
|
||||
### Launching
|
||||
|
||||
If you wish to pull a different version, replace `latest` with the required version number, for example:
|
||||
Launch **trains-server** in any of the following formats:
|
||||
|
||||
sudo docker pull allegroai/trains:0.10.0
|
||||
|
||||
3. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
|
||||
For example, if your data directory is `/opt/trains`, use the following command:
|
||||
## Connecting Trains to your trains-server
|
||||
|
||||
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
|
||||
By default, the **Trains** client is set up to work with the [**Trains** demo server](https://demoapp.trains.allegro.ai/).
|
||||
To have the **Trains** client use your **trains-server** instead:
|
||||
- Run the `trains-init` command for an interactive setup.
|
||||
- Or manually edit `~/trains.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
This back ups all data to an archive in your home directory.
|
||||
api {
|
||||
# API server on port 8008
|
||||
api_server: "http://localhost:8008"
|
||||
|
||||
To restore this example backup, use the following command:
|
||||
# web_server on port 8080
|
||||
web_server: "http://localhost:8080"
|
||||
|
||||
sudo rm -R /opt/trains/data
|
||||
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
|
||||
# file server on port 8081
|
||||
files_server: "http://localhost:8081"
|
||||
}
|
||||
|
||||
4. Launch the newly released Docker image (see [Launching Docker Containers](#launching-docker-containers)).
|
||||
**Note**: If you have set up **trains-server** in a sub-domain configuration, then there is no need to specify a port number,
|
||||
it will be inferred from the http/s scheme.
|
||||
|
||||
After launching the **trains-server** and configuring the **Trains** client to use the **trains-server**,
|
||||
you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in your experiments and view them in your **trains-server** web server,
|
||||
for example http://localhost:8080.
|
||||
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
|
||||
|
||||
## Trains-Agent Services <a name="services"></a>
|
||||
|
||||
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
## Advanced Functionality
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
|
||||
|
||||
## Restarting trains-server
|
||||
|
||||
To restart the **trains-server**, you must first stop the containers, and then restart them.
|
||||
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
## Upgrading <a name="upgrade"></a>
|
||||
|
||||
**trains-server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker-compose.yml).
|
||||
We strongly encourage you to keep your **trains-server** up to date, by keeping up with the current release.
|
||||
|
||||
**Note**: The following upgrade instructions use the Linux OS as an example.
|
||||
|
||||
To upgrade your existing **trains-server** deployment:
|
||||
|
||||
1. Shut down the docker containers
|
||||
```bash
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
1. We highly recommend backing up your data directory before upgrading.
|
||||
|
||||
Assuming your data directory is `/opt/trains`, to archive all data into `~/trains_backup.tgz` execute:
|
||||
|
||||
```bash
|
||||
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Restore instructions:</summary>
|
||||
|
||||
To restore this example backup, execute:
|
||||
```bash
|
||||
sudo rm -R /opt/trains/data
|
||||
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
|
||||
```
|
||||
</details>
|
||||
|
||||
1. Download the latest `docker-compose.yml` file.
|
||||
|
||||
```bash
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Configure the Trains-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
|
||||
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
|
||||
Additionally, you can always find us at *trains@allegro.ai*
|
||||
|
||||
## License
|
||||
|
||||
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
|
||||
|
||||
**trains-server** relies on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
|
||||
member of the community to support the projects we love and cherish.
|
||||
We believe the cause for the license change in both cases is more than just,
|
||||
We believe the cause for the license change in both cases is more than just,
|
||||
and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more general and flexible of the two licenses.
|
||||
|
||||
This is our way to say - we support you guys!
|
||||
|
||||
88
docker-compose-unified.yml
Normal file
88
docker-compose-unified.yml
Normal file
@@ -0,0 +1,88 @@
|
||||
version: "3.6"
|
||||
services:
|
||||
trainsserver:
|
||||
command:
|
||||
- -c
|
||||
- "echo \"#!/bin/bash\" > /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh webserver&\" >> /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh fileserver&\" >> /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh apiserver\" >> /opt/trains/all.sh && cat /opt/trains/all.sh && chmod +x /opt/trains/all.sh && /opt/trains/all.sh"
|
||||
entrypoint: /bin/bash
|
||||
container_name: trains-server
|
||||
image: allegroai/trains:latest
|
||||
ports:
|
||||
- 8008:8008
|
||||
- 8080:80
|
||||
- 8081:8081
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
- elasticsearch
|
||||
environment:
|
||||
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
TRAINS_ELASTIC_SERVICE_PORT: 9200
|
||||
TRAINS_MONGODB_SERVICE_HOST: mongo
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
networks:
|
||||
- backend
|
||||
elasticsearch:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
memlock:
|
||||
soft: -1
|
||||
hard: -1
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
mongo:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-mongo
|
||||
image: mongo:3.6.5
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- /opt/trains/data/mongo/db:/data/db
|
||||
- /opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
redis:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-redis
|
||||
image: redis:5.0
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/redis:/data
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
119
docker-compose-win10.yml
Normal file
119
docker-compose-win10.yml
Normal file
@@ -0,0 +1,119 @@
|
||||
version: "3.6"
|
||||
services:
|
||||
|
||||
apiserver:
|
||||
command:
|
||||
- apiserver
|
||||
container_name: trains-apiserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/logs:/var/log/trains
|
||||
- c:/opt/trains/config:/opt/trains/config
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
- elasticsearch
|
||||
- fileserver
|
||||
environment:
|
||||
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
TRAINS_ELASTIC_SERVICE_PORT: 9200
|
||||
TRAINS_MONGODB_SERVICE_HOST: mongo
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- backend
|
||||
|
||||
elasticsearch:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
memlock:
|
||||
soft: -1
|
||||
hard: -1
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
fileserver:
|
||||
networks:
|
||||
- backend
|
||||
command:
|
||||
- fileserver
|
||||
container_name: trains-fileserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/logs:/var/log/trains
|
||||
- c:/opt/trains/data/fileserver:/mnt/fileserver
|
||||
- c:/opt/trains/config:/opt/trains/config
|
||||
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
mongo:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-mongo
|
||||
image: mongo:3.6.5
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- c:/opt/trains/data/mongo/db:/data/db
|
||||
- c:/opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
|
||||
redis:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-redis
|
||||
image: redis:5.0
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/data/redis:/data
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
webserver:
|
||||
command:
|
||||
- webserver
|
||||
container_name: trains-webserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/trains/logs:/var/log/trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
@@ -1,17 +1,40 @@
|
||||
version: "3.6"
|
||||
services:
|
||||
|
||||
apiserver:
|
||||
command:
|
||||
- apiserver
|
||||
container_name: trains-apiserver
|
||||
image: allegroai/trains:latest
|
||||
network_mode: host
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- type: bind
|
||||
source: /opt/trains/logs
|
||||
target: /var/log/trains
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
- elasticsearch
|
||||
- fileserver
|
||||
environment:
|
||||
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
TRAINS_ELASTIC_SERVICE_PORT: 9200
|
||||
TRAINS_MONGODB_SERVICE_HOST: mongo
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- backend
|
||||
|
||||
elasticsearch:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
@@ -19,58 +42,109 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
network_mode: host
|
||||
restart: always
|
||||
ulimits:
|
||||
memlock:
|
||||
soft: -1
|
||||
hard: -1
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- type: bind
|
||||
source: /opt/trains/data/elastic
|
||||
target: /usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
fileserver:
|
||||
networks:
|
||||
- backend
|
||||
command:
|
||||
- fileserver
|
||||
container_name: trains-fileserver
|
||||
image: allegroai/trains:latest
|
||||
network_mode: host
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- type: bind
|
||||
source: /opt/trains/logs
|
||||
target: /var/log/trains
|
||||
- type: bind
|
||||
source: /opt/trains/data/fileserver
|
||||
target: /mnt/fileserver
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
mongo:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-mongo
|
||||
image: mongo:3.6.5
|
||||
network_mode: host
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- type: bind
|
||||
source: /opt/trains/data/mongo/db
|
||||
target: /data/db
|
||||
- type: bind
|
||||
source: /opt/trains/data/mongo/configdb
|
||||
target: /data/configdb
|
||||
- /opt/trains/data/mongo/db:/data/db
|
||||
- /opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
|
||||
redis:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-redis
|
||||
image: redis:5.0
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/redis:/data
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
webserver:
|
||||
command:
|
||||
- webserver
|
||||
container_name: trains-webserver
|
||||
image: allegroai/trains:latest
|
||||
network_mode: host
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- apiserver
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
agent-services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-agent-services
|
||||
image: allegroai/trains-agent-services:latest
|
||||
restart: unless-stopped
|
||||
privileged: true
|
||||
environment:
|
||||
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
|
||||
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
|
||||
TRAINS_API_HOST: http://apiserver:8008
|
||||
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
|
||||
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
|
||||
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
|
||||
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
|
||||
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
|
||||
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
|
||||
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
|
||||
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
|
||||
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
TRAINS_WORKER_ID: "trains-services"
|
||||
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
|
||||
volumes:
|
||||
- type: bind
|
||||
source: /opt/trains/logs
|
||||
target: /var/log/trains
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/trains/agent:/root/.trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
19
docs/apiserver.conf
Normal file
19
docs/apiserver.conf
Normal file
@@ -0,0 +1,19 @@
|
||||
auth {
|
||||
# Fixed users login credentials
|
||||
# No other user will be able to login
|
||||
fixed_users {
|
||||
enabled: true
|
||||
users: [
|
||||
{
|
||||
username: "jane"
|
||||
password: "12345678"
|
||||
name: "Jane Doe"
|
||||
},
|
||||
{
|
||||
username: "john"
|
||||
password: "12345678"
|
||||
name: "John Doe"
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
328
docs/faq.md
Normal file
328
docs/faq.md
Normal file
@@ -0,0 +1,328 @@
|
||||
# trains-server FAQ
|
||||
|
||||
Launching **trains-server**
|
||||
|
||||
* How do I launch **trains-server** on:
|
||||
|
||||
* [Stand alone Linux Ubuntu systems?](#ubuntu)
|
||||
|
||||
* [macOS?](#mac-osx)
|
||||
|
||||
* [Windows 10?](#docker_compose_win10)
|
||||
|
||||
* [How do I restart trains-server?](#restart)
|
||||
|
||||
Kubernetes
|
||||
|
||||
* [Can I deploy trains-server on Kubernetes clusters?](#kubernetes)
|
||||
|
||||
* [Can I create a Helm Chart for trains-server Kubernetes deployment?](#helm)
|
||||
|
||||
Configuration
|
||||
|
||||
* [How do I configure trains-server for sub-domains and load balancers?](#sub-domains)
|
||||
|
||||
* [Can I add web login authentication to trains-server?](#web-auth)
|
||||
|
||||
* [Can I modify the non-responsive experiment watchdog settings?](#watchdog)
|
||||
|
||||
Troubleshooting
|
||||
|
||||
* [How do I fix Docker upgrade errors?](#common-docker-upgrade-errors)
|
||||
|
||||
* [Why is web login authentication not working?](#port-conflict)
|
||||
|
||||
## Launching **trains-server**
|
||||
|
||||
### How do I launch trains-server on stand alone Linux Ubuntu systems? <a name="ubuntu"></a>
|
||||
|
||||
To launch **trains-server** on a stand alone Linux Ubuntu:
|
||||
|
||||
1. Install [docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
|
||||
|
||||
1. Install `docker-compose` using the following commands (for more detailed information, see the [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
|
||||
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
||||
sudo chmod +x /usr/local/bin/docker-compose
|
||||
|
||||
1. Remove the previous installation of **trains-server**.
|
||||
|
||||
**WARNING**: This clears all existing **Trains** databases.
|
||||
|
||||
sudo rm -R /opt/trains/
|
||||
|
||||
1. Create local directories for the databases and storage.
|
||||
|
||||
sudo mkdir -p /opt/trains/data/elastic
|
||||
sudo mkdir -p /opt/trains/data/mongo/db
|
||||
sudo mkdir -p /opt/trains/data/mongo/configdb
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
sudo chown -R 1000:1000 /opt/trains
|
||||
|
||||
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
|
||||
|
||||
git clone https://github.com/allegroai/trains-server.git
|
||||
cd trains-server
|
||||
|
||||
1. Run `docker-compose`
|
||||
|
||||
/usr/local/bin/docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080)
|
||||
|
||||
### How do I launch trains-server on macOS? <a name="mac-osx"></a>
|
||||
|
||||
To launch **trains-server** on macOS:
|
||||
|
||||
1. Install [docker for macOS](https://docs.docker.com/docker-for-mac/install/).
|
||||
|
||||
1. Configure [Docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode).
|
||||
|
||||
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
|
||||
sysctl -w vm.max_map_count=262144
|
||||
|
||||
1. Create local directories for the databases and storage.
|
||||
|
||||
sudo mkdir -p /opt/trains/data/elastic
|
||||
sudo mkdir -p /opt/trains/data/mongo/db
|
||||
sudo mkdir -p /opt/trains/data/mongo/configdb
|
||||
sudo mkdir -p /opt/trains/data/redis
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
sudo chown -R $(whoami):staff /opt/trains
|
||||
|
||||
1. Open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
|
||||
|
||||
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
|
||||
|
||||
git clone https://github.com/allegroai/trains-server.git
|
||||
cd trains-server
|
||||
|
||||
1. Run `docker-compose` with the docker compose file.
|
||||
|
||||
docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080)
|
||||
|
||||
### How do I launch trains-server on Windows 10? <a name="docker_compose_win10"></a>
|
||||
|
||||
You can run **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
|
||||
|
||||
To launch **trains-server** on Windows 10:
|
||||
|
||||
1. Install the Docker Desktop for Windows application by either:
|
||||
|
||||
* following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
|
||||
* running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
|
||||
|
||||
1. Increase the memory allocation in Docker Desktop to `4GB`.
|
||||
|
||||
1. In your Windows notification area (system tray), right click the Docker icon.
|
||||
|
||||
1. Click *Settings*, *Advanced*, and then set the memory to at least `4096`.
|
||||
|
||||
1. Click *Apply*.
|
||||
|
||||
1. Create local directories for data and logs. Open PowerShell and execute the following commands:
|
||||
|
||||
cd c:
|
||||
mkdir c:\opt\trains\data
|
||||
mkdir c:\opt\trains\logs
|
||||
|
||||
1. Download the **trains-server** docker-compose YAML file [docker-compose-win10.yml](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml) as `c:\opt\trains\docker-compose.yml`.
|
||||
|
||||
1. Run `docker-compose`. In PowerShell, execute the following commands:
|
||||
|
||||
docker-compose -f up docker-compose-win10.yml
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080)
|
||||
|
||||
### How do I restart trains-server? <a name="restart"></a>
|
||||
|
||||
Restart *trains-server* by first stopping the Docker containers and then restarting them.
|
||||
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose up -f docker-compose.yml
|
||||
```
|
||||
|
||||
**Note**: If you are using a different docker-compose YAML file, specify that file.
|
||||
|
||||
## Kubernetes
|
||||
|
||||
### Can I deploy trains-server on Kubernetes clusters? <a name="kubernetes"></a>
|
||||
|
||||
**trains-server** supports Kubernetes. See [trains-server-k8s](https://github.com/allegroai/trains-server-k8s)
|
||||
which contains the YAML files describing the required services and detailed instructions for deploying
|
||||
**trains-server** to a Kubernetes clusters.
|
||||
|
||||
### Can I create a Helm Chart for trains-server Kubernetes deployment? <a name="helm"></a>
|
||||
|
||||
**trains-server** supports creating a Helm chart for Kubernetes deployment. See [trains-server-helm](https://github.com/allegroai/trains-server-helm)
|
||||
which you can use to create a Helm chart for **trains-server** and contains detailed instructions for deploying
|
||||
**trains-server** to a Kubernetes clusters using Helm.
|
||||
|
||||
## Configuration
|
||||
|
||||
### How do I configure trains-server for sub-domains and load balancers? <a name="sub-domains"></a>
|
||||
|
||||
You can configure **trains-server** for sub-domains and a load balancer.
|
||||
|
||||
For example, if your domain is `trains.mydomain.com` and your sub-domains are `app` and `api`, then do the following:
|
||||
|
||||
1. If you are not using the current **trains-server** version, [upgrade](https://github.com/allegroai/trains-server#upgrade) **trains-server**.
|
||||
|
||||
1. Add the following to `/opt/trains/config/apiserver.conf`:
|
||||
|
||||
auth {
|
||||
cookies {
|
||||
httponly: true
|
||||
secure: true
|
||||
domain: ".trains.mydomain.com"
|
||||
max_age: 99999999999
|
||||
}
|
||||
}
|
||||
|
||||
1. Use the following load balancer configuration:
|
||||
|
||||
* Listeners:
|
||||
* Optional: HTTP listener, that redirects all traffic to HTTPS.
|
||||
* HTTPS listener for `app.` forwarded to `AppTargetGroup`
|
||||
* HTTPS listener for `api.` forwarded to `ApiTargetGroup`
|
||||
* HTTPS listener for `files.` forwarded to `FilesTargetGroup`
|
||||
* Target groups:
|
||||
* `AppTargetGroup`: HTTP based target group, port `8080`
|
||||
* `ApiTargetGroup`: HTTP based target group, port `8008`
|
||||
* `FilesTargetGroup`: HTTP based target group, port `8081`
|
||||
* Security and routing:
|
||||
* Load balancer: make sure the load balancers are able to receive traffic from the relevant IP addresses (Security groups and Subnets definitions).
|
||||
* Instances: make sure the load balancers are able to access the instances, using the relevant ports (Security groups definitions).
|
||||
|
||||
1. Run the Docker containers with our updated `docker run` commands (see [Launching Docker Containers](#https://github.com/allegroai/trains-server#launching-docker-containers)).
|
||||
|
||||
### Can I add web login authentication to trains-server? <a name="web-auth"></a>
|
||||
|
||||
By default, anyone can login to the **trains-server** Web-App.
|
||||
You can configure the **trains-server** to allow only a specific set of users to access the system.
|
||||
|
||||
To add web login authentication to **trains-server**:
|
||||
|
||||
1. If you are not using the current **trains-server** version, then [upgrade](https://github.com/allegroai/trains-server#upgrade).
|
||||
|
||||
1. In `/opt/trains/config/apiserver.conf`, add the `auth` section and in it specify the users, for example:
|
||||
|
||||
**Note**: A sample `apiserver.conf` configuration file is also available [here](https://github.com/allegroai/trains-server/blob/master/docs/apiserver.conf).
|
||||
|
||||
auth {
|
||||
# Fixed users login credentials
|
||||
# No other user will be able to login
|
||||
fixed_users {
|
||||
enabled: true
|
||||
users: [
|
||||
{
|
||||
username: "jane"
|
||||
password: "12345678"
|
||||
name: "Jane Doe"
|
||||
},
|
||||
{
|
||||
username: "john"
|
||||
password: "12345678"
|
||||
name: "John Doe"
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
1. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).
|
||||
|
||||
### Can I modify the experiment watchdog settings? <a name="watchdog"></a>
|
||||
|
||||
The non-responsive experiment watchdog monitors experiments that were not updated for a specified period of time
|
||||
and marks them as `aborted`. The watchdog is always active.
|
||||
|
||||
You can modify the following settings for the watchdog:
|
||||
|
||||
* the time threshold (in seconds) of experiment inactivity (default value is 7200 seconds (2 hours))
|
||||
* the time interval (in seconds) between watchdog cycles
|
||||
|
||||
To change the watchdog's settings:
|
||||
|
||||
1. In `/opt/trains/config`, add the `services.conf` file and in it specify the watchdog settings, for example:
|
||||
|
||||
**Note**: A sample watchdog `services.conf` configuration file is also available [here](https://github.com/allegroai/trains-server/blob/master/docs/services.conf).
|
||||
|
||||
tasks {
|
||||
non_responsive_tasks_watchdog {
|
||||
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
# Watchdog will sleep for this number of seconds after each cycle
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
}
|
||||
|
||||
1. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### How do I fix Docker upgrade errors? <a name="common-docker-upgrade-errors"></a>
|
||||
|
||||
To resolve the Docker error "... The container name "/trains-???" is already in use by ...", try removing deprecated images:
|
||||
|
||||
docker rm -f $(docker ps -a -q)
|
||||
|
||||
### Why is web login authentication not working?
|
||||
|
||||
A port conflict between the **trains-server** MongoDB and / or Elastic instances, and other
|
||||
instances running on your system may prevent web login authentication
|
||||
from working correctly.
|
||||
|
||||
**trains-server** uses the following default ports which may be in conflict with other instances:
|
||||
|
||||
* MongoDB port `27017`
|
||||
* Elastic port `9200`
|
||||
|
||||
You can check for port conflicts in the logs in `/opt/trains/log`.
|
||||
|
||||
If a port conflict occurs, change the MongoDB and / or Elastic ports in the `docker-compose.yml`,
|
||||
and then run the Docker compose commands to restart the **trains-server** instance.
|
||||
|
||||
To change the MongoDB and / or Elastic ports for **trains-server**:
|
||||
|
||||
1. Edit the `docker-compose.yml` file.
|
||||
|
||||
1. In the `services/trainsserver/environment` section, add the following environment variable(s):
|
||||
|
||||
* For MongoDB:
|
||||
|
||||
MONGODB_SERVICE_PORT: <new-mongodb-port>
|
||||
|
||||
* For Elastic:
|
||||
|
||||
ELASTIC_SERVICE_PORT: <new-elasticsearch-port>
|
||||
|
||||
For example:
|
||||
|
||||
MONGODB_SERVICE_PORT: 27018
|
||||
ELASTIC_SERVICE_PORT: 9201
|
||||
|
||||
1. For MongoDB, in the `services/mongo/ports` section, expose the new MongoDB port:
|
||||
|
||||
<new-mongodb-port>:27017
|
||||
|
||||
For example:
|
||||
|
||||
20718:27017
|
||||
|
||||
1. For Elastic, in the `services/elasticsearch/ports` section, expose the new Elastic port:
|
||||
|
||||
<new-elsticsearch-port>:9200
|
||||
|
||||
For example:
|
||||
|
||||
9201:9200
|
||||
|
||||
2. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).
|
||||
@@ -1,49 +1,265 @@
|
||||
# **TRAINS-server**: AWS pre-installed images
|
||||
# Deploying **trains-server** on AWS
|
||||
|
||||
In order to easily deploy **trains-server** on AWS, we created the following Amazon Machine Images (AMIs).
|
||||
To easily deploy **trains-server** on AWS, use one of our pre-built Amazon Machine Images (AMIs).
|
||||
We provide AMIs per region for each released version of **trains-server**, see [Released versions](#released-versions) below.
|
||||
|
||||
Service port numbers on these AMIs are:
|
||||
- Web: 8080
|
||||
- API: 8008
|
||||
- File Server: 8081
|
||||
Once the AMI is up and running, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
The service port numbers on our **trains-server** AMIs:
|
||||
|
||||
Persistent storage configuration:
|
||||
- MongoDB: /opt/trains/data/mongo/
|
||||
- ElasticSearch: /opt/trains/data/elastic/
|
||||
- File Server: /mnt/fileserver/
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
Instructions on launching a custom AMI from the EC2 console can be found [here](https://aws.amazon.com/premiumsupport/knowledge-center/launch-instance-custom-ami/)
|
||||
and a detailed version [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/launching-instance.html).
|
||||
The persistent storage configuration:
|
||||
|
||||
The minimum recommended instance type is **t3a.large**
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
- ElasticSearch: `/opt/trains/data/elastic/`
|
||||
- File Server: `/mnt/fileserver/`
|
||||
|
||||
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
|
||||
|
||||
For instructions on launching a custom AMI from the EC2 console, see the [AWS Knowledge Center](https://aws.amazon.com/premiumsupport/knowledge-center/launch-instance-custom-ami/) or detailed instructions in the [AWS Documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/launching-instance.html).
|
||||
|
||||
The minimum recommended amount of RAM is 8GB. For example, **t3.large** or **t3a.large** would have the minimum recommended amount of resources.
|
||||
|
||||
## Upgrading
|
||||
|
||||
In order to upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
To upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
### Note on upgrading AMIs to v0.12
|
||||
|
||||
This upgrade includes the automatically updated AMI in Version 0.12. It also includes an additional REDIS docker to the **trains-server** setup.
|
||||
|
||||
To upgrade the AMI:
|
||||
|
||||
1. SSH to the EC2 machine running one of the `Latest Version AMI's`
|
||||
2. Execute the following bash commands
|
||||
```bash
|
||||
sudo bash
|
||||
echo "" >> /usr/bin/start_or_update_server.sh
|
||||
echo "sudo mkdir -p \${datadir}/redis" >> /usr/bin/start_or_update_server.sh
|
||||
echo "sudo docker stop trains-redis || true && sudo docker rm -v trains-redis || true" >> /usr/bin/start_or_update_server.sh
|
||||
echo "echo never | sudo tee -a /sys/kernel/mm/transparent_hugepage/enabled" >> /usr/bin/start_or_update_server.sh
|
||||
echo "sudo sysctl vm.overcommit_memory=1" >> /usr/bin/start_or_update_server.sh
|
||||
echo "sudo docker run -d --restart=always --name=trains-redis -v \${datadir}/redis:/data --network=host redis:5 redis-server" >> /usr/bin/start_or_update_server.sh
|
||||
```
|
||||
3. Reboot the EC2 machine
|
||||
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections provide a list containing AMI Image ID per region for each released **trains-server** version.
|
||||
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
|
||||
|
||||
### v0.10.1
|
||||
* **eu-north-1** : ami-075179c6226cd9095
|
||||
* **ap-south-1** : ami-060f4ed126053eb04
|
||||
* **eu-west-3** : ami-064db17c89563ab4d
|
||||
* **eu-west-2** : ami-03fd7ed12335a28a7
|
||||
* **eu-west-1** : ami-0ccc70c4a240ad561
|
||||
* **ap-northeast-2** : ami-0a396c6a9aa957f5a
|
||||
* **ap-northeast-1** : ami-03e2f80d6f507dd01
|
||||
* **sa-east-1** : ami-05860b0f84623709f
|
||||
* **ca-central-1** : ami-09070d5996ab859bd
|
||||
* **ap-southeast-1** : ami-0f6837004376bcf17
|
||||
* **ap-southeast-2** : ami-0fedf36691cdbe1ae
|
||||
* **eu-central-1** : ami-075ffc9f0c89fe67f
|
||||
* **us-east-2** : ami-0e762227086ba49e4
|
||||
* **us-west-1** : ami-06845a6c0f1c9bd13
|
||||
* **us-west-2** : ami-03ddd7af1dcab9bf9
|
||||
* **us-east-1** : ami-0e599e219060f7def
|
||||
### Latest version AMI - v0.15.1 (auto update)<a name="autoupdate"></a>
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-0f30c84b905d354b9
|
||||
* **ap-south-1** : ami-050e7acec52c8c74e
|
||||
* **eu-west-3** : ami-03911c5b5bc77ef75
|
||||
* **eu-west-2** : ami-0a5ed8aa2573ccc70
|
||||
* **eu-west-1** : ami-0a53c65e922ec0611
|
||||
* **ap-northeast-2** : ami-08cd017a37b8e8aab
|
||||
* **ap-northeast-1** : ami-056b3ca1ad5af9322
|
||||
* **sa-east-1** : ami-01ddc9325bafb400c
|
||||
* **ca-central-1** : ami-0fc3cbbd982b18b45
|
||||
* **ap-southeast-1** : ami-04c7a358df7002ef5
|
||||
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
|
||||
* **eu-central-1** : ami-00b8e44041f8175fd
|
||||
* **us-east-2** : ami-0ac7deebb3f738f6d
|
||||
* **us-west-1** : ami-06bc07deb8b8c44d6
|
||||
* **us-west-2** : ami-01ba85ffe79a422f1
|
||||
* **us-east-1** : ami-04cf5a66cb4928ac3
|
||||
|
||||
### v0.15.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0cd314e267426d1b7
|
||||
* **ap-south-1** : ami-086182cbe29151f96
|
||||
* **eu-west-3** : ami-0062366012182815b
|
||||
* **eu-west-2** : ami-022b8f2e32a9d18d0
|
||||
* **eu-west-1** : ami-0d8cf60446e09aa3d
|
||||
* **ap-northeast-2** : ami-0d4c168a815b56889
|
||||
* **ap-northeast-1** : ami-0daf7887db1053ae4
|
||||
* **sa-east-1** : ami-020a759a3ba4ff22b
|
||||
* **ca-central-1** : ami-0c10b5e04b707f3e3
|
||||
* **ap-southeast-1** : ami-0f61bb3529a165fcd
|
||||
* **ap-southeast-2** : ami-032dcdc82749c66c5
|
||||
* **eu-central-1** : ami-08f364f32d2eb3bae
|
||||
* **us-east-2** : ami-0b7efc3591803eba4
|
||||
* **us-west-1** : ami-08b2df27b0ada6faf
|
||||
* **us-west-2** : ami-0693029c4bad28816
|
||||
* **us-east-1** : ami-0200954fa9c2819ff
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0bef15c03eab64c0c
|
||||
* **ap-south-1** : ami-06ac6248e583e2cd2
|
||||
* **eu-west-3** : ami-0541d86ef47a5714e
|
||||
* **eu-west-2** : ami-01381ef4c4ed22482
|
||||
* **eu-west-1** : ami-064626a0dd38b21f1
|
||||
* **ap-northeast-2** : ami-0a2490a7a3a8aa675
|
||||
* **ap-northeast-1** : ami-063f1de819a2524b8
|
||||
* **sa-east-1** : ami-07980486741b94987
|
||||
* **ca-central-1** : ami-0ced3b8b21ded839e
|
||||
* **ap-southeast-1** : ami-0c493c5093fde8741
|
||||
* **ap-southeast-2** : ami-0320a727eccb8dc6c
|
||||
* **eu-central-1** : ami-0aa85cfc78674c526
|
||||
* **us-east-2** : ami-01791485051e1880c
|
||||
* **us-west-1** : ami-0d8eade4d5888ea73
|
||||
* **us-west-2** : ami-02ceaef72cdf60f7e
|
||||
* **us-east-1** : ami-0fc3f9d1d0eba1d62
|
||||
|
||||
### v0.14.2 (static update)
|
||||
|
||||
* **eu-north-1** : ami-006d491e9e8869248
|
||||
* **ap-south-1** : ami-0e55ec221687f98e7
|
||||
* **eu-west-3** : ami-06ad9cf3c05c83e91
|
||||
* **eu-west-2** : ami-0d05839268e748cff
|
||||
* **eu-west-1** : ami-0d14c297789ce0d7a
|
||||
* **ap-northeast-2** : ami-0d7fd775f0e76cc6f
|
||||
* **ap-northeast-1** : ami-0c0a6e1daeb3f7a9c
|
||||
* **sa-east-1** : ami-01e0c5e30e94ec887
|
||||
* **ca-central-1** : ami-07a31896832734897
|
||||
* **ap-southeast-1** : ami-0886d5b2d4b7fccd5
|
||||
* **ap-southeast-2** : ami-0397d5a2db3c356fe
|
||||
* **eu-central-1** : ami-0629f26eea22f5c17
|
||||
* **us-east-2** : ami-0499c3d7bb45a1a6e
|
||||
* **us-west-1** : ami-02fa8a961a4daf9f0
|
||||
* **us-west-2** : ami-05c711cfab4342468
|
||||
* **us-east-1** : ami-0b97d99a08012c726
|
||||
|
||||
### v0.14.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-036defe1885dced2e
|
||||
* **ap-south-1** : ami-0b403aa1da6a5dc17
|
||||
* **eu-west-3** : ami-0d30c2d330d1255c4
|
||||
* **eu-west-2** : ami-06f0e8d075e50a029
|
||||
* **eu-west-1** : ami-0da721d874f282b6d
|
||||
* **ap-northeast-2** : ami-03bffe94675dd5f8c
|
||||
* **ap-northeast-1** : ami-0f96520d646423673
|
||||
* **sa-east-1** : ami-0c2f706a3b7d97282
|
||||
* **ca-central-1** : ami-0da74525dcfd74e32
|
||||
* **ap-southeast-1** : ami-066368a21cf6d232b
|
||||
* **ap-southeast-2** : ami-0bfd09170067f7318
|
||||
* **eu-central-1** : ami-06aa99b1c41492986
|
||||
* **us-east-2** : ami-065c1880f59d03272
|
||||
* **us-west-1** : ami-0b7f6b896f5058eba
|
||||
* **us-west-2** : ami-0041e10ca68eef29a
|
||||
* **us-east-1** : ami-0b7125e4305bbd7eb
|
||||
|
||||
### v0.14.0 (static update)
|
||||
* **eu-north-1** : ami-02de71586ec496e38
|
||||
* **ap-south-1** : ami-074b03849b51852e5
|
||||
* **eu-west-3** : ami-022c388835e0eeb03
|
||||
* **eu-west-2** : ami-0a151c236c6b27707
|
||||
* **eu-west-1** : ami-06de69b06b4e73312
|
||||
* **ap-northeast-2** : ami-0ee821b72d9f669b1
|
||||
* **ap-northeast-1** : ami-03687ae215e64e100
|
||||
* **sa-east-1** : ami-01eb83364b7f667af
|
||||
* **ca-central-1** : ami-02e9b35f9c90377e6
|
||||
* **ap-southeast-1** : ami-0d3ab5ab0048fea51
|
||||
* **ap-southeast-2** : ami-0bd39d908fe3a9e06
|
||||
* **eu-central-1** : ami-0b8638701311b35c4
|
||||
* **us-east-2** : ami-02ff039693fc3a614
|
||||
* **us-west-1** : ami-08634f7dfb608a9a7
|
||||
* **us-west-2** : ami-034d693ef742b9333
|
||||
* **us-east-1** : ami-0b828b05c323dde7f
|
||||
|
||||
### v0.13.0 (static update)
|
||||
* **eu-north-1** : ami-0d9c74a015e7510d8
|
||||
* **ap-south-1** : ami-02acd6dd0659bb5c1
|
||||
* **eu-west-3** : ami-0f0cc5cb6d9afd194
|
||||
* **eu-west-2** : ami-0298fdc0860206ed9
|
||||
* **eu-west-1** : ami-0cdc072e528401d5e
|
||||
* **ap-northeast-2** : ami-0055579cc95b0e53e
|
||||
* **ap-northeast-1** : ami-0ced7becb9b83b5d0
|
||||
* **sa-east-1** : ami-033345d0f16a1b5e4
|
||||
* **ca-central-1** : ami-06c63b05aed47ae67
|
||||
* **ap-southeast-1** : ami-09f0355f367f30602
|
||||
* **ap-southeast-2** : ami-0bd2314163ce0fba0
|
||||
* **eu-central-1** : ami-05fbae957df63e366
|
||||
* **us-east-2** : ami-050c51b5b4074d3fc
|
||||
* **us-west-1** : ami-06ad513073d4e5a19
|
||||
* **us-west-2** : ami-0c96e1361d1d4ca94
|
||||
* **us-east-1** : ami-07b669040d1eea213
|
||||
|
||||
### v0.12.1 (static update)
|
||||
* **eu-north-1** : ami-003118a8103286d84
|
||||
* **ap-south-1** : ami-02dfe86baa48e096f
|
||||
* **eu-west-3** : ami-0cc1f01267d2a780d
|
||||
* **eu-west-2** : ami-0e4c8332e5ce09585
|
||||
* **eu-west-1** : ami-03459a2f0b0a3b1ab
|
||||
* **ap-northeast-2** : ami-08f6c2aed3a53f24c
|
||||
* **ap-northeast-1** : ami-0b798eab95a7c5435
|
||||
* **sa-east-1** : ami-0d3ee166c09f0d1b2
|
||||
* **ca-central-1** : ami-00a758c56bd63acd5
|
||||
* **ap-southeast-1** : ami-0be64d4988cd03fbb
|
||||
* **ap-southeast-2** : ami-02087310d43a63f31
|
||||
* **eu-central-1** : ami-097bbefeac0c74225
|
||||
* **us-east-2** : ami-07eda256712b90f4d
|
||||
* **us-west-1** : ami-02ef2b55cbd01c7df
|
||||
* **us-west-2** : ami-037c6176ef4735360
|
||||
* **us-east-1** : ami-08715c20c0e3f1c15
|
||||
|
||||
### v0.12.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-03ff8ab48cd43e77e
|
||||
* **ap-south-1** : ami-079c1a41ff836487c
|
||||
* **eu-west-3** : ami-0121ef0398ae87ab0
|
||||
* **eu-west-2** : ami-09f0f97654d8c79de
|
||||
* **eu-west-1** : ami-0b7ba303f757bfcd9
|
||||
* **ap-northeast-2** : ami-053f416517b5f40a6
|
||||
* **ap-northeast-1** : ami-056dff06c698c2d9d
|
||||
* **sa-east-1** : ami-017ab655119258639
|
||||
* **ca-central-1** : ami-03bf5fa1d86ac97f6
|
||||
* **ap-southeast-1** : ami-0e667958002b0360c
|
||||
* **ap-southeast-2** : ami-091f1b69cb43b1933
|
||||
* **eu-central-1** : ami-068ec2f0e98c26541
|
||||
* **us-east-2** : ami-0524bbdc1b64ff83f
|
||||
* **us-west-1** : ami-0b4facd7534e393c9
|
||||
* **us-west-2** : ami-0018d5a7e58966848
|
||||
* **us-east-1** : ami-08f24178fc14a84d2
|
||||
|
||||
### v0.11.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0cbe338f058018c97
|
||||
* **ap-south-1** : ami-06d72ff894f7a5e5d
|
||||
* **eu-west-3** : ami-00f2a45d67df2d2f3
|
||||
* **eu-west-2** : ami-0627ae688f4533237
|
||||
* **eu-west-1** : ami-00bf924ccb0354418
|
||||
* **ap-northeast-2** : ami-0800edf1d1dec1da8
|
||||
* **ap-northeast-1** : ami-07b2ed9709cdc4b15
|
||||
* **sa-east-1** : ami-0012c1648618b812c
|
||||
* **ca-central-1** : ami-02870b965d002fc8a
|
||||
* **ap-southeast-1** : ami-068ec23abf2473192
|
||||
* **ap-southeast-2** : ami-06664624728b5e01a
|
||||
* **eu-central-1** : ami-05f2a9304f237a6f0
|
||||
* **us-east-2** : ami-0ec242e6dca2b72b9
|
||||
* **us-west-1** : ami-050b6577acf246ceb
|
||||
* **us-west-2** : ami-0e384b6f78bf96ebe
|
||||
* **us-east-1** : ami-0a7b46f907d5d9c4a
|
||||
|
||||
### v0.10.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-09937ec4d18350c32
|
||||
* **ap-south-1** : ami-089d6ba7541ec4c7f
|
||||
* **eu-west-3** : ami-0accb1a94bdd5c5c1
|
||||
* **eu-west-2** : ami-0dd2c97bc678b8570
|
||||
* **eu-west-1** : ami-07a38865cbe7ca3cb
|
||||
* **ap-northeast-2** : ami-09aa0b7fe1cf3dd55
|
||||
* **ap-northeast-1** : ami-0905e7d1543e5ed36
|
||||
* **sa-east-1** : ami-08c0627daa67d7372
|
||||
* **ca-central-1** : ami-034add081712ff648
|
||||
* **ap-southeast-1** : ami-0c6caee3689b6e066
|
||||
* **ap-southeast-2** : ami-04994afd8dae5b417
|
||||
* **eu-central-1** : ami-06b10f8c30e1434f1
|
||||
* **us-east-2** : ami-0d3abe7a1fec535cc
|
||||
* **us-west-1** : ami-02bb610b70c55018b
|
||||
* **us-west-2** : ami-0d1cb8ba7de246ff0
|
||||
* **us-east-1** : ami-049ccba6abdb40cba
|
||||
|
||||
### v0.10.0 (static update)
|
||||
|
||||
### v0.10.0
|
||||
* **eu-north-1** : ami-05ba33c763877e54e
|
||||
* **ap-south-1** : ami-0529eec569161cae5
|
||||
* **eu-west-3** : ami-03cb9396f63e26ff6
|
||||
@@ -62,7 +278,7 @@ The following sections provide a list containing AMI Image ID per region for eac
|
||||
* **us-west-2** : ami-04a522ecb2250fb44
|
||||
* **us-east-1** : ami-0a66ddbd50959f91e
|
||||
|
||||
### v0.9.0
|
||||
### v0.9.0 (static update)
|
||||
|
||||
* **us-east-1** : ami-0991ad536ecbacdac
|
||||
* **eu-north-1** : ami-07cbcdff501b14afe
|
||||
@@ -80,3 +296,4 @@ The following sections provide a list containing AMI Image ID per region for eac
|
||||
* **us-east-2** : ami-03b01914b07428488
|
||||
* **us-west-1** : ami-0cf4768e9d47ed076
|
||||
* **us-west-2** : ami-0b145f37da31eb9fb
|
||||
|
||||
|
||||
76
docs/install_gcp.md
Normal file
76
docs/install_gcp.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Deploying Trains Server on Google Cloud Platform
|
||||
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
|
||||
#### Default Trains Server Service ports
|
||||
The service port numbers on our Trains Server GCP Custom Image are:
|
||||
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
#### Default Trains Server Storage paths
|
||||
The persistent storage configuration:
|
||||
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
- ElasticSearch: `/opt/trains/data/elastic/`
|
||||
- File Server: `/mnt/fileserver/`
|
||||
|
||||
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
|
||||
|
||||
## Importing the Custom Image to your GCP account
|
||||
|
||||
In order to launch an instance using the Trains Server GCP Custom Image, you'll need to import the image to your custom images list.
|
||||
|
||||
**Note:** there's **no need** to upload the image file to Google Cloud Storage - we already provide links to image files stored in Google Storage
|
||||
|
||||
To import the image to your custom images list:
|
||||
1. In the Cloud Console, go to the [Images](https://console.cloud.google.com/compute/images) page.
|
||||
1. At the top of the page, click **Create image**.
|
||||
1. In the **Name** field, specify a unique name for the image.
|
||||
1. Optionally, specify an image family for your new image, or configure specific encryption settings for the image.
|
||||
1. Click the **Source** menu and select **Cloud Storage file**.
|
||||
1. Enter the Trains Server image bucket path (see [Trains Server GCP Custom Image](#released-versions)), for example:
|
||||
`allegro-files/trains-server/trains-server.tar.gz`
|
||||
1. Click the **Create** button to import the image. The process can take several minutes depending on the size of the boot disk image.
|
||||
|
||||
For more information see [Import the image to your custom images list](https://cloud.google.com/compute/docs/import/import-existing-image#import_image) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
|
||||
## Launching an instance with a Custom Image
|
||||
|
||||
For instructions on launching an instance using a GCP Custom Image, see the [Manually importing virtual disks](https://cloud.google.com/compute/docs/import/import-existing-image#overview) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
For more information on Custom Images, see [Custom Images](https://cloud.google.com/compute/docs/images#custom_images) in the Compute Engine Documentation.
|
||||
|
||||
The minimum recommended requirements for Trains Server are:
|
||||
- 2 vCPUs
|
||||
- 7.5GB RAM
|
||||
|
||||
## Upgrading
|
||||
|
||||
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
## Network and Security
|
||||
|
||||
Please make sure your instance is properly secured.
|
||||
|
||||
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
|
||||
If your instance is open for public access, we recommend you follow best practices for access management, including:
|
||||
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
|
||||
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
|
||||
### Latest version image
|
||||
|
||||
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
|
||||
|
||||
### All released images
|
||||
|
||||
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
|
||||
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
|
||||
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz
|
||||
97
docs/install_linux_mac.md
Normal file
97
docs/install_linux_mac.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# Launching the **trains-server** Docker in Linux or macOS
|
||||
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
|
||||
For Linux users:
|
||||
|
||||
* You must be logged in as a user with sudo privileges.
|
||||
* Use `bash` for all command-line instructions in this installation.
|
||||
|
||||
To launch **trains-server** on Linux or macOS:
|
||||
|
||||
1. Install Docker.
|
||||
|
||||
* Linux - see [Docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
|
||||
* macOS - see [Docker for macOS](https://docs.docker.com/docker-for-mac/install/).
|
||||
|
||||
1. Verify the Docker CE installation. Execute the command:
|
||||
|
||||
docker run hello-world
|
||||
|
||||
The expected is output is:
|
||||
|
||||
Hello from Docker!
|
||||
This message shows that your installation appears to be working correctly.
|
||||
To generate this message, Docker took the following steps:
|
||||
|
||||
1. The Docker client contacted the Docker daemon.
|
||||
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
|
||||
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
|
||||
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
|
||||
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
|
||||
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
||||
sudo chmod +x /usr/local/bin/docker-compose
|
||||
|
||||
1. Increase `vm.max_map_count` for ElasticSearch docker.
|
||||
|
||||
Linux:
|
||||
|
||||
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
sudo service docker restart
|
||||
|
||||
macOS:
|
||||
|
||||
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
|
||||
sysctl -w vm.max_map_count=262144
|
||||
|
||||
|
||||
1. Remove any previous installation of **trains-server**.
|
||||
|
||||
**WARNING**: This clears all existing **Trains** databases.
|
||||
|
||||
sudo rm -R /opt/trains/
|
||||
|
||||
1. Create local directories for the databases and storage.
|
||||
|
||||
sudo mkdir -p /opt/trains/data/elastic
|
||||
sudo mkdir -p /opt/trains/data/mongo/db
|
||||
sudo mkdir -p /opt/trains/data/mongo/configdb
|
||||
sudo mkdir -p /opt/trains/data/redis
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
|
||||
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
|
||||
|
||||
1. Grant access to the Dockers.
|
||||
|
||||
Linux:
|
||||
|
||||
sudo chown -R 1000:1000 /opt/trains
|
||||
|
||||
macOS:
|
||||
|
||||
sudo chown -R $(whoami):staff /opt/trains
|
||||
|
||||
1. Download the **trains-server** docker-compose YAML file.
|
||||
|
||||
cd /opt/trains
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
|
||||
1. Run `docker-compose` with the downloaded configuration file.
|
||||
|
||||
docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
|
||||
|
||||
* Web server on port `8080`
|
||||
* API server on port `8008`
|
||||
* File server on port `8081`
|
||||
|
||||
## Next Step
|
||||
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
50
docs/install_win.md
Normal file
50
docs/install_win.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Launching the **trains-server** Docker in Windows 10
|
||||
|
||||
For Windows, we recommend launching our pre-built Docker image on a Linux virtual machine.
|
||||
However, you can launch **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
|
||||
|
||||
To launch **trains-server** on Windows 10:
|
||||
|
||||
1. Install the Docker Desktop for Windows application by either:
|
||||
|
||||
* Following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
|
||||
* Running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
|
||||
|
||||
1. Increase the memory allocation in Docker Desktop to `4GB`.
|
||||
|
||||
1. In your Windows notification area (system tray), right click the Docker icon.
|
||||
|
||||
1. Click *Settings*, *Advanced*, and then set the memory to at least `4096`.
|
||||
|
||||
1. Click *Apply*.
|
||||
|
||||
1. Remove any previous installation of **trains-server**.
|
||||
|
||||
**WARNING**: This clears all existing **Trains** databases.
|
||||
|
||||
rmdir c:\opt\trains /s
|
||||
|
||||
1. Create local directories for data and logs. Open PowerShell and execute the following commands:
|
||||
|
||||
cd c:
|
||||
mkdir c:\opt\trains\data
|
||||
mkdir c:\opt\trains\logs
|
||||
|
||||
1. Save the **trains-server** docker-compose YAML file.
|
||||
|
||||
cd c:\opt\trains
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml -o docker-compose-win10.yml
|
||||
|
||||
1. Run `docker-compose`. In PowerShell, execute the following commands:
|
||||
|
||||
docker-compose -f docker-compose-win10.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
|
||||
|
||||
* Web server on port `8080`
|
||||
* API server on port `8008`
|
||||
* File server on port `8081`
|
||||
|
||||
## Next Step
|
||||
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
9
docs/services.conf
Normal file
9
docs/services.conf
Normal file
@@ -0,0 +1,9 @@
|
||||
tasks {
|
||||
non_responsive_tasks_watchdog {
|
||||
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
# Watchdog will sleep for this number of seconds after each cycle
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
Server Side Public License
|
||||
VERSION 1, OCTOBER 16, 2018
|
||||
|
||||
Copyright © 2018 MongoDB, Inc.
|
||||
Copyright © 2019 allegro.ai, Inc.
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
@@ -46,7 +50,23 @@ class BasicConfig:
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
@staticmethod
|
||||
def _read_extra_env_config_values():
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _read_env_paths(key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
return
|
||||
@@ -64,12 +84,17 @@ class BasicConfig:
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
download {
|
||||
# Add response headers requesting no caching for served files
|
||||
disable_browser_caching: false
|
||||
|
||||
# Cache timeout to be set for downloaded files
|
||||
cache_timeout_sec: 300
|
||||
}
|
||||
|
||||
cors {
|
||||
|
||||
@@ -10,10 +10,15 @@ from flask_cors import CORS
|
||||
|
||||
from config import config
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get("fileserver.download.cache_timeout_sec", 5 * 60)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
@@ -54,12 +59,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--upload-folder",
|
||||
"-u",
|
||||
default="/mnt/fileserver",
|
||||
default=DEFAULT_UPLOAD_FOLDER,
|
||||
help="Upload folder (default %(default)s)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
if app.config.get("UPLOAD_FOLDER") is None:
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
|
||||
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
Server Side Public License
|
||||
VERSION 1, OCTOBER 16, 2018
|
||||
|
||||
Copyright © 2018 MongoDB, Inc.
|
||||
Copyright © 2019 allegro.ai, Inc.
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
1
server/api_version.py
Normal file
1
server/api_version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "2.9.0"
|
||||
@@ -47,7 +47,7 @@ _error_codes = {
|
||||
128: ('invalid_task_output', 'invalid task output'),
|
||||
129: ('task_publish_in_progress', 'Task publish in progress'),
|
||||
130: ('task_not_found', 'task not found'),
|
||||
|
||||
131: ('events_not_added', 'events not added'),
|
||||
|
||||
# Models
|
||||
200: ('model_error', 'general task error'),
|
||||
@@ -70,9 +70,28 @@ _error_codes = {
|
||||
403: ('project_not_found', 'project not found'),
|
||||
405: ('project_has_models', 'project has associated models'),
|
||||
|
||||
# Queues
|
||||
701: ('invalid_queue_id', 'invalid queue id'),
|
||||
702: ('queue_not_empty', 'queue is not empty'),
|
||||
703: ('invalid_queue_or_task_not_queued', 'invalid queue id or task not in queue'),
|
||||
704: ('removed_during_reposition', 'task was removed by another party during reposition'),
|
||||
705: ('failed_adding_during_reposition', 'failed adding task back to queue during reposition'),
|
||||
706: ('task_already_queued', 'failed adding task to queue since task is already queued'),
|
||||
707: ('no_default_queue', 'no queue is tagged as the default queue for this company'),
|
||||
708: ('multiple_default_queues', 'more than one queue is tagged as the default queue for this company'),
|
||||
|
||||
# Database
|
||||
800: ('data_validation_error', 'data validation error'),
|
||||
801: ('expected_unique_data', 'value combination already exists'),
|
||||
|
||||
# Workers
|
||||
1001: ('invalid_worker_id', 'invalid worker id'),
|
||||
1002: ('worker_registration_failed', 'worker registration failed'),
|
||||
1003: ('worker_registered', 'worker is already registered'),
|
||||
1004: ('worker_not_registered', 'worker is not registered'),
|
||||
1005: ('worker_stats_not_found', 'worker stats not found'),
|
||||
|
||||
1104: ('invalid_scroll_id', 'Invalid scroll id'),
|
||||
},
|
||||
|
||||
(401, 'unauthorized'): {
|
||||
@@ -83,12 +102,12 @@ _error_codes = {
|
||||
21: ('bad_credentials', 'unauthorized (malformed credentials)'),
|
||||
22: ('invalid_credentials', 'unauthorized (invalid credentials)'),
|
||||
30: ('invalid_token', 'invalid token'),
|
||||
31: ('blocked_token', 'token is blocked')
|
||||
31: ('blocked_token', 'token is blocked'),
|
||||
40: ('invalid_fixed_user', 'fixed user ID was not found')
|
||||
},
|
||||
|
||||
(403, 'forbidden'): {
|
||||
10: ('routing_error', 'forbidden (routing error)'),
|
||||
11: ('missing_routing_header', 'forbidden (missing routing header)'),
|
||||
12: ('blocked_internal_endpoint', 'forbidden (blocked internal endpoint)'),
|
||||
20: ('role_not_allowed', 'forbidden (not allowed for role)'),
|
||||
21: ('no_write_permission', 'forbidden (modification not allowed)'),
|
||||
@@ -104,6 +123,7 @@ _error_codes = {
|
||||
100: ('data_error', 'general data error'),
|
||||
101: ('inconsistent_data', 'inconsistent data encountered in document'),
|
||||
102: ('database_unavailable', 'database is temporarily unavailable'),
|
||||
110: ('update_failed', 'update failed'),
|
||||
|
||||
# Index-related issues
|
||||
201: ('missing_index', 'missing internal index'),
|
||||
|
||||
@@ -4,16 +4,16 @@ from enum import Enum
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
import jsonmodels.errors
|
||||
import jsonmodels.validators
|
||||
import six
|
||||
import validators
|
||||
from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType
|
||||
from jsonmodels.fields import _LazyType, NotSet
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from luqum.parser import parser, ParseError
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@@ -25,6 +25,12 @@ def make_default(field_cls, default_value):
|
||||
|
||||
|
||||
class ListField(fields.ListField):
|
||||
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
|
||||
if default is not NotSet and callable(default):
|
||||
default = default()
|
||||
|
||||
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
|
||||
|
||||
def _cast_value(self, value):
|
||||
try:
|
||||
return super(ListField, self)._cast_value(value)
|
||||
@@ -61,9 +67,7 @@ class DictField(fields.BaseField):
|
||||
value_types = tuple()
|
||||
|
||||
return tuple(
|
||||
_LazyType(type_)
|
||||
if isinstance(type_, six.string_types)
|
||||
else type_
|
||||
_LazyType(type_) if isinstance(type_, six.string_types) else type_
|
||||
for type_ in value_types
|
||||
)
|
||||
|
||||
@@ -73,6 +77,9 @@ class DictField(fields.BaseField):
|
||||
if not self.value_types:
|
||||
return
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
for item in value.values():
|
||||
self.validate_single_value(item)
|
||||
|
||||
@@ -99,7 +106,7 @@ class IntField(fields.IntField):
|
||||
|
||||
|
||||
def validate_lucene_query(value):
|
||||
if value == '':
|
||||
if value == "":
|
||||
return
|
||||
try:
|
||||
parser.parse(value)
|
||||
@@ -117,6 +124,7 @@ class LuceneQueryField(fields.StringField):
|
||||
|
||||
class NullableEnumValidator(EnumValidator):
|
||||
"""Validator for enums that allows a None value."""
|
||||
|
||||
def validate(self, value):
|
||||
if value is not None:
|
||||
super(NullableEnumValidator, self).validate(value)
|
||||
@@ -144,12 +152,49 @@ class EnumField(fields.StringField):
|
||||
return super().parse_value(value)
|
||||
|
||||
|
||||
class ActualEnumField(fields.StringField):
|
||||
def __init__(
|
||||
self,
|
||||
enum_class: Type[Enum],
|
||||
*args,
|
||||
validators=None,
|
||||
required=False,
|
||||
default=None,
|
||||
**kwargs
|
||||
):
|
||||
self.__enum = enum_class
|
||||
self.types = (enum_class,)
|
||||
# noinspection PyTypeChecker
|
||||
choices = list(enum_class)
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
validators = [*(validators or []), validator_cls(*choices)]
|
||||
super().__init__(
|
||||
default=self.parse_value(default) if default else NotSet,
|
||||
*args,
|
||||
required=required,
|
||||
validators=validators,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is None and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
return self.__enum(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def to_struct(self, value):
|
||||
return super().to_struct(value.value)
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if validators.email(value) is not True:
|
||||
if email_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidEmailAddress()
|
||||
|
||||
|
||||
@@ -158,5 +203,14 @@ class DomainField(fields.StringField):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if validators.domain(value) is not True:
|
||||
if domain_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField, DateTimeField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Max, Enum
|
||||
|
||||
@@ -79,6 +79,7 @@ class Credentials(Base):
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
from apimodels import DictField, ListField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
@@ -58,3 +59,11 @@ class UpdateResponse(models.Base):
|
||||
class PagedRequest(models.Base):
|
||||
page = fields.IntField()
|
||||
page_size = fields.IntField()
|
||||
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
78
server/apimodels/events.py
Normal file
78
server/apimodels/events.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
from utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1, maximum_value=10)]
|
||||
)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
class DebugImageResponse(Base):
|
||||
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class TaskMetricsRequest(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
@@ -6,11 +6,16 @@ from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,), required=True)
|
||||
labels = DictField(value_types=string_types+(int,))
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
public = fields.BoolField(default=False)
|
||||
project = fields.StringField()
|
||||
|
||||
11
server/apimodels/organization.py
Normal file
11
server/apimodels/organization.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
17
server/apimodels/projects.py
Normal file
17
server/apimodels/projects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
|
||||
|
||||
class GetHyperParamReq(ProjectReq):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
60
server/apimodels/queues.py
Normal file
60
server/apimodels/queues.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import ListField
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class MoveTaskRequest(TaskRequest):
|
||||
count = IntField(default=1)
|
||||
|
||||
|
||||
class MoveTaskResponse(Base):
|
||||
position = IntField()
|
||||
|
||||
|
||||
class GetMetricsRequest(Base):
|
||||
queue_ids = ListField([str])
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
|
||||
|
||||
class QueueMetrics(Base):
|
||||
queue = StringField()
|
||||
dates = ListField(int)
|
||||
avg_waiting_times = ListField([float, int])
|
||||
queue_lengths = ListField(int)
|
||||
|
||||
|
||||
class GetMetricsResponse(Base):
|
||||
queues = ListField(QueueMetrics)
|
||||
15
server/apimodels/server.py
Normal file
15
server/apimodels/server.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from jsonmodels.fields import BoolField, DateTimeField, StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
|
||||
class ReportStatsOptionRequest(Base):
|
||||
enabled = BoolField(default=None, nullable=True)
|
||||
|
||||
|
||||
class ReportStatsOptionResponse(Base):
|
||||
supported = BoolField(default=True)
|
||||
enabled = BoolField()
|
||||
enabled_time = DateTimeField(nullable=True)
|
||||
enabled_version = StringField(nullable=True)
|
||||
enabled_user = StringField(nullable=True)
|
||||
current_version = StringField()
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField
|
||||
from jsonmodels.validators import Enum
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
@@ -9,12 +11,39 @@ from database.model.task.task import TaskType
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class ArtifactTypeData(models.Base):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class Artifact(models.Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(validators=Enum("input", "output"), default="output")
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = IntField()
|
||||
timestamp = IntField()
|
||||
type_data = EmbeddedField(ArtifactTypeData)
|
||||
display_data = ListField([list])
|
||||
|
||||
|
||||
class StartedResponse(UpdateResponse):
|
||||
started = IntField()
|
||||
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
dequeued = IntField()
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
deleted_indices = ListField(items_types=six.string_types)
|
||||
dequeued = DictField()
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
@@ -30,6 +59,10 @@ class UpdateRequest(TaskRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
|
||||
@@ -57,5 +90,104 @@ class CreateRequest(TaskData):
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
|
||||
|
||||
|
||||
class PingRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
new_task_tags = ListField([str])
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
artifacts = ListField([Artifact], required=True)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
178
server/apimodels/workers.py
Normal file
178
server/apimodels/workers.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import (
|
||||
StringField,
|
||||
EmbeddedField,
|
||||
DateTimeField,
|
||||
IntField,
|
||||
FloatField,
|
||||
BoolField,
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
timeout = make_default(
|
||||
IntField, DEFAULT_TIMEOUT
|
||||
)() # registration timeout in seconds (default is 10min)
|
||||
queues = ListField(six.string_types) # list of queues this worker listens to
|
||||
|
||||
|
||||
class MachineStats(Base):
|
||||
cpu_usage = ListField(six.integer_types + (float,))
|
||||
cpu_temperature = ListField(six.integer_types + (float,))
|
||||
gpu_usage = ListField(six.integer_types + (float,))
|
||||
gpu_temperature = ListField(six.integer_types + (float,))
|
||||
gpu_memory_free = ListField(six.integer_types + (float,))
|
||||
gpu_memory_used = ListField(six.integer_types + (float,))
|
||||
memory_used = FloatField()
|
||||
memory_free = FloatField()
|
||||
network_tx = FloatField()
|
||||
network_rx = FloatField()
|
||||
disk_free_home = FloatField()
|
||||
disk_free_temp = FloatField()
|
||||
disk_read = FloatField()
|
||||
disk_write = FloatField()
|
||||
|
||||
|
||||
class StatusReportRequest(WorkerRequest):
|
||||
task = StringField() # task the worker is running on
|
||||
queue = StringField() # queue from which task was taken
|
||||
queues = ListField(
|
||||
str
|
||||
) # list of queues this worker listens to. if None, this will not update the worker's queues list.
|
||||
timestamp = IntField(required=True)
|
||||
machine_stats = EmbeddedField(MachineStats)
|
||||
|
||||
|
||||
class IdNameEntry(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
running_time = IntField()
|
||||
last_iteration = IntField()
|
||||
|
||||
|
||||
class QueueEntry(IdNameEntry):
|
||||
next_task = EmbeddedField(IdNameEntry)
|
||||
num_tasks = IntField()
|
||||
|
||||
|
||||
class WorkerResponseEntry(WorkerEntry):
|
||||
task = EmbeddedField(CurrentTaskEntry)
|
||||
queue = EmbeddedField(QueueEntry)
|
||||
queues = ListField(QueueEntry)
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
workers = ListField(WorkerResponseEntry)
|
||||
|
||||
|
||||
class StatsBase(Base):
|
||||
worker_ids = ListField(str)
|
||||
|
||||
|
||||
class StatsReportBase(StatsBase):
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
|
||||
|
||||
class AggregationType(Enum):
|
||||
avg = "avg"
|
||||
min = "min"
|
||||
max = "max"
|
||||
|
||||
|
||||
class StatItem(Base):
|
||||
key = StringField(required=True)
|
||||
aggregation = EnumField(AggregationType, default=AggregationType.avg)
|
||||
|
||||
|
||||
class GetStatsRequest(StatsReportBase):
|
||||
items = ListField(
|
||||
StatItem, required=True, validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
split_by_variant = BoolField(default=False)
|
||||
|
||||
|
||||
class AggregationStats(Base):
|
||||
aggregation = EnumField(AggregationType)
|
||||
values = ListField(float)
|
||||
|
||||
|
||||
class MetricStats(Base):
|
||||
metric = StringField()
|
||||
variant = StringField()
|
||||
dates = ListField(int)
|
||||
stats = ListField(AggregationStats)
|
||||
|
||||
|
||||
class WorkerStatistics(Base):
|
||||
worker = StringField()
|
||||
metrics = ListField(MetricStats)
|
||||
|
||||
|
||||
class GetStatsResponse(Base):
|
||||
workers = ListField(WorkerStatistics)
|
||||
|
||||
|
||||
class GetMetricKeysRequest(StatsBase):
|
||||
pass
|
||||
|
||||
|
||||
class MetricCategory(Base):
|
||||
name = StringField()
|
||||
metric_keys = ListField(str)
|
||||
|
||||
|
||||
class GetMetricKeysResponse(Base):
|
||||
categories = ListField(MetricCategory)
|
||||
|
||||
|
||||
class GetActivityReportRequest(StatsReportBase):
|
||||
pass
|
||||
|
||||
|
||||
class ActivityReportSeries(Base):
|
||||
dates = ListField(int)
|
||||
counts = ListField(int)
|
||||
|
||||
|
||||
class GetActivityReportResponse(Base):
|
||||
total = EmbeddedField(ActivityReportSeries)
|
||||
active = EmbeddedField(ActivityReportSeries)
|
||||
467
server/bll/event/debug_images_iterator.py
Normal file
467
server/bll/event/debug_images_iterator.py
Normal file
@@ -0,0 +1,467 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
name: str = StringField(required=True)
|
||||
recycle_url_marker: str = StringField()
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
timestamp: int = IntField(default=0)
|
||||
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugImagesResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
es_index=es_index,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
"""
|
||||
task_ids = set(metric.task for metric in state.metrics)
|
||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return []
|
||||
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE in stats.event_stats_by_type
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
chain.from_iterable(
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
)
|
||||
)
|
||||
outdated_metrics = [
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if (metric.task, metric.name) in update_times
|
||||
and update_times[metric.task, metric.name] > metric.timestamp
|
||||
]
|
||||
state.metrics = [
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
es_index,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, es_index, metrics: Sequence[Tuple[str, str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
tasks = defaultdict(list)
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(self._init_metric_states_for_task, es_index=es_index),
|
||||
tasks.items(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], es_index
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_scroll_state(variant: dict):
|
||||
"""
|
||||
Return new variant scroll state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantScrollState(name=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricScrollState(
|
||||
task=task,
|
||||
name=metric["key"],
|
||||
variants=[
|
||||
init_variant_scroll_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
es_index: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update metric scroll state
|
||||
"""
|
||||
if metric.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and metric.last_min_iter is not None:
|
||||
range_condition = {"lt": metric.last_min_iter}
|
||||
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||
range_condition = {"gt": metric.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
if navigate_earlier:
|
||||
"""
|
||||
When navigating to earlier iterations consider only
|
||||
variants whose invalid iterations border is lower than
|
||||
our starting iteration. For these variants make sure
|
||||
that only events from the valid iterations are returned
|
||||
"""
|
||||
if not metric.last_min_iter:
|
||||
variants = metric.variants
|
||||
else:
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is None
|
||||
or v.last_invalid_iteration < metric.last_min_iter
|
||||
)
|
||||
if not variants:
|
||||
return metric.task, metric.name, []
|
||||
must_conditions.append(
|
||||
{"terms": {"variant": list(v.name for v in variants)}}
|
||||
)
|
||||
else:
|
||||
"""
|
||||
When navigating to later iterations all variants may be relevant.
|
||||
For the variants whose invalid border is higher than our starting
|
||||
iteration make sure that only events from valid iterations are returned
|
||||
"""
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is not None
|
||||
and v.last_invalid_iteration > metric.last_max_iter
|
||||
)
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
if v.last_invalid_iteration is not None
|
||||
]
|
||||
if variants_conditions:
|
||||
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
for v in variant_buckets
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
]
|
||||
|
||||
iterations = [
|
||||
{
|
||||
"iter": it["key"],
|
||||
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||
}
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||
]
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
metric.last_max_iter = iterations[0]["iter"]
|
||||
metric.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||
# if navigate_earlier and any(
|
||||
# variant.last_invalid_iteration is None for variant in variants
|
||||
# ):
|
||||
# """
|
||||
# Variants validation flags due to recycling can
|
||||
# be set only on navigation to earlier frames
|
||||
# """
|
||||
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||
|
||||
return metric.task, metric.name, iterations
|
||||
|
||||
@staticmethod
|
||||
def _update_variants_invalid_iterations(
|
||||
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
This code is currently not in used since the invalid iterations
|
||||
are calculated during MetricState initialization
|
||||
For variants that do not have recycle url marker set it from the
|
||||
first event
|
||||
For variants that do not have last_invalid_iteration set check if the
|
||||
recycle marker was reached on a certain iteration and set it to the
|
||||
corresponding iteration
|
||||
For variants that have a newly set last_invalid_iteration remove
|
||||
events from the invalid iterations
|
||||
Return the updated iterations list
|
||||
"""
|
||||
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||
for it in iterations:
|
||||
iteration = it["iter"]
|
||||
events_to_remove = []
|
||||
for event in it["events"]:
|
||||
variant = variants_lookup[event["variant"]][0]
|
||||
if (
|
||||
variant.last_invalid_iteration
|
||||
and variant.last_invalid_iteration >= iteration
|
||||
):
|
||||
events_to_remove.append(event)
|
||||
continue
|
||||
event_url = event.get("url")
|
||||
if not variant.recycle_url_marker:
|
||||
variant.recycle_url_marker = event_url
|
||||
elif variant.recycle_url_marker == event_url:
|
||||
variant.last_invalid_iteration = iteration
|
||||
events_to_remove.append(event)
|
||||
if events_to_remove:
|
||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||
return [it for it in iterations if it["events"]]
|
||||
@@ -1,71 +1,106 @@
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
|
||||
from tools import safe_get
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
|
||||
|
||||
@attr.s
|
||||
class TaskEventsResult(object):
|
||||
events = attr.ib(type=list, default=attr.Factory(list))
|
||||
total_events = attr.ib(type=int, default=0)
|
||||
next_scroll_id = attr.ib(type=str, default=None)
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ["task", "iter", "metric", "variant", "key"]
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.es = events_es if events_es is not None else es_factory.connect("events")
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
self._metrics = EventMetrics(self.es)
|
||||
self._skip_iteration_for_metric = set(
|
||||
config.get("services.events.ignore_iteration.metrics", [])
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
|
||||
def add_events(self, company_id, events, worker):
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_events = nested_dict(
|
||||
task_last_scalar_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> variant_hash -> MetricEvent
|
||||
|
||||
task_last_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
if "type" not in event:
|
||||
raise errors.BadRequest("Event must have a 'type' field", event=event)
|
||||
event_type = event.get("type")
|
||||
if event_type is None:
|
||||
errors_per_type["Event must have a 'type' field"] += 1
|
||||
continue
|
||||
|
||||
event_type = event["type"].replace(" ", "_")
|
||||
event_type = event_type.replace(" ", "_")
|
||||
if event_type not in EVENT_TYPES:
|
||||
raise errors.BadRequest(
|
||||
"Invalid event type {}".format(event_type),
|
||||
event=event,
|
||||
types=EVENT_TYPES,
|
||||
)
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
|
||||
@@ -94,11 +129,13 @@ class EventBLL(object):
|
||||
event["value"] = event["values"]
|
||||
del event["values"]
|
||||
|
||||
index_name = EventBLL.get_index_name(company_id, event_type)
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
|
||||
index_name = EventMetrics.get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
@@ -108,92 +145,81 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is not None:
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if iter is not None:
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_metric_event_for_task(
|
||||
task_last_events=task_last_events, task_id=task_id, event=event
|
||||
)
|
||||
else:
|
||||
es_action["_routing"] = task_id
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
if task_ids:
|
||||
# verify task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
res = Task.objects(id__in=task_ids, company=company_id).only("id")
|
||||
if len(res) < len(task_ids):
|
||||
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
company=company_id, ids=invalid_task_ids
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
errors_in_bulk = []
|
||||
added = 0
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
last_metrics = {
|
||||
t.id: t.to_proper_dict().get("last_metrics", {})
|
||||
for t in Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"last_metrics"
|
||||
)
|
||||
}
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
|
||||
# who's events were successful
|
||||
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter=task_iteration.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
last_metrics=last_metrics.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
return added, errors_in_bulk
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
|
||||
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update task_last_events structure for the provided task_id with the provided event details if this event is more
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/variant combination.
|
||||
|
||||
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
key conflicts due to invalid characters and/or long field names.
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
@@ -204,14 +230,49 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
last_events = task_last_events[task_id]
|
||||
last_event = last_events[metric_hash][variant_hash]
|
||||
event_iter = event.get("iter", 0)
|
||||
event_timestamp = event.get("timestamp", 0)
|
||||
value = event.get("value")
|
||||
if value is not None and (
|
||||
(event_iter, event_timestamp)
|
||||
>= (
|
||||
last_event.get("iter", event_iter),
|
||||
last_event.get("timestamp", event_timestamp),
|
||||
)
|
||||
):
|
||||
event_data = {
|
||||
k: event[k]
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/event_type combination.
|
||||
last_events contains [metric_name -> event_type -> event]
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
event_type = event.get("type")
|
||||
if not (metric and event_type):
|
||||
return
|
||||
|
||||
timestamp = last_events[metric][event_type].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
last_events[metric][event_type] = event
|
||||
|
||||
def _update_task(
|
||||
self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
now,
|
||||
iter_max=None,
|
||||
last_scalar_events=None,
|
||||
last_events=None,
|
||||
):
|
||||
"""
|
||||
Update task information in DB with aggregated results after handling event(s) related to this task.
|
||||
@@ -222,27 +283,26 @@ class EventBLL(object):
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
if iter is not None:
|
||||
fields["last_iteration"] = iter
|
||||
if iter_max is not None:
|
||||
fields["last_iteration_max"] = iter_max
|
||||
|
||||
if last_scalar_events:
|
||||
fields["last_scalar_values"] = list(
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if last_events:
|
||||
|
||||
def get_metric_event(ev):
|
||||
me = MetricEvent.from_dict(**ev)
|
||||
if "timestamp" in ev:
|
||||
me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000)
|
||||
return me
|
||||
|
||||
new_last_metrics = nested_dict(2, MetricEvent)
|
||||
new_last_metrics.update(last_metrics)
|
||||
|
||||
for metric_hash, variants in last_events.items():
|
||||
for variant_hash, event in variants.items():
|
||||
new_last_metrics[metric_hash][variant_hash] = get_metric_event(
|
||||
event
|
||||
)
|
||||
|
||||
fields["last_metrics"] = new_last_metrics.to_dict()
|
||||
fields["last_events"] = last_events
|
||||
|
||||
if not fields:
|
||||
return False
|
||||
@@ -251,7 +311,7 @@ class EventBLL(object):
|
||||
|
||||
def _get_event_id(self, event):
|
||||
id_values = (str(event[field]) for field in self.id_fields if field in event)
|
||||
return "-".join(id_values)
|
||||
return hashlib.md5("-".join(id_values).encode()).hexdigest()
|
||||
|
||||
def scroll_task_events(
|
||||
self,
|
||||
@@ -262,6 +322,9 @@ class EventBLL(object):
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -270,7 +333,7 @@ class EventBLL(object):
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], None, 0
|
||||
@@ -284,12 +347,144 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
|
||||
):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [
|
||||
(metric["key"], variant["key"], iter["key"])
|
||||
for metric in es_res["aggregations"]["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
for iter in variant["iters"]["buckets"]
|
||||
]
|
||||
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
tasks: Sequence[str],
|
||||
last_iterations_per_plot: int = None,
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = "plot"
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = []
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
|
||||
for metric, variant, iter in last_iters:
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"term": {"iter": iter}},
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@@ -302,6 +497,8 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@@ -311,24 +508,20 @@ class EventBLL(object):
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
@@ -347,34 +540,30 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
@@ -383,8 +572,18 @@ class EventBLL(object):
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": 200},
|
||||
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
@@ -393,7 +592,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -405,7 +604,7 @@ class EventBLL(object):
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar")
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
@@ -424,15 +623,15 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1000,
|
||||
"order": {"_term": "asc"},
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 1000,
|
||||
"order": {"_term": "asc"},
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@@ -462,7 +661,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -488,147 +687,9 @@ class EventBLL(object):
|
||||
metrics.append(metric_summary)
|
||||
return metrics, max_timestamp
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self, company_id, task_ids, allow_public=True
|
||||
):
|
||||
assert isinstance(task_ids, list)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metric_and_variant": {
|
||||
"terms": {
|
||||
"script": "doc['metric'].value +'/'+ doc['variant'].value",
|
||||
"size": 10000,
|
||||
},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_comparison"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return
|
||||
|
||||
metrics = {}
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metric_and_variant"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for task_bucket in metric_bucket["tasks"]["buckets"]:
|
||||
task_id = task_bucket["key"]
|
||||
value = task_bucket["avg_val"]["value"]
|
||||
if metric_data.get(task_id) is None:
|
||||
metric_data[task_id] = {
|
||||
"x": [],
|
||||
"y": [],
|
||||
"name": task_name_by_id[
|
||||
task_id
|
||||
], # todo: lookup task name from id
|
||||
}
|
||||
metric_data[task_id]["x"].append(iteration)
|
||||
metric_data[task_id]["y"].append(value)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_scalar_metrics_average_per_iter(self, company_id, task_id):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 200,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 500,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
metrics = {}
|
||||
if "aggregations" in es_res:
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metrics"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for variant_bucket in metric_bucket["variants"]["buckets"]:
|
||||
variant = variant_bucket["key"]
|
||||
value = variant_bucket["avg_val"]["value"]
|
||||
if metric_data.get(variant) is None:
|
||||
metric_data[variant] = {"x": [], "y": [], "name": variant}
|
||||
metric_data[variant]["x"].append(iteration)
|
||||
metric_data[variant]["y"].append(value)
|
||||
return metrics
|
||||
|
||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_vector")
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], []
|
||||
|
||||
@@ -647,7 +708,7 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@@ -668,7 +729,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -678,23 +739,28 @@ class EventBLL(object):
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
|
||||
def delete_task_events(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "*")
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
with translate_errors_context():
|
||||
extra_msg = None
|
||||
query = Q(id=task_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return "events-%s-%s" % (event_type, company_id)
|
||||
|
||||
406
server/bll/event/event_metrics.py
Normal file
406
server/bll/event/event_metrics.py
Normal file
@@ -0,0 +1,406 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apierrors import errors
|
||||
from bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_METRICS_COUNT = 100
|
||||
MAX_VARIANTS_COUNT = 100
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@property
|
||||
def _max_concurrency(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id}"
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.company for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return res
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
max(1, int(max_index - min_index + 1) // samples),
|
||||
samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
"name": variant["key"],
|
||||
**key.get_iterations_data(variant),
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
return {
|
||||
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(self._max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
),
|
||||
task_ids,
|
||||
)
|
||||
return list(zip(task_ids, res))
|
||||
|
||||
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"type": event_type.value}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||
]
|
||||
113
server/bll/event/log_events_iterator.py
Normal file
113
server/bll/event/log_events_iterator.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
|
||||
already_present_ids = set(ev["_id"] for ev in events)
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[
|
||||
*events,
|
||||
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
|
||||
],
|
||||
hits_total,
|
||||
)
|
||||
161
server/bll/event/scalar_key.py
Normal file
161
server/bll/event/scalar_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Module for polymorphism over different types of X axes in scalar aggregations
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from utilities.stringenum import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ScalarKeyEnum(StringEnum):
|
||||
"""
|
||||
String enum representing X axes key
|
||||
"""
|
||||
|
||||
iter = auto()
|
||||
timestamp = auto()
|
||||
iso_time = auto()
|
||||
|
||||
|
||||
class ScalarKey(ABC):
|
||||
"""
|
||||
Abstract scalar key
|
||||
"""
|
||||
|
||||
_enum_to_key = {}
|
||||
bucket_key_key = "key"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def enum_value(self) -> ScalarKeyEnum:
|
||||
"""
|
||||
Enum value accepted in API requests
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Key name. Used as arbitrary internal key in elasticsearch queries
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def field(self) -> str:
|
||||
"""
|
||||
Event key to aggregate by
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
"""
|
||||
Get aggregation for this type of key
|
||||
:param interval: elasticsearch aggregation interval
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Save a mapping from enum values to key class
|
||||
"""
|
||||
if cls.enum_value not in ScalarKeyEnum:
|
||||
raise ValueError(f"{cls.enum_value!r} not in {ScalarKeyEnum.__name__}")
|
||||
if cls.enum_value in cls._enum_to_key:
|
||||
log.warning(
|
||||
f"'{cls.enum_value.value}' is already registered to {ScalarKey.__name__}"
|
||||
)
|
||||
cls._enum_to_key[cls.enum_value] = cls
|
||||
|
||||
@classmethod
|
||||
def resolve(cls, key: ScalarKeyEnum):
|
||||
"""
|
||||
Create a key instance from enum instance
|
||||
"""
|
||||
return cls._enum_to_key[key]()
|
||||
|
||||
def get_iterations_data(self, iter_buckets: dict) -> dict:
|
||||
"""
|
||||
Convert a list of bucket entries to `x`s array and `y`s array
|
||||
"""
|
||||
return extract_properties_to_lists(
|
||||
("x", "y"),
|
||||
iter_buckets[self.name]["buckets"],
|
||||
self._get_iterations_data_single,
|
||||
)
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
"""
|
||||
Extract x value and y value from a single bucket item
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by timestamp in milliseconds since epoch
|
||||
"""
|
||||
|
||||
name = "timestamp"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.timestamp
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by iteration number
|
||||
"""
|
||||
|
||||
name = "iters"
|
||||
field = "iter"
|
||||
enum_value = ScalarKeyEnum.iter
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"histogram": {"field": "iter", "interval": interval, "min_doc_count": 1}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by time formatted as ISO strings
|
||||
"""
|
||||
|
||||
name = "iso_time"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.iso_time
|
||||
bucket_key_key = "key_as_string"
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
return iter_data[self.bucket_key_key], iter_data["avg_val"]["value"]
|
||||
18
server/bll/model/__init__.py
Normal file
18
server/bll/model/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from database.model.model import Model
|
||||
from database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
193
server/bll/organization/__init__.py
Normal file
193
server/bll/organization/__init__.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .project_bll import ProjectBLL
|
||||
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
1
server/bll/query/__init__.py
Normal file
1
server/bll/query/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .builder import Builder
|
||||
36
server/bll/query/builder.py
Normal file
36
server/bll/query/builder.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Sequence, Iterable, Union
|
||||
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
RANGE_IGNORE_VALUE = -1
|
||||
|
||||
|
||||
class Builder:
|
||||
@staticmethod
|
||||
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
|
||||
return {
|
||||
"range": {
|
||||
"timestamp": {
|
||||
"gte": int(from_date),
|
||||
"lte": int(to_date),
|
||||
"format": "epoch_second",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def terms(field: str, values: Iterable[str]) -> dict:
|
||||
return {"terms": {field: list(values)}}
|
||||
|
||||
@staticmethod
|
||||
def normalize_range(
|
||||
range_: Sequence[Union[int, float]],
|
||||
ignore_value: Union[int, float] = RANGE_IGNORE_VALUE,
|
||||
) -> Optional[Sequence[Union[int, float]]]:
|
||||
if not range_ or set(range_) == {ignore_value}:
|
||||
return None
|
||||
if len(range_) < 2:
|
||||
return [range_[0]] * 2
|
||||
return range_
|
||||
1
server/bll/queue/__init__.py
Normal file
1
server/bll/queue/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .queue_bll import QueueBLL
|
||||
270
server/bll/queue/queue_bll.py
Normal file
270
server/bll/queue/queue_bll.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import database
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.queue.queue_metrics import QueueMetrics
|
||||
from bll.workers import WorkerBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.queue import Queue, Entry
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class QueueBLL(object):
|
||||
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
|
||||
self.worker_bll = worker_bll or WorkerBLL()
|
||||
self.es = es or es_factory.connect("workers")
|
||||
self._metrics = QueueMetrics(self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> QueueMetrics:
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
company_id: str,
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
now = datetime.utcnow()
|
||||
queue = Queue(
|
||||
id=database.utils.id(),
|
||||
company=company_id,
|
||||
created=now,
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
system_tags=system_tags or [],
|
||||
last_update=now,
|
||||
)
|
||||
queue.save()
|
||||
return queue
|
||||
|
||||
def get_by_id(
|
||||
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
|
||||
) -> Queue:
|
||||
"""
|
||||
Get queue by id
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
qs = Queue.objects(**query)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
queue = qs.first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
|
||||
return queue
|
||||
|
||||
@classmethod
|
||||
def get_queue_with_task(cls, company_id: str, queue_id: str, task_id: str) -> Queue:
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(entries__task=task_id, **query).first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return queue
|
||||
|
||||
def get_default(self, company_id: str) -> Queue:
|
||||
"""
|
||||
Get the default queue
|
||||
:raise errors.bad_request.NoDefaultQueue: if the default queue not found
|
||||
:raise errors.bad_request.MultipleDefaultQueues: if more than one default queue is found
|
||||
"""
|
||||
with translate_errors_context():
|
||||
res = Queue.objects(company=company_id, system_tags="default").only(
|
||||
"id", "name"
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.NoDefaultQueue()
|
||||
if len(res) > 1:
|
||||
raise errors.bad_request.MultipleDefaultQueues(
|
||||
queues=tuple(r.id for r in res)
|
||||
)
|
||||
|
||||
return res.first()
|
||||
|
||||
def update(
|
||||
self, company_id: str, queue_id: str, **update_fields
|
||||
) -> Tuple[int, dict]:
|
||||
"""
|
||||
Partial update of the queue from update_fields
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
:return: number of updated objects and updated fields dictionary
|
||||
"""
|
||||
with translate_errors_context():
|
||||
# validate the queue exists
|
||||
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return Queue.safe_update(company_id, queue_id, update_fields)
|
||||
|
||||
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
|
||||
"""
|
||||
Delete the queue
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if queue.entries and not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
projection = Queue.get_extra_projection("entries.task.name")
|
||||
with translate_errors_context():
|
||||
res = Queue.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
override_projection=projection,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
for queue in worker.queues:
|
||||
queue_workers[queue].append(worker)
|
||||
|
||||
for item in res:
|
||||
item["workers"] = [
|
||||
{
|
||||
"name": w.id,
|
||||
"ip": w.ip,
|
||||
"task": w.task.to_struct() if w.task else None,
|
||||
}
|
||||
for w in queue_workers.get(item["id"], [])
|
||||
]
|
||||
|
||||
return res
|
||||
|
||||
def add_task(self, company_id: str, queue_id: str, task_id: str) -> dict:
|
||||
"""
|
||||
Add the task to the queue and return the queue update results
|
||||
:raise errors.bad_request.TaskAlreadyQueued: if the task is already in the queue
|
||||
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the queue update operation failed
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
|
||||
"""
|
||||
Atomically pop and return the first task from the queue (or None)
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
if not queue.entries:
|
||||
return
|
||||
|
||||
try:
|
||||
Queue.objects(**query).update(last_update=datetime.utcnow())
|
||||
except Exception:
|
||||
log.exception("Error while updating Queue.last_update")
|
||||
|
||||
return queue.entries[0]
|
||||
|
||||
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
|
||||
"""
|
||||
Removes the task from the queue and returns the number of removed items
|
||||
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task=task_id, **query).update_one(
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
task_id: str,
|
||||
pos_func: Callable[[int], int],
|
||||
) -> int:
|
||||
"""
|
||||
Moves the task in the queue to the position calculated by pos_func
|
||||
Returns the updated task position in the queue
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
|
||||
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
|
||||
new_position = pos_func(position)
|
||||
|
||||
if new_position != position:
|
||||
entry = queue.entries[position]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
updated = Queue.objects(entries__task=task_id, **query).update_one(
|
||||
pull__entries=entry, last_update=datetime.utcnow()
|
||||
)
|
||||
if not updated:
|
||||
raise errors.bad_request.RemovedDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
|
||||
if new_position >= 0:
|
||||
inst["$push"]["entries"]["$position"] = new_position
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
__raw__=inst
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.FailedAddingDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return new_position
|
||||
262
server/bll/queue/queue_metrics.py
Normal file
262
server/bll/queue/queue_metrics.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
import elasticsearch.helpers
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import es_factory
|
||||
from apierrors.errors import bad_request
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.queue import Queue, Entry
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def _queue_metrics_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"queue_metrics_{company_id}_"
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
@staticmethod
|
||||
def _calc_avg_waiting_time(entries: Sequence[Entry]) -> float:
|
||||
"""
|
||||
Calculate avg waiting time for the given tasks.
|
||||
Return 0 if the list is empty
|
||||
"""
|
||||
if not entries:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
|
||||
return total_waiting_in_secs / len(entries)
|
||||
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
|
||||
"""
|
||||
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
|
||||
:return: True if the write to es was successful, false otherwise
|
||||
"""
|
||||
es_index = (
|
||||
self._queue_metrics_prefix_for_company(company_id)
|
||||
+ self._get_es_index_suffix()
|
||||
)
|
||||
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
|
||||
def make_doc(queue: Queue) -> dict:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
|
||||
entries
|
||||
),
|
||||
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
actions = list(map(make_doc, queues))
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
|
||||
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
|
||||
query = dict(company=company_id)
|
||||
if queue_ids:
|
||||
query["id__in"] = list(queue_ids)
|
||||
queues = Queue.objects(**query)
|
||||
self.log_queue_metrics_to_es(company_id, queues=list(queues))
|
||||
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_dates_agg(cls, interval) -> dict:
|
||||
"""
|
||||
Aggregation for building date histogram with internal grouping per queue.
|
||||
We are grouping by queue inside date histogram and not vice versa so that
|
||||
it will be easy to average between queue metrics inside each date bucket.
|
||||
Ignore empty buckets.
|
||||
"""
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
"queues": {
|
||||
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
|
||||
"aggs": cls._get_top_waiting_agg(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_top_waiting_agg(cls) -> dict:
|
||||
"""
|
||||
Aggregation for getting max waiting time and the corresponding queue length
|
||||
inside each date->queue bucket
|
||||
"""
|
||||
return {
|
||||
"top_avg_waiting": {
|
||||
"top_hits": {
|
||||
"sort": [
|
||||
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
],
|
||||
"_source": {
|
||||
"includes": [
|
||||
cls.EsKeys.WAITING_TIME_FIELD,
|
||||
cls.EsKeys.QUEUE_LENGTH_FIELD,
|
||||
]
|
||||
},
|
||||
"size": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_queue_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
from_date: float,
|
||||
to_date: float,
|
||||
interval: int,
|
||||
queue_ids: Sequence[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Get the company queue metrics in the specified time range.
|
||||
Returned as date histograms of average values per queue and metric type.
|
||||
The from_date is extended by 'metrics_before_from_date' seconds from
|
||||
queues.conf due to possibly small amount of points. The default extension is 3600s
|
||||
In case no queue ids are specified the avg across all the
|
||||
company queues is calculated for each metric
|
||||
"""
|
||||
# self._log_current_metrics(company, queue_ids=queue_ids)
|
||||
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
seconds_before = config.get("services.queues.metrics_before_from_date", 3600)
|
||||
must_terms = [QueryBuilder.dates_range(from_date - seconds_before, to_date)]
|
||||
if queue_ids:
|
||||
must_terms.append(QueryBuilder.terms("queue", queue_ids))
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_terms}},
|
||||
"aggs": self._get_dates_agg(interval),
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
|
||||
res = self._search_company_metrics(company_id, es_req)
|
||||
|
||||
if "aggregations" not in res:
|
||||
return {}
|
||||
|
||||
date_metrics = [
|
||||
dict(
|
||||
timestamp=d["key"],
|
||||
queue_metrics=self._extract_queue_metrics(d["queues"]["buckets"]),
|
||||
)
|
||||
for d in res["aggregations"]["dates"]["buckets"]
|
||||
if d["doc_count"] > 0
|
||||
]
|
||||
if queue_ids:
|
||||
return self._datetime_histogram_per_queue(date_metrics)
|
||||
|
||||
return self._average_datetime_histogram(date_metrics)
|
||||
|
||||
@classmethod
|
||||
def _datetime_histogram_per_queue(cls, date_metrics: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Build datetime histogram per queue from datetime histogram where every
|
||||
bucket contains all the queues metrics
|
||||
"""
|
||||
queues_data = defaultdict(list)
|
||||
for date_data in date_metrics:
|
||||
timestamp = date_data["timestamp"]
|
||||
for queue, metrics in date_data["queue_metrics"].items():
|
||||
queues_data[queue].append({"date": timestamp, **metrics})
|
||||
|
||||
return queues_data
|
||||
|
||||
@classmethod
|
||||
def _average_datetime_histogram(cls, date_metrics: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Calculate weighted averages and total count for each bucket of date_metrics histogram.
|
||||
If for any queue the data is missing then take it from the previous bucket
|
||||
The result is returned as a dictionary with one key 'total'
|
||||
"""
|
||||
queues_total = []
|
||||
last_values = {}
|
||||
for date_data in date_metrics:
|
||||
date_metrics = date_data["queue_metrics"]
|
||||
queue_metrics = {
|
||||
**date_metrics,
|
||||
**{k: v for k, v in last_values.items() if k not in date_metrics},
|
||||
}
|
||||
|
||||
total_length = sum(m["queue_length"] for m in queue_metrics.values())
|
||||
if total_length:
|
||||
total_average = sum(
|
||||
m["avg_waiting_time"] * m["queue_length"] / total_length
|
||||
for m in queue_metrics.values()
|
||||
)
|
||||
else:
|
||||
total_average = 0
|
||||
|
||||
queues_total.append(
|
||||
dict(
|
||||
date=date_data["timestamp"],
|
||||
avg_waiting_time=total_average,
|
||||
queue_length=total_length,
|
||||
)
|
||||
)
|
||||
|
||||
for k, v in date_metrics.items():
|
||||
last_values[k] = v
|
||||
|
||||
return dict(total=queues_total)
|
||||
|
||||
@classmethod
|
||||
def _extract_queue_metrics(cls, queue_buckets: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Extract ES data for single date and queue bucket
|
||||
"""
|
||||
queue_metrics = dict()
|
||||
for queue_data in queue_buckets:
|
||||
if not queue_data["doc_count"]:
|
||||
continue
|
||||
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
|
||||
queue_metrics[queue_data["key"]] = {
|
||||
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
|
||||
}
|
||||
return queue_metrics
|
||||
79
server/bll/redis_cache_manager.py
Normal file
79
server/bll/redis_cache_manager.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
self.expiration_interval - expiration interval in seconds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
|
||||
):
|
||||
self.state_class = state_class
|
||||
self.redis = redis
|
||||
self.expiration_interval = expiration_interval
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
with TimingContext("redis", "cache_set_state"):
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
with TimingContext("redis", "cache_get_state"):
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
with TimingContext("redis", "cache_delete_state"):
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
90
server/bll/statistics/resource_monitor.py
Normal file
90
server/bll/statistics/resource_monitor.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from threading import Thread, Lock
|
||||
from time import sleep
|
||||
|
||||
import attr
|
||||
import psutil
|
||||
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
class ResourceMonitor(Thread):
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
|
||||
@classmethod
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
def __init__(self, sample_interval_sec=5):
|
||||
super(ResourceMonitor, self).__init__(daemon=True)
|
||||
self.sample_interval_sec = sample_interval_sec
|
||||
self._lock = Lock()
|
||||
self._clear()
|
||||
|
||||
def _clear(self):
|
||||
sample = self._get_sample()
|
||||
self._avg = sample
|
||||
self._min = sample
|
||||
self._max = sample
|
||||
self._clear_time = datetime.utcnow()
|
||||
self._count = 1
|
||||
|
||||
@classmethod
|
||||
def _get_sample(cls) -> Sample:
|
||||
return cls.Sample(
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
|
||||
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
while not ThreadsManager.terminating:
|
||||
sleep(self.sample_interval_sec)
|
||||
|
||||
sample = self._get_sample()
|
||||
|
||||
with self._lock:
|
||||
self._min = self._min.min(sample)
|
||||
self._max = self._max.max(sample)
|
||||
self._avg = self._avg.avg(sample, self._count)
|
||||
self._count += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
""" Returns current resource statistics and clears internal resource statistics """
|
||||
with self._lock:
|
||||
min_ = attr.asdict(self._min)
|
||||
max_ = attr.asdict(self._max)
|
||||
avg = attr.asdict(self._avg)
|
||||
interval = datetime.utcnow() - self._clear_time
|
||||
self._clear()
|
||||
|
||||
return {
|
||||
"interval_sec": interval.total_seconds(),
|
||||
"num_cores": psutil.cpu_count(),
|
||||
**{
|
||||
k: {"min": v, "max": max_[k], "avg": avg[k]}
|
||||
for k, v in min_.items()
|
||||
}
|
||||
}
|
||||
304
server/bll/statistics/stats_reporter.py
Normal file
304
server/bll/statistics/stats_reporter.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import dpath
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from bll.util import get_server_uuid
|
||||
from bll.workers import WorkerStats, WorkerBLL
|
||||
from config import config
|
||||
from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from tools import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
Note: in trains we usually have only a single company
|
||||
"""
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
report_interval = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
sleep(report_interval.total_seconds())
|
||||
while not ThreadsManager.terminating:
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
).only("id"):
|
||||
stats = cls.get_statistics(company.id)
|
||||
cls.send_queue.put(stats)
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting stats: {str(ex)}")
|
||||
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
@classmethod
|
||||
@threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
url = config.get("apiserver.statistics.url")
|
||||
|
||||
retries = config.get("apiserver.statistics.max_retries", 5)
|
||||
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(retries))
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
session.headers["Content-type"] = "application/json"
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while not ThreadsManager.terminating:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
# Set a random backoff factor each time we send a report
|
||||
adapter.max_retries.backoff_factor = random.random() * max_backoff
|
||||
|
||||
session.post(url, data=dumps(report))
|
||||
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_statistics(cls, company_id: str) -> dict:
|
||||
"""
|
||||
Returns a statistics report per company
|
||||
"""
|
||||
return {
|
||||
"time": datetime.utcnow(),
|
||||
"company_id": company_id,
|
||||
"server": {
|
||||
"version": current_version,
|
||||
"deployment": get_deployment_type(),
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": cls.threads.resource_monitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
},
|
||||
"agents": cls._get_agents_statistics(company_id),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
|
||||
result = cls._get_resource_stats_per_agent(company_id, key="resources")
|
||||
dpath.merge(
|
||||
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
|
||||
)
|
||||
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
|
||||
|
||||
@classmethod
|
||||
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_resource_threshold_sec = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
).total_seconds()
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {
|
||||
"categories": {
|
||||
"terms": {"field": "category"},
|
||||
"aggs": {"count": {"cardinality": {"field": "variant"}}},
|
||||
},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": {
|
||||
"min": {"min": {"field": "value"}},
|
||||
"max": {"max": {"field": "value"}},
|
||||
"avg": {"avg": {"field": "value"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
|
||||
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
|
||||
names = {"cpu": "num_cores"}
|
||||
return {
|
||||
names[c["key"]]: safe_get(c, "count/value")
|
||||
for c in categories
|
||||
if c["key"] in names
|
||||
}
|
||||
|
||||
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
|
||||
names = {
|
||||
"cpu_usage": "cpu_usage",
|
||||
"memory_used": "mem_used_gb",
|
||||
"memory_free": "mem_free_gb",
|
||||
}
|
||||
return {
|
||||
names[m["key"]]: {
|
||||
"min": safe_get(m, "min/value"),
|
||||
"max": safe_get(m, "max/value"),
|
||||
"avg": safe_get(m, "avg/value"),
|
||||
}
|
||||
for m in metrics
|
||||
if m["key"] in names
|
||||
}
|
||||
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {
|
||||
key: {
|
||||
"interval_sec": agent_resource_threshold_sec,
|
||||
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
|
||||
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
|
||||
}
|
||||
}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_relevant_threshold = timedelta(
|
||||
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
|
||||
)
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
|
||||
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
|
||||
if not workers:
|
||||
return {}
|
||||
|
||||
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
|
||||
return {
|
||||
worker_id: {key: {**workers[worker_id], **stat}}
|
||||
for worker_id, stat in stats.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_active_workers(
|
||||
cls, company_id, from_timestamp: int, to_timestamp: int
|
||||
) -> dict:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats(
|
||||
cls, company_id, workers: Optional[Sequence] = None
|
||||
) -> dict:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"started": {"$exists": True, "$ne": None},
|
||||
"last_update": {"$exists": True, "$ne": None},
|
||||
"status": {"$nin": ["created", "queued"]},
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
"count": {"$sum": 1},
|
||||
"avg_run_time_sec": {
|
||||
"$avg": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$last_update", "$started"]},
|
||||
1000,
|
||||
]
|
||||
}
|
||||
},
|
||||
"avg_iterations": {"$avg": "$last_iteration"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"count": 1,
|
||||
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
|
||||
"avg_iterations": {"$trunc": "$avg_iterations"},
|
||||
}
|
||||
},
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(pipeline)
|
||||
}
|
||||
|
||||
|
||||
class WarningFilter(logging.Filter):
|
||||
@classmethod
|
||||
def attach(cls):
|
||||
from urllib3.connectionpool import (
|
||||
ConnectionPool,
|
||||
) # required to make sure the logger is created
|
||||
|
||||
assert ConnectionPool # make sure import is not optimized out
|
||||
|
||||
logging.getLogger("urllib3.connectionpool").addFilter(cls())
|
||||
|
||||
def filter(self, record):
|
||||
if (
|
||||
record.levelno == logging.WARNING
|
||||
and len(record.args) > 2
|
||||
and record.args[2] == "/stats"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
229
server/bll/task/hyperparams.py
Normal file
229
server/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@staticmethod
|
||||
def _get_task_for_update(
|
||||
company: str, id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apierrors import errors
|
||||
from bll.task import ChangeStatusRequest
|
||||
from config import config
|
||||
from database.model.task.task import TaskStatus, Task
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class NonResponsiveTasksWatchdog:
|
||||
threads = ThreadsManager()
|
||||
|
||||
class _Settings:
|
||||
"""
|
||||
Retrieves watchdog settings from the config file
|
||||
The properties are not cached so that the updates in
|
||||
the config file are reflected
|
||||
"""
|
||||
|
||||
_prefix = "services.tasks.non_responsive_tasks_watchdog"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return config.get(f"{self._prefix}.enabled", True)
|
||||
|
||||
@property
|
||||
def watch_interval_sec(self):
|
||||
return config.get(f"{self._prefix}.watch_interval_sec", 900)
|
||||
|
||||
@property
|
||||
def threshold_sec(self):
|
||||
return config.get(f"{self._prefix}.threshold_sec", 7200)
|
||||
|
||||
settings = _Settings()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
stopped = cls.cleanup_tasks(
|
||||
threshold_sec=cls.settings.threshold_sec
|
||||
)
|
||||
log.info(f"{stopped} non-responsive tasks stopped")
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
sleep(watch_interval)
|
||||
|
||||
@classmethod
|
||||
def cleanup_tasks(cls, threshold_sec):
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(seconds=threshold_sec)
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
|
||||
"id", "name", "status", "project", "last_update"
|
||||
)
|
||||
)
|
||||
log.info(f"{len(tasks)} non-responsive tasks found")
|
||||
if not tasks:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
try:
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
except errors.bad_request.FailedChangingTaskStatus:
|
||||
err_count += 1
|
||||
|
||||
return len(tasks) - err_count
|
||||
201
server/bll/task/param_utils.py
Normal file
201
server/bll/task/param_utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.task.task import Task
|
||||
from tools import safe_get
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
@@ -1,39 +1,66 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Mapping, Collection
|
||||
from urllib.parse import urlparse
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
|
||||
import dpath
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from six import string_types
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.fields import OutputDestinationField
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from database.model.task.metrics import EventStats, MetricEventStats
|
||||
from database.model.task.output import Output
|
||||
from database.model.task.task import Task, TaskStatus, TaskStatusMessage, TaskTags
|
||||
from database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
threads = ThreadsManager()
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
@@ -58,25 +85,24 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -84,17 +110,12 @@ class TaskBLL(object):
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
@@ -105,14 +126,13 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
@@ -143,42 +163,122 @@ class TaskBLL(object):
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task, force=False):
|
||||
assert isinstance(task, Task)
|
||||
def clone_task(
|
||||
cls,
|
||||
company_id,
|
||||
user_id,
|
||||
task_id,
|
||||
name: Optional[str] = None,
|
||||
comment: Optional[str] = None,
|
||||
parent: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
if execution_overrides:
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
||||
]
|
||||
now = datetime.utcnow()
|
||||
|
||||
with translate_errors_context():
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or [],
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project:
|
||||
Project.get_for_writing(company=task.company, id=task.project)
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
model = cls.validate_execution_model(task)
|
||||
if model and not force and not model.ready:
|
||||
raise errors.bad_request.ModelNotReady(
|
||||
"can't be used in a task", model=model.id
|
||||
)
|
||||
|
||||
if task.execution:
|
||||
if task.execution.parameters:
|
||||
cls._validate_execution_parameters(task.execution.parameters)
|
||||
|
||||
if task.output and task.output.destination:
|
||||
parsed_url = urlparse(task.output.destination)
|
||||
if parsed_url.scheme not in OutputDestinationField.schemes:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported scheme for output destination",
|
||||
dest=task.output.destination,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_execution_parameters(parameters):
|
||||
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
|
||||
if invalid_keys:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"execution.parameters keys contain whitespace", keys=invalid_keys
|
||||
)
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
@@ -218,7 +318,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.objects.aggregate(*pipeline)
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
@@ -236,7 +336,8 @@ class TaskBLL(object):
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_metrics: Mapping[str, Mapping[str, MetricEvent]] = None,
|
||||
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
"""
|
||||
@@ -248,7 +349,8 @@ class TaskBLL(object):
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_metrics: Last reported metrics summary.
|
||||
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
||||
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
"""
|
||||
@@ -259,10 +361,36 @@ class TaskBLL(object):
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_metrics is not None:
|
||||
extra_updates.update(last_metrics=last_metrics)
|
||||
if last_scalar_values is not None:
|
||||
|
||||
return Task.objects(id=task_id, company=company_id).update(
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
return {
|
||||
event_type: EventStats(last_update=event["timestamp"])
|
||||
for event_type, event in metric_data.items()
|
||||
}
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
Task.objects(id=task_id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **extra_updates
|
||||
)
|
||||
|
||||
@@ -375,14 +503,30 @@ class TaskBLL(object):
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = cls.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=("status", "project", "tags", "last_update"),
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
if TaskTags.development in task.tags:
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
@@ -398,53 +542,158 @@ class TaskBLL(object):
|
||||
).execute()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start_non_responsive_tasks_watchdog(cls):
|
||||
log = config.logger("non_responsive_tasks_watchdog")
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
|
||||
)
|
||||
)
|
||||
while True:
|
||||
sleep(
|
||||
config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec",
|
||||
900,
|
||||
)
|
||||
)
|
||||
try:
|
||||
def add_or_update_artifacts(
|
||||
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
key = attrgetter("key", "mode")
|
||||
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
if not artifacts:
|
||||
return [], []
|
||||
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
|
||||
artifacts: List[Artifact] = [
|
||||
Artifact(**artifact.to_struct()) for artifact in artifacts
|
||||
]
|
||||
|
||||
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
|
||||
|
||||
for retry in range(attempts):
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(
|
||||
status__in=relevant_status, last_update__lt=ref_time
|
||||
).only("id", "name", "status", "project", "last_update")
|
||||
current = list(map(key, task.execution.artifacts))
|
||||
updated = [a for a in artifacts if key(a) in current]
|
||||
added = [a for a in artifacts if a not in updated]
|
||||
|
||||
filter = {"_id": task_id, "company": company_id}
|
||||
update = {}
|
||||
array_filters = None
|
||||
if current:
|
||||
filter["execution.artifacts"] = {
|
||||
"$size": len(current),
|
||||
"$all": [
|
||||
*(
|
||||
{"$elemMatch": {"key": key, "mode": mode}}
|
||||
for key, mode in current
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
filter["$or"] = [
|
||||
{"execution.artifacts": {"$exists": False}},
|
||||
{"execution.artifacts": {"$size": 0}},
|
||||
]
|
||||
|
||||
if added:
|
||||
update["$push"] = {
|
||||
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
|
||||
}
|
||||
if updated:
|
||||
update["$set"] = {
|
||||
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
|
||||
for index, artifact in enumerate(updated)
|
||||
}
|
||||
array_filters = [
|
||||
{
|
||||
f"artifact{index}.key": artifact.key,
|
||||
f"artifact{index}.mode": artifact.mode,
|
||||
}
|
||||
for index, artifact in enumerate(updated)
|
||||
]
|
||||
|
||||
if not update:
|
||||
return [], []
|
||||
|
||||
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
|
||||
filter=filter,
|
||||
update=update,
|
||||
array_filters=array_filters,
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if tasks:
|
||||
if result.matched_count >= 1:
|
||||
break
|
||||
|
||||
log.info(f"Stopping {len(tasks)} non-responsive tasks")
|
||||
wait_msec = random() * int(
|
||||
config.get("services.tasks.artifacts.update_retry_msec", 500)
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
log.warning(
|
||||
f"Failed to update artifacts for task {task_id} (updated by another party),"
|
||||
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
|
||||
)
|
||||
|
||||
log.info(f"Done")
|
||||
sleep(wait_msec / 1000)
|
||||
else:
|
||||
raise errors.server_error.UpdateFailed(
|
||||
"task artifacts updated by another party"
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"total": 1,
|
||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
@@ -7,7 +7,7 @@ import six
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from database.utils import get_options
|
||||
from timing_context import TimingContext
|
||||
from utilities.attrs import typed_attrs
|
||||
@@ -25,9 +25,10 @@ class ChangeStatusRequest(object):
|
||||
status_message = attr.ib(type=six.string_types, default="")
|
||||
force = attr.ib(type=bool, default=False)
|
||||
allow_same_state_transition = attr.ib(type=bool, default=True)
|
||||
current_status_override = attr.ib(default=None)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
current_status = self.task.status
|
||||
current_status = self.current_status_override or self.task.status
|
||||
project_id = self.task.project
|
||||
|
||||
# Verify new status is allowed from current status (will throw exception if not valid)
|
||||
@@ -44,6 +45,9 @@ class ChangeStatusRequest(object):
|
||||
last_update=now,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
fields["pull__system_tags"] = TaskSystemTags.development
|
||||
|
||||
def safe_mongoengine_key(key):
|
||||
return f"__{key}" if key in control else key
|
||||
|
||||
@@ -66,6 +70,10 @@ class ChangeStatusRequest(object):
|
||||
)
|
||||
|
||||
update_project_time(project_id)
|
||||
|
||||
# make sure that _raw_ queries are not returned back to the client
|
||||
fields.pop("__raw__", None)
|
||||
|
||||
return dict(updated=updated, fields=fields)
|
||||
|
||||
def validate_transition(self, current_status):
|
||||
@@ -95,7 +103,8 @@ def validate_status_change(current_status, new_status):
|
||||
|
||||
|
||||
state_machine = {
|
||||
TaskStatus.created: {TaskStatus.in_progress},
|
||||
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
|
||||
TaskStatus.in_progress: {
|
||||
TaskStatus.stopped,
|
||||
TaskStatus.failed,
|
||||
@@ -125,7 +134,7 @@ state_machine = {
|
||||
TaskStatus.published,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.created,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -135,9 +144,11 @@ def get_possible_status_changes(current_status):
|
||||
:return possible states from current state
|
||||
"""
|
||||
possible = state_machine.get(current_status)
|
||||
assert (
|
||||
possible is not None
|
||||
), f"Current status {current_status} not supported by state machine"
|
||||
if possible is None:
|
||||
raise errors.server_error.InternalError(
|
||||
f"Current status {current_status} not supported by state machine"
|
||||
)
|
||||
|
||||
return possible
|
||||
|
||||
|
||||
|
||||
80
server/bll/util.py
Normal file
80
server/bll/util.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import functools
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
|
||||
|
||||
from database.model import AttributedDocument
|
||||
from database.model.settings import Settings
|
||||
|
||||
|
||||
def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
builds a dictionary with the requested keys and values lists
|
||||
:param key_names: names of the keys in the resulting dictionary
|
||||
:param data: sequence of dictionaries to extract values from
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
"""
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
|
||||
|
||||
class SetFieldsResolver:
|
||||
"""
|
||||
The class receives set fields dictionary
|
||||
and for the set fields that require 'min' or 'max'
|
||||
operation replace them with a simple set in case the
|
||||
DB document does not have these fields set
|
||||
"""
|
||||
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
return self.fields[name]
|
||||
return name
|
||||
|
||||
def get_fields(self, doc: AttributedDocument):
|
||||
"""
|
||||
For the given document return the set fields instructions
|
||||
with min/max operations replaced with a single set in case
|
||||
the document does not have the field set
|
||||
"""
|
||||
return {
|
||||
self._get_updated_name(doc, name): value
|
||||
for name, value in self.orig_fields.items()
|
||||
}
|
||||
|
||||
def get_names(self) -> Set[str]:
|
||||
"""
|
||||
Returns the names of the fields that had min/max modifiers
|
||||
in the format suitable for projection (dot separated)
|
||||
"""
|
||||
return set(name.replace("__", ".") for name in self.fields.values())
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_server_uuid() -> Optional[str]:
|
||||
return Settings.get_by_key("server.uuid")
|
||||
438
server/bll/workers/__init__.py
Normal file
438
server/bll/workers/__init__.py
Normal file
@@ -0,0 +1,438 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
|
||||
import es_factory
|
||||
from apierrors import APIError
|
||||
from apierrors.errors import bad_request, server_error
|
||||
from apimodels.workers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
IdNameEntry,
|
||||
WorkerEntry,
|
||||
StatusReportRequest,
|
||||
WorkerResponseEntry,
|
||||
QueueEntry,
|
||||
MachineStats,
|
||||
)
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
def stats(self) -> WorkerStats:
|
||||
return self._stats
|
||||
|
||||
def register_worker(
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
worker: str,
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
timeout = timeout or DEFAULT_TIMEOUT
|
||||
queues = queues or []
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=user_id, company=company_id)
|
||||
user = User.objects(**query).only("id", "name").first()
|
||||
if not user:
|
||||
raise bad_request.InvalidUserId(**query)
|
||||
company = Company.objects(id=company_id).only("id", "name").first()
|
||||
if not company:
|
||||
raise server_error.InternalError("invalid company", company=company_id)
|
||||
|
||||
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
|
||||
if len(queue_objs) < len(queues):
|
||||
invalid = set(queues).difference(q.id for q in queue_objs)
|
||||
raise bad_request.InvalidQueueId(ids=invalid)
|
||||
|
||||
now = datetime.utcnow()
|
||||
entry = WorkerEntry(
|
||||
key=key,
|
||||
id=worker,
|
||||
user=user.to_proper_dict(),
|
||||
company=company.to_proper_dict(),
|
||||
ip=ip,
|
||||
queues=queues,
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
|
||||
return entry
|
||||
|
||||
def unregister_worker(self, company_id: str, user_id: str, worker: str) -> None:
|
||||
"""
|
||||
Unregister a worker
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
|
||||
"""
|
||||
with TimingContext("redis", "workers_unregister"):
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
entry = self._get_worker(company_id, user_id, report.worker)
|
||||
|
||||
try:
|
||||
entry.ip = ip
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
company_name=entry.company.name,
|
||||
worker=report.worker,
|
||||
timestamp=report.timestamp,
|
||||
task=report.task,
|
||||
machine_stats=report.machine_stats,
|
||||
)
|
||||
|
||||
entry.queue = report.queue
|
||||
|
||||
if report.queues:
|
||||
entry.queues = report.queues
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
update = dict(
|
||||
last_worker=report.worker,
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
if not task:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
except Exception as e:
|
||||
msg = "Failed processing worker status report"
|
||||
log.exception(msg)
|
||||
raise server_error.DataError(msg, err=e.args[0])
|
||||
finally:
|
||||
self._save_worker(entry)
|
||||
|
||||
def get_all(
|
||||
self, company_id: str, last_seen: Optional[int] = None
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
:param company_id: worker's company id
|
||||
:param last_seen: period in seconds to check. Min value is 1 second
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
if last_seen:
|
||||
ref_time = datetime.utcnow() - timedelta(seconds=max(1, last_seen))
|
||||
workers = [
|
||||
w
|
||||
for w in workers
|
||||
if w.last_activity_time.replace(tzinfo=None) >= ref_time
|
||||
]
|
||||
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self, company_id: str, last_seen: int
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(company_id=company_id, last_seen=last_seen),
|
||||
)
|
||||
)
|
||||
|
||||
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
|
||||
all_queues = set(
|
||||
itertools.chain.from_iterable(helper.queue_ids for helper in helpers)
|
||||
)
|
||||
|
||||
queues_info = {}
|
||||
if all_queues:
|
||||
projection = [
|
||||
{"$match": {"_id": {"$in": list(all_queues)}}},
|
||||
{
|
||||
"$project": {
|
||||
"name": 1,
|
||||
"next_entry": {"$arrayElemAt": ["$entries", 0]},
|
||||
"num_entries": {"$size": "$entries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
None,
|
||||
(
|
||||
safe_get(info, "next_entry/task")
|
||||
for info in queues_info.values()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
tasks_info = {}
|
||||
if task_ids:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration"
|
||||
)
|
||||
}
|
||||
|
||||
def update_queue_entries(*entries):
|
||||
for entry in entries:
|
||||
if not entry:
|
||||
continue
|
||||
info = queues_info.get(entry.id, None)
|
||||
if not info:
|
||||
continue
|
||||
entry.name = info.get("name", None)
|
||||
entry.num_tasks = info.get("num_entries", 0)
|
||||
task_id = safe_get(info, "next_entry/task")
|
||||
if task_id:
|
||||
task = tasks_info.get(task_id, None)
|
||||
entry.next_task = IdNameEntry(
|
||||
id=task_id, name=task.name if task else None
|
||||
)
|
||||
|
||||
for helper in helpers:
|
||||
worker = helper.worker
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (
|
||||
int((datetime.utcnow() - task.started).total_seconds() * 1000)
|
||||
if task.started
|
||||
else 0
|
||||
)
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
if worker.queues:
|
||||
update_queue_entries(*worker.queues)
|
||||
|
||||
return [helper.worker for helper in helpers]
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_key(company: str, user: str, worker_id: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"worker_{company}_{user}_{worker_id}"
|
||||
|
||||
def _get_worker(self, company_id: str, user_id: str, worker: str) -> WorkerEntry:
|
||||
"""
|
||||
Get a worker entry for the provided worker ID. The entry is loaded from Redis
|
||||
if it exists (i.e. worker has already been registered), otherwise the worker
|
||||
is registered and its entry stored into Redis).
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:raise bad_request.InvalidWorkerId: in case the worker id was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
key = self._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
with TimingContext("redis", "get_worker"):
|
||||
data = self.redis.get(key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
entry = WorkerEntry.from_json(data)
|
||||
if not entry.key:
|
||||
entry.key = key
|
||||
self._save_worker(entry)
|
||||
return entry
|
||||
except Exception as e:
|
||||
msg = "Failed parsing worker entry"
|
||||
log.exception(msg)
|
||||
raise server_error.DataError(msg, err=e.args[0])
|
||||
|
||||
# Failed loading worker from Redis
|
||||
if config.get("apiserver.workers.auto_register", False):
|
||||
try:
|
||||
return self.register_worker(company_id, user_id, worker)
|
||||
except Exception:
|
||||
log.error(
|
||||
"Failed auto registration of {} for company {}".format(
|
||||
worker, company_id
|
||||
)
|
||||
)
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get(
|
||||
self, company: str, user: str = "*", worker_id: str = "*"
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
match = self._get_worker_key(company, user, worker_id)
|
||||
with TimingContext("redis", "workers_get_all"):
|
||||
res = self.redis.scan_iter(match)
|
||||
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
def _log_stats_to_es(
|
||||
self,
|
||||
company_id: str,
|
||||
company_name: str,
|
||||
worker: str,
|
||||
timestamp: int,
|
||||
task: str,
|
||||
machine_stats: MachineStats,
|
||||
) -> bool:
|
||||
"""
|
||||
Actually writing the worker statistics to Elastic
|
||||
:return: True if successful, False otherwise
|
||||
"""
|
||||
es_index = (
|
||||
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
|
||||
f"{self._get_es_index_suffix()}"
|
||||
)
|
||||
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
company=company_name,
|
||||
task=task,
|
||||
category=category,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
value=float(value),
|
||||
),
|
||||
)
|
||||
|
||||
actions = []
|
||||
for field, value in machine_stats.to_struct().items():
|
||||
if not value:
|
||||
continue
|
||||
category = field.partition("_")[0]
|
||||
metric = field
|
||||
if not isinstance(value, (list, tuple)):
|
||||
actions.append(make_doc(category, metric, "total", value))
|
||||
else:
|
||||
actions.extend(
|
||||
make_doc(category, metric, str(i), val)
|
||||
for i, val in enumerate(value)
|
||||
)
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class WorkerConversionHelper:
|
||||
worker: WorkerResponseEntry
|
||||
task_id: str
|
||||
queue_ids: Set[str]
|
||||
|
||||
@classmethod
|
||||
def from_worker_entry(cls, worker: WorkerEntry):
|
||||
data = worker.to_struct()
|
||||
queue = data.pop("queue", None) or None
|
||||
queue_ids = set(data.pop("queues", []))
|
||||
queues = [QueueEntry(id=id) for id in queue_ids]
|
||||
if queue:
|
||||
queue = next((q for q in queues if q.id == queue), None)
|
||||
return cls(
|
||||
worker=WorkerResponseEntry(queues=queues, queue=queue, **data),
|
||||
task_id=worker.task.id if worker.task else None,
|
||||
queue_ids=queue_ids,
|
||||
)
|
||||
243
server/bll/workers/stats.py
Normal file
243
server/bll/workers/stats.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from operator import attrgetter
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apierrors.errors import bad_request
|
||||
from apimodels.workers import AggregationType, GetStatsRequest, StatItem
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerStats:
|
||||
def __init__(self, es):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def worker_stats_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"worker_stats_{company_id}_"
|
||||
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
def get_worker_stats_keys(
|
||||
self, company_id: str, worker_ids: Optional[Sequence[str]]
|
||||
) -> dict:
|
||||
"""
|
||||
Get dictionary of metric types grouped by categories
|
||||
:param company_id: company id
|
||||
:param worker_ids: optional list of workers to get metric types from.
|
||||
If not specified them metrics for all the company workers returned
|
||||
:return:
|
||||
"""
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"categories": {
|
||||
"terms": {"field": "category"},
|
||||
"aggs": {"metrics": {"terms": {"field": "metric"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
if worker_ids:
|
||||
es_req["query"] = QueryBuilder.terms("worker", worker_ids)
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
|
||||
return {
|
||||
category["key"]: [
|
||||
metric["key"] for metric in category["metrics"]["buckets"]
|
||||
]
|
||||
for category in res["aggregations"]["categories"]["buckets"]
|
||||
}
|
||||
|
||||
def get_worker_stats(self, company_id: str, request: GetStatsRequest) -> dict:
|
||||
"""
|
||||
Get statistics for company workers metrics in the specified time range
|
||||
Returned as date histograms for different aggregation types
|
||||
grouped by worker, metric type (and optionally metric variant)
|
||||
Buckets with no metrics are not returned
|
||||
Note: all the statistics are retrieved as one ES query
|
||||
"""
|
||||
if request.from_date >= request.to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
def get_dates_agg() -> dict:
|
||||
es_to_agg_types = (
|
||||
("avg", AggregationType.avg.value),
|
||||
("min", AggregationType.min.value),
|
||||
("max", AggregationType.max.value),
|
||||
)
|
||||
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
agg_type: {es_agg: {"field": "value"}}
|
||||
for es_agg, agg_type in es_to_agg_types
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def get_variants_agg() -> dict:
|
||||
return {
|
||||
"variants": {"terms": {"field": "variant"}, "aggs": get_dates_agg()}
|
||||
}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": get_variants_agg()
|
||||
if request.split_by_variant
|
||||
else get_dates_agg(),
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
query_terms = [
|
||||
QueryBuilder.dates_range(request.from_date, request.to_date),
|
||||
QueryBuilder.terms("metric", {item.key for item in request.items}),
|
||||
]
|
||||
if request.worker_ids:
|
||||
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
|
||||
es_req["query"] = {"bool": {"must": query_terms}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
return self._extract_results(data, request.items, request.split_by_variant)
|
||||
|
||||
@staticmethod
|
||||
def _extract_results(
|
||||
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
|
||||
leave only aggregation types requested by the user and return a clean dictionary
|
||||
and return a "clean" dictionary of
|
||||
:param data: aggregation data retrieved from ES
|
||||
:param request_items: aggs types requested by the user
|
||||
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
|
||||
"""
|
||||
if "aggregations" not in data:
|
||||
return {}
|
||||
|
||||
items_by_key = bucketize(request_items, key=attrgetter("key"))
|
||||
aggs_per_metric = {
|
||||
key: [item.aggregation for item in items]
|
||||
for key, items in items_by_key.items()
|
||||
}
|
||||
|
||||
def extract_date_stats(date: dict, metric_key) -> dict:
|
||||
return {
|
||||
"date": date["key"],
|
||||
"count": date["doc_count"],
|
||||
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
|
||||
}
|
||||
|
||||
def extract_metric_results(
|
||||
metric_or_variant: dict, metric_key: str
|
||||
) -> Sequence[dict]:
|
||||
return [
|
||||
extract_date_stats(date, metric_key)
|
||||
for date in metric_or_variant["dates"]["buckets"]
|
||||
if date["doc_count"]
|
||||
]
|
||||
|
||||
def extract_variant_results(metric: dict) -> dict:
|
||||
metric_key = metric["key"]
|
||||
return {
|
||||
variant["key"]: extract_metric_results(variant, metric_key)
|
||||
for variant in metric["variants"]["buckets"]
|
||||
}
|
||||
|
||||
def extract_worker_results(worker: dict) -> dict:
|
||||
return {
|
||||
metric["key"]: extract_variant_results(metric)
|
||||
if split_by_variant
|
||||
else extract_metric_results(metric, metric["key"])
|
||||
for metric in worker["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
return {
|
||||
worker["key"]: extract_worker_results(worker)
|
||||
for worker in data["aggregations"]["workers"]["buckets"]
|
||||
}
|
||||
|
||||
def get_activity_report(
|
||||
self,
|
||||
company_id: str,
|
||||
from_date: float,
|
||||
to_date: float,
|
||||
interval: int,
|
||||
active_only: bool,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get statistics for company workers metrics in the specified time range
|
||||
Returned as date histograms for different aggregation types
|
||||
grouped by worker, metric type (and optionally metric variant)
|
||||
Note: all the statistics are retrieved using one ES query
|
||||
"""
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
must = [QueryBuilder.dates_range(from_date, to_date)]
|
||||
if active_only:
|
||||
must.append({"exists": {"field": "task"}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_worker_activity_report"
|
||||
):
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if "aggregations" not in data:
|
||||
return {}
|
||||
|
||||
ret = [
|
||||
dict(date=date["key"], count=date["workers_count"]["value"])
|
||||
for date in data["aggregations"]["dates"]["buckets"]
|
||||
]
|
||||
|
||||
if ret and ret[-1]["date"] > (to_date - 0.9 * interval):
|
||||
# remove last interval if it's incomplete. Allow 10% tolerance
|
||||
ret.pop()
|
||||
|
||||
return ret
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -14,7 +16,10 @@ from pyparsing import (
|
||||
|
||||
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
@@ -43,9 +48,23 @@ class BasicConfig:
|
||||
def logger(self, name):
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self):
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
@@ -59,17 +78,22 @@ class BasicConfig:
|
||||
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
|
||||
]
|
||||
if invalid:
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -21,12 +21,41 @@
|
||||
version {
|
||||
required: false
|
||||
default: 1.0
|
||||
# if set then calls to endpoints with the version
|
||||
# greater that the current max version will be rejected
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
auth {
|
||||
@@ -72,4 +101,40 @@
|
||||
}
|
||||
|
||||
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
|
||||
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
}
|
||||
|
||||
check_for_updates {
|
||||
enabled: true
|
||||
|
||||
# Check for updates every 24 hours
|
||||
check_interval_sec: 86400
|
||||
|
||||
url: "https://updates.trains.allegro.ai/updates"
|
||||
|
||||
component_name: "trains-server"
|
||||
|
||||
# GET request timeout
|
||||
request_timeout_sec: 3.0
|
||||
}
|
||||
|
||||
statistics {
|
||||
# Note: statistics are sent ONLY if the user has actively opted-in
|
||||
supported: true
|
||||
|
||||
url: "https://updates.trains.allegro.ai/stats"
|
||||
|
||||
report_interval_hours: 24
|
||||
agent_relevant_threshold_days: 30
|
||||
|
||||
max_retries: 5
|
||||
max_backoff_sec: 5
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
@@ -19,3 +30,16 @@ mongo {
|
||||
host: "mongodb://127.0.0.1:27017/auth"
|
||||
}
|
||||
}
|
||||
|
||||
redis {
|
||||
apiserver {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
db: 0
|
||||
}
|
||||
workers {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
db: 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,17 +13,21 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
role: "system"
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
role: "system"
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
16
server/config/default/services/auth.conf
Normal file
16
server/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,13 @@
|
||||
{
|
||||
es_index_prefix:"events"
|
||||
}
|
||||
es_index_prefix: "events"
|
||||
|
||||
ignore_iteration {
|
||||
metrics: [":monitor:machine", ":monitor:gpu"]
|
||||
}
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
||||
3
server/config/default/services/organization.conf
Normal file
3
server/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
8
server/config/default/services/projects.conf
Normal file
8
server/config/default/services/projects.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured_order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
@@ -1,7 +1,14 @@
|
||||
non_responsive_tasks_watchdog {
|
||||
enabled: true
|
||||
|
||||
# In-progress tasks older than this value in seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
# Watchdog will sleep for this number of seconds after each cycle
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
|
||||
artifacts {
|
||||
update_attempts: 10
|
||||
update_retry_msec: 500
|
||||
}
|
||||
@@ -1,20 +1,47 @@
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
from version import __version__
|
||||
|
||||
from config import config
|
||||
|
||||
root = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def _get(prop_name, env_suffix=None, default=""):
|
||||
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
|
||||
if value:
|
||||
return value
|
||||
|
||||
try:
|
||||
return (root / prop_name).read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_build_number():
|
||||
try:
|
||||
return (root / "BUILD").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
return _get("BUILD")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_version():
|
||||
try:
|
||||
return (root / "VERSION").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
return _get("VERSION", default=__version__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_commit_number():
|
||||
return _get("COMMIT")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_deployment_type() -> str:
|
||||
return _get("DEPLOY", env_suffix="DEPLOYMENT_TYPE", default="manual")
|
||||
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from furl import furl
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
@@ -15,7 +16,12 @@ log = config.logger("database")
|
||||
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = "MONGODB_SERVICE_SERVICE_HOST"
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
|
||||
_entries = []
|
||||
|
||||
@@ -34,19 +40,27 @@ def initialize():
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = getenv(OVERRIDE_HOST_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
@@ -65,6 +79,10 @@ def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from operator import itemgetter
|
||||
from sys import maxsize
|
||||
from typing import Type, Tuple
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
@@ -11,7 +12,11 @@ from mongoengine import (
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
@@ -88,104 +93,6 @@ class CustomFloatField(FloatField):
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
# TODO: bucket name should be at most 63 characters....
|
||||
aws_s3_bucket_only_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
)
|
||||
|
||||
aws_s3_url_with_bucket_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
|
||||
)
|
||||
|
||||
non_aws_s3_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
|
||||
)
|
||||
|
||||
google_gs_bucket_only_regex = (
|
||||
r"^gs://"
|
||||
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
|
||||
)
|
||||
|
||||
file_regex = r"^file://"
|
||||
|
||||
generic_url_regex = (
|
||||
r"^%s://" # scheme placeholder
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
)
|
||||
|
||||
path_suffix = r"(?:/?|[/?]\S+)$"
|
||||
file_path_suffix = r"(?:/\S*[^/]+)$"
|
||||
|
||||
|
||||
class _RegexURLField(StringField):
|
||||
_regex = []
|
||||
|
||||
def __init__(self, regex, **kwargs):
|
||||
super(_RegexURLField, self).__init__(**kwargs)
|
||||
regex = regex if isinstance(regex, (tuple, list)) else [regex]
|
||||
self._regex = [
|
||||
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
|
||||
for e in regex
|
||||
]
|
||||
|
||||
def validate(self, value):
|
||||
# Check first if the scheme is valid
|
||||
if not any(regex for regex in self._regex if regex.match(value)):
|
||||
self.error("Invalid URL: {}".format(value))
|
||||
return
|
||||
|
||||
|
||||
class OutputDestinationField(_RegexURLField):
|
||||
""" A field representing task output URL """
|
||||
|
||||
schemes = ["s3", "gs", "file"]
|
||||
_expressions = (
|
||||
aws_s3_bucket_only_regex + path_suffix,
|
||||
aws_s3_url_with_bucket_regex + path_suffix,
|
||||
non_aws_s3_regex + path_suffix,
|
||||
google_gs_bucket_only_regex + path_suffix,
|
||||
file_regex + path_suffix,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class SupportedURLField(_RegexURLField):
|
||||
""" A field representing a model URL """
|
||||
|
||||
schemes = ["s3", "gs", "file", "http", "https"]
|
||||
|
||||
_expressions = tuple(
|
||||
pattern + file_path_suffix
|
||||
for pattern in (
|
||||
aws_s3_bucket_only_regex,
|
||||
aws_s3_url_with_bucket_regex,
|
||||
non_aws_s3_regex,
|
||||
google_gs_bucket_only_regex,
|
||||
file_regex,
|
||||
(generic_url_regex % "http"),
|
||||
(generic_url_regex % "https"),
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
@@ -221,17 +128,82 @@ def contains_empty_key(d):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
|
||||
|
||||
class SafeSortedListField(SortedListField):
|
||||
"""
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
except TypeError:
|
||||
return self._safe_to_mongo(*args, **kwargs)
|
||||
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
|
||||
class UnionField(DynamicField):
|
||||
def __init__(self, types, *args, **kwargs):
|
||||
super(UnionField, self).__init__(*args, **kwargs)
|
||||
self.types: Tuple[Type] = tuple(types)
|
||||
|
||||
def validate(self, value, clean=True):
|
||||
if not isinstance(value, self.types):
|
||||
type_names = [t.__name__ for t in self.types]
|
||||
expected = " or ".join(
|
||||
filter(
|
||||
None,
|
||||
(", ".join(type_names[:-1]), type_names[-1]))
|
||||
)
|
||||
self.error(
|
||||
f"Expected {expected}, got {type(value).__name__}: {value}"
|
||||
)
|
||||
super(UnionField, self).validate(value, clean)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
@@ -54,3 +56,7 @@ def validate_id(cls, company, **kwargs):
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
@@ -43,15 +45,17 @@ class Role(object):
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField(unique_with="company")
|
||||
name = StringField()
|
||||
|
||||
created = DateTimeField()
|
||||
""" User auth entry creation time """
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document
|
||||
from six import string_types
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import get_company_or_none_constraint, get_fields_with_attr
|
||||
from database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -28,7 +35,12 @@ class AuthDocument(Document):
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
def to_proper_dict(
|
||||
self: Union["ProperDictMixin", Document],
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
@@ -54,8 +66,9 @@ class ProperDictMixin(object):
|
||||
|
||||
class GetMixin(PropsMixin):
|
||||
_text_score = "$text_score"
|
||||
|
||||
_projection_key = "projection"
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
@@ -64,11 +77,13 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "id"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
):
|
||||
@@ -84,11 +99,48 @@ class GetMixin(PropsMixin):
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
id,
|
||||
*,
|
||||
_only=None,
|
||||
include_public=False,
|
||||
**kwargs,
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
@@ -155,17 +207,7 @@ class GetMixin(PropsMixin):
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -209,12 +251,71 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
return get_company_or_none_constraint(company)
|
||||
return Q(company=company)
|
||||
|
||||
@classmethod
|
||||
def validate_order_by(cls, parameters, search_text) -> Sequence:
|
||||
"""
|
||||
Validate and extract order_by params as a list
|
||||
"""
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if not order_by:
|
||||
return []
|
||||
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
if not search_text and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
@@ -245,11 +346,40 @@ class GetMixin(PropsMixin):
|
||||
return override_projection
|
||||
if not parameters:
|
||||
return []
|
||||
return parameters.get("projection") or parameters.get("only_fields", [])
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def set_default_ordering(cls, parameters, value):
|
||||
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
parameters[cls._projection_key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
|
||||
return parameters.get(cls._ordering_key)
|
||||
|
||||
@classmethod
|
||||
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters[cls._ordering_key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
@@ -332,8 +462,9 @@ class GetMixin(PropsMixin):
|
||||
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
|
||||
be raised.
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
requested, each contains only the requested projection). If False, a QuerySet object is returned
|
||||
(lazy evaluated). If return_dicts is requested then the entities with the None value in order_by field
|
||||
are returned last in the ordering.
|
||||
:param company: Company ID (required)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
||||
@@ -356,16 +487,38 @@ class GetMixin(PropsMixin):
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
return_dicts=return_dicts,
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls, query, parameters=None, override_projection=None, return_dicts=True
|
||||
cls: Union["GetMixin", Document],
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
@@ -375,59 +528,138 @@ class GetMixin(PropsMixin):
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if order_by:
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
|
||||
search_text = parameters.get("search_text")
|
||||
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
|
||||
if not search_text and order_by and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = (
|
||||
qs.order_by(order_by)
|
||||
if isinstance(order_by, string_types)
|
||||
else qs.order_by(*order_by)
|
||||
)
|
||||
if only:
|
||||
qs = qs.order_by(*order_by)
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
|
||||
if return_dicts:
|
||||
return [obj.to_proper_dict(only=only) for obj in qs]
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def _get_many_override_none_ordering(
|
||||
cls: Union[Document, "GetMixin"],
|
||||
query: Q = None,
|
||||
parameters: dict = None,
|
||||
override_projection: Collection[str] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
If the first order field is a user defined parameter (either from execution.parameters,
|
||||
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
order_field = first(
|
||||
field for field in order_by if not field.startswith("$")
|
||||
)
|
||||
if (
|
||||
order_field
|
||||
and not order_field.startswith("-")
|
||||
and "[" not in order_field
|
||||
):
|
||||
params = {}
|
||||
mongo_field = order_field.replace(".", "__")
|
||||
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
||||
params["is_list"] = True
|
||||
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
||||
params["empty_value"] = ""
|
||||
non_empty = query & field_exists(mongo_field, **params)
|
||||
empty = query & field_does_not_exist(mongo_field, **params)
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_for_writing(
|
||||
cls, *args, _only: Collection[str] = None, **kwargs
|
||||
@@ -464,8 +696,8 @@ class UpdateMixin(object):
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = dict(
|
||||
get_fields_with_attr(cls, "user_set_allowed")
|
||||
res = cls.__user_set_allowed_fields = get_fields_choices(
|
||||
cls, "user_set_allowed"
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -488,7 +720,13 @@ class UpdateMixin(object):
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
def safe_update(
|
||||
cls: Union["UpdateMixin", Document],
|
||||
company_id,
|
||||
id,
|
||||
partial_update_dict,
|
||||
injected_update=None,
|
||||
):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
@@ -503,7 +741,52 @@ class UpdateMixin(object):
|
||||
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
""" Provide convenience methods for a subclass of mongoengine.Document """
|
||||
|
||||
pass
|
||||
@classmethod
|
||||
def aggregate(
|
||||
cls: Union["DbModelMixin", Document],
|
||||
pipeline: Sequence[dict],
|
||||
allow_disk_use=None,
|
||||
**kwargs,
|
||||
) -> CommandCursor:
|
||||
"""
|
||||
Aggregate objects of this document class according to the provided pipeline.
|
||||
:param pipeline: a list of dictionaries describing the pipeline stages
|
||||
:param allow_disk_use: if True, allow the server to use disk space if aggregation query cannot fit in memory.
|
||||
If None, default behavior will be used (see apiserver.conf/mongo/aggregate/allow_disk_use)
|
||||
:param kwargs: additional keyword arguments passed to mongoengine
|
||||
:return:
|
||||
"""
|
||||
kwargs.update(
|
||||
allowDiskUse=allow_disk_use
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
@@ -525,5 +808,5 @@ def validate_id(cls, company, **kwargs):
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
||||
)
|
||||
|
||||
@@ -1,23 +1,36 @@
|
||||
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentField,
|
||||
StringField,
|
||||
Q,
|
||||
BooleanField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class ReportStatsOption(EmbeddedDocument):
|
||||
enabled = BooleanField(default=False) # opt-in for statistics reporting
|
||||
enabled_version = StringField() # server version when enabled
|
||||
enabled_time = DateTimeField() # time when enabled
|
||||
enabled_user = StringField() # ID of user who enabled
|
||||
|
||||
|
||||
class CompanyDefaults(EmbeddedDocument):
|
||||
cluster = StringField()
|
||||
stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption)
|
||||
|
||||
|
||||
class Company(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
meta = {"db_alias": Database.backend, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
@@ -12,45 +13,63 @@ from database.model.user import User
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
'name': '%s.model.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$parent',
|
||||
'$task',
|
||||
'$project',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'parent': 5,
|
||||
'task': 3,
|
||||
'project': 3,
|
||||
}
|
||||
}
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"parent": 5,
|
||||
"task": 3,
|
||||
"project": 3,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"user",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field='Model', required=False)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = SupportedURLField(default='', user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from mongoengine import MapField, IntField
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
class ModelLabels(SafeMapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
super(ModelLabels, self).__init__(
|
||||
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
non_empty_values = list(filter(None, value.values()))
|
||||
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import OutputDestinationField, StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -9,19 +9,21 @@ from database.model.base import GetMixin
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -34,6 +36,11 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list)
|
||||
default_output_destination = OutputDestinationField()
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
46
server/database/model/queue.py
Normal file
46
server/database/model/queue.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
from database.model.task.task import Task
|
||||
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True, unique_with="company", min_length=3, user_set_allowed=True
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
57
server/database/model/settings.py
Normal file
57
server/database/model/settings.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mongoengine import Document, StringField, DynamicField, Q
|
||||
from mongoengine.errors import NotUniqueError
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
server__uuid = "server.uuid"
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
key = StringField(primary_key=True)
|
||||
value = DynamicField()
|
||||
|
||||
@classmethod
|
||||
def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any:
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).first()
|
||||
if not res:
|
||||
return default
|
||||
return res.value
|
||||
|
||||
@classmethod
|
||||
def get_by_prefix(
|
||||
cls, key_prefix: str, default: Optional[Any] = None, sep: str = "."
|
||||
) -> Sequence[Tuple[str, Any]]:
|
||||
key_prefix = key_prefix.strip(sep)
|
||||
query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep)
|
||||
res = Settings.objects(query)
|
||||
if not res:
|
||||
return default
|
||||
return [(x.key, x.value) for x in res]
|
||||
|
||||
@classmethod
|
||||
def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Sets a new value or adds a new key/value setting (if key does not exist) """
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
|
||||
return bool(res)
|
||||
|
||||
@classmethod
|
||||
def add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = cls(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
||||
@@ -1,14 +1,39 @@
|
||||
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
|
||||
from mongoengine import (
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DynamicField,
|
||||
LongField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from database.fields import SafeMapField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
metric = StringField(required=True, )
|
||||
variant = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
timestamp = DateTimeField(default=0, required=True)
|
||||
iter = LongField()
|
||||
value = DynamicField(required=True)
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs):
|
||||
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})
|
||||
metric = StringField(required=True)
|
||||
variant = StringField(required=True)
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
last_update = LongField()
|
||||
|
||||
|
||||
class MetricEventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
metric = StringField(required=True)
|
||||
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
from database.utils import get_options
|
||||
|
||||
from database.fields import OutputDestinationField
|
||||
from database.fields import StrippedStringField
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Result(object):
|
||||
@@ -10,7 +10,7 @@ class Result(object):
|
||||
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = OutputDestinationField()
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
@@ -7,101 +5,184 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
IntField,
|
||||
ListField,
|
||||
LongField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeMapField, SafeDictField
|
||||
from database.fields import (
|
||||
StrippedStringField,
|
||||
SafeMapField,
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from .metrics import MetricEvent
|
||||
from .metrics import MetricEvent, MetricEventStats
|
||||
from .output import Output
|
||||
|
||||
DEFAULT_LAST_ITERATION = 0
|
||||
|
||||
|
||||
class TaskStatus(object):
|
||||
created = 'created'
|
||||
in_progress = 'in_progress'
|
||||
stopped = 'stopped'
|
||||
publishing = 'publishing'
|
||||
published = 'published'
|
||||
closed = 'closed'
|
||||
failed = 'failed'
|
||||
completed = 'completed'
|
||||
unknown = 'unknown'
|
||||
created = "created"
|
||||
queued = "queued"
|
||||
in_progress = "in_progress"
|
||||
stopped = "stopped"
|
||||
publishing = "publishing"
|
||||
published = "published"
|
||||
closed = "closed"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class TaskStatusMessage(object):
|
||||
stopping = 'stopping'
|
||||
stopping = "stopping"
|
||||
|
||||
|
||||
class TaskTags(object):
|
||||
development = 'development'
|
||||
class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
binary = StringField(default='python')
|
||||
repository = StringField(required=True)
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(default="")
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
entry_point = StringField(default="")
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument):
|
||||
class ArtifactTypeData(EmbeddedDocument):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class ArtifactModes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field='Model')
|
||||
model_desc = SafeMapField(StringField(default=''))
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
|
||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
''' Queue ID where task was queued '''
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
class TaskType(object):
|
||||
training = 'training'
|
||||
testing = 'testing'
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
inference = "inference"
|
||||
data_processing = "data_processing"
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
'created',
|
||||
'started',
|
||||
'completed',
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{
|
||||
'name': '%s.task.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$execution.model',
|
||||
'$output.model',
|
||||
'$script.repository',
|
||||
'$script.entry_point',
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'execution.model': 2,
|
||||
'output.model': 2,
|
||||
'script.repository': 1,
|
||||
'script.entry_point': 1,
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
@@ -120,15 +201,19 @@ class Task(AttributedDocument):
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
|
||||
|
||||
class TaskVisibility(Enum):
|
||||
active = 'active'
|
||||
archived = 'archived'
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from mongoengine import Document, StringField
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
@@ -18,4 +19,4 @@ class User(DbModelMixin, Document):
|
||||
family_name = StringField(user_set_allowed=True)
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = SafeDictField(default=dict, exclude_by_default=True)
|
||||
preferences = DynamicField(default="", exclude_by_default=True)
|
||||
|
||||
18
server/database/model/version.py
Normal file
18
server/database/model/version.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from mongoengine import Document, DateTimeField, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class Version(DbModelMixin, Document):
|
||||
meta = {
|
||||
"collection": "versions", # custom collection name ('version' is not a proper collection name...)
|
||||
"db_alias": Database.backend, # although we'll use this model for all databases, a default must be defined
|
||||
"strict": strict,
|
||||
"indexes": [("-created", "-num")],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
num = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
desc = StringField()
|
||||
@@ -1,13 +1,17 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
|
||||
import dpath
|
||||
import dpath.path
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
|
||||
SEP = "."
|
||||
|
||||
def project_dict(data, projection, separator='.'):
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
@@ -30,19 +34,27 @@ def project_dict(data, projection, separator='.'):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError('Incompatible destination type %s for %s (list expected)'
|
||||
% (type(dst), separator.join(path_parts[:depth + 1])))
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError('Destination list length differs from source length for %s'
|
||||
% separator.join(path_parts[:depth + 1]))
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])]
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError('Unsupported projection type %s for %s'
|
||||
% (type(src), separator.join(path_parts[:depth + 1])))
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
@@ -53,14 +65,38 @@ def project_dict(data, projection, separator='.'):
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator),
|
||||
source=data,
|
||||
destination=result)
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
def __init__(self, id):
|
||||
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
||||
|
||||
|
||||
class _ProxyManager:
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._proxies: Dict[str, _ReferenceProxy] = {}
|
||||
|
||||
def add(self, id):
|
||||
with self.lock:
|
||||
proxy = self._proxies.get(id)
|
||||
if proxy is None:
|
||||
proxy = self._proxies[id] = _ReferenceProxy(id)
|
||||
return proxy
|
||||
|
||||
def update(self, result):
|
||||
proxy = self._proxies.get(result.get("id"))
|
||||
if proxy is not None:
|
||||
proxy.update(result)
|
||||
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@@ -72,6 +108,11 @@ class ProjectionHelper(object):
|
||||
self._doc_cls = doc_cls
|
||||
self._doc_projection = None
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
@@ -81,32 +122,46 @@ class ProjectionHelper(object):
|
||||
:param projection: List of projection fields
|
||||
:return: A tuple of document projection and reference fields information
|
||||
"""
|
||||
doc_projection = set() # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = [] # Projection information for reference fields (used in join queries)
|
||||
doc_projection = (
|
||||
set()
|
||||
) # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = (
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
if not subfield.startswith('.'):
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
||||
orig_field = field
|
||||
if field.endswith('.*'):
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
doc_projection.add(field)
|
||||
return doc_projection, ref_projection_info
|
||||
|
||||
@@ -124,12 +179,14 @@ class ProjectionHelper(object):
|
||||
if not projection:
|
||||
return [], {}
|
||||
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(
|
||||
doc_cls, projection
|
||||
)
|
||||
|
||||
def normalize_cls_projection(cls_, fields):
|
||||
""" Normalize projection for this class and group (expand *, for once) """
|
||||
if '*' in fields:
|
||||
return list(fields.difference('*').union(cls_.get_fields()))
|
||||
if "*" in fields:
|
||||
return list(fields.difference("*").union(cls_.get_fields()))
|
||||
return list(fields)
|
||||
|
||||
def compute_ref_cls_projection(cls_, group):
|
||||
@@ -143,12 +200,16 @@ class ProjectionHelper(object):
|
||||
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
||||
ref_projection = {
|
||||
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
||||
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
|
||||
for (ref_field, ref_cls), g in groupby(
|
||||
sorted(ref_projection_info, key=sort_key), sort_key
|
||||
)
|
||||
}
|
||||
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
||||
@@ -158,13 +219,23 @@ class ProjectionHelper(object):
|
||||
doc_projection = [
|
||||
field
|
||||
for field in doc_projection
|
||||
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
|
||||
if not any(
|
||||
field.startswith(f"{other_field}.")
|
||||
for other_field in projection_set - {field}
|
||||
)
|
||||
]
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
|
||||
invalid_fields = [
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
|
||||
raise errors.bad_request.InvalidFields(
|
||||
fields=invalid_fields, object=doc_cls.__name__
|
||||
)
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - use both normal projection fields and top-level reference fields
|
||||
@@ -175,14 +246,54 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@staticmethod
|
||||
def _search(doc_cls, obj, path, only_values=True):
|
||||
""" Call dpath.search with yielded=True, collect result values """
|
||||
def _search(
|
||||
self,
|
||||
doc_cls: PropsMixin,
|
||||
obj: dict,
|
||||
path: str,
|
||||
factory: Callable[[str], dict] = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Search for a path in the given object, return the list of values found for the
|
||||
given path (multiple values may exist if the path is a glob expression)
|
||||
:param doc_cls: The document class represented by the object
|
||||
:param obj: Data object
|
||||
:param path: Path to a leaf in the data object ("." separated, may contain "*")
|
||||
(in case the path contains "*", there may be multiple values)
|
||||
:param factory: If provided, replace each value found with an instance provided by the factory.
|
||||
"""
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
|
||||
globlist = norm_path.strip(SEP).split(SEP)
|
||||
|
||||
obj_paths = self._cached_results_paths.get(id(obj))
|
||||
if obj_paths is None:
|
||||
obj_paths = self._cached_results_paths[id(obj)] = list(
|
||||
dpath.path.paths(obj, dirs=True, skip=True)
|
||||
)
|
||||
|
||||
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
|
||||
|
||||
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
|
||||
parent = None
|
||||
target = obj
|
||||
for part in p:
|
||||
parent = target
|
||||
target = target[part[0]]
|
||||
if parent and factory:
|
||||
parent[p[-1][0]] = factory(target)
|
||||
return target
|
||||
|
||||
return [search_and_replace(p) for p in paths]
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
@@ -197,28 +308,51 @@ class ProjectionHelper(object):
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - get results for each reference fields projection required (this is the join step)
|
||||
# Note: this is a recursive step, so we support nested reference fields
|
||||
# Note: this is a recursive step, so nested reference fields are supported
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data = item
|
||||
res = {}
|
||||
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
|
||||
for res in results))))
|
||||
if ids:
|
||||
doc_type = data['cls']
|
||||
doc_only = list(filter(None, data['only']))
|
||||
doc_only = list({'id'} | set(doc_only)) if doc_only else None
|
||||
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
|
||||
data['res'] = res
|
||||
def collect_ids(ref_field_name):
|
||||
"""
|
||||
Collect unique IDs for the given reference path from all result documents.
|
||||
All collected IDs are replaced in the result dictionaries with a reference proxy generated by the
|
||||
proxies manager to allow rapid update later on when projection results are obtained.
|
||||
"""
|
||||
all_ids = (
|
||||
self._search(
|
||||
cls, res, ref_field_name, factory=self._proxy_manager.add
|
||||
)
|
||||
for res in results
|
||||
)
|
||||
return list(filter(None, set(chain.from_iterable(all_ids))))
|
||||
|
||||
items = list(ref_projection.items())
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
items = [
|
||||
tup
|
||||
for tup in (
|
||||
(*item, collect_ids(item[0])) for item in ref_projection.items()
|
||||
)
|
||||
if tup[2]
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
doc_type = data["cls"]
|
||||
doc_only = list(filter(None, data["only"]))
|
||||
doc_only = list({"id"} | set(doc_only)) if doc_only else None
|
||||
|
||||
for res in projection_func(
|
||||
doc_type=doc_type, projection=doc_only, ids=ids
|
||||
):
|
||||
self._proxy_manager.update(res)
|
||||
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
|
||||
def do_expand_reference_ids(result, skip_fields=None):
|
||||
ref_fields = cls.get_reference_fields()
|
||||
@@ -226,44 +360,18 @@ class ProjectionHelper(object):
|
||||
ref_fields = set(ref_fields) - set(skip_fields)
|
||||
self._expand_reference_fields(cls, result, ref_fields)
|
||||
|
||||
def merge_projection_result(result):
|
||||
for ref_field_name, data in ref_projection.items():
|
||||
res = data.get('res')
|
||||
if not res:
|
||||
self._expand_reference_fields(cls, result, [ref_field_name])
|
||||
continue
|
||||
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
obj = res.get(value) or {'id': value}
|
||||
dpath.new(result, path, obj, separator='.')
|
||||
|
||||
# any reference field not projected should be expanded
|
||||
do_expand_reference_ids(result, skip_fields=list(ref_projection))
|
||||
|
||||
update_func = merge_projection_result if ref_projection else \
|
||||
do_expand_reference_ids if self._should_expand_reference_ids else None
|
||||
|
||||
if update_func:
|
||||
# any reference field not projected should be expanded
|
||||
if self._should_expand_reference_ids:
|
||||
for result in results:
|
||||
update_func(result)
|
||||
do_expand_reference_ids(
|
||||
result, skip_fields=list(ref_projection) if ref_projection else None
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _expand_reference_fields(cls, doc_cls, result, fields):
|
||||
def _expand_reference_fields(self, doc_cls, result, fields):
|
||||
for ref_field_name in fields:
|
||||
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
dpath.set(
|
||||
result,
|
||||
path,
|
||||
{'id': value} if value else {},
|
||||
separator='.')
|
||||
self._search(doc_cls, result, ref_field_name, factory=_ReferenceProxy)
|
||||
|
||||
@classmethod
|
||||
def expand_reference_ids(cls, doc_cls, result):
|
||||
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
def expand_reference_ids(self, doc_cls, result):
|
||||
self._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document
|
||||
from mongoengine.base import get_document, BaseField
|
||||
|
||||
from database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
)
|
||||
from database.utils import get_fields, get_fields_and_attr
|
||||
from database.utils import get_fields, get_fields_attr
|
||||
|
||||
|
||||
class PropsMixin(object):
|
||||
@@ -19,6 +21,7 @@ class PropsMixin(object):
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
__cached_field_names_per_type = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
@@ -29,6 +32,39 @@ class PropsMixin(object):
|
||||
cls.__cached_fields = get_fields(cls)
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_field_names_for_type(cls, of_type=BaseField):
|
||||
"""
|
||||
Return field names per type including subfields
|
||||
The fields of derived types are also returned
|
||||
"""
|
||||
assert issubclass(of_type, BaseField)
|
||||
if cls.__cached_field_names_per_type is None:
|
||||
fields = defaultdict(list)
|
||||
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||
fields[type(field)].append(name)
|
||||
for type_ in fields:
|
||||
fields[type_].extend(
|
||||
chain.from_iterable(
|
||||
fields[other_type]
|
||||
for other_type in fields
|
||||
if other_type != type_ and issubclass(other_type, type_)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type = fields
|
||||
|
||||
if of_type not in cls.__cached_field_names_per_type:
|
||||
names = list(
|
||||
chain.from_iterable(
|
||||
field_names
|
||||
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||
if issubclass(type_, of_type)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type[of_type] = names
|
||||
|
||||
return cls.__cached_field_names_per_type[of_type]
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
if cls.__cached_fields_with_instance is None:
|
||||
@@ -42,7 +78,7 @@ class PropsMixin(object):
|
||||
@staticmethod
|
||||
def _get_fields_with_attr(cls_, attr):
|
||||
""" Get all fields with the specified attribute (supports nested fields) """
|
||||
res = get_fields_and_attr(cls_, attr=attr)
|
||||
res = get_fields_attr(cls_, attr=attr)
|
||||
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
@@ -122,6 +158,14 @@ class PropsMixin(object):
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@classmethod
|
||||
def get_extra_projection(cls, fields: Sequence) -> tuple:
|
||||
if isinstance(fields, str):
|
||||
fields = [fields]
|
||||
return tuple(
|
||||
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
@@ -140,3 +184,18 @@ class PropsMixin(object):
|
||||
result = separator.join(translated)
|
||||
cls.__cached_dpath_computed_fields[path] = result
|
||||
return cls.__cached_dpath_computed_fields[path]
|
||||
|
||||
def get_field_value(self, field_path: str, default=None):
|
||||
"""
|
||||
Return the document field_path value by the field_path name.
|
||||
The path may contain '.'. If on any level the path is
|
||||
not found then the default value is returned
|
||||
"""
|
||||
path_elements = field_path.split(".")
|
||||
current = self
|
||||
for name in path_elements:
|
||||
current = getattr(current, name, default)
|
||||
if current == default:
|
||||
break
|
||||
|
||||
return current
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
from mongoengine.queryset.visitor import (
|
||||
QueryCompilerVisitor,
|
||||
SimplificationVisitor,
|
||||
QCombination,
|
||||
QNode,
|
||||
)
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
@@ -17,17 +23,16 @@ class RegexWrapper(object):
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
def to_query(self: Union["RegexMixin", QNode], document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
if getattr(other, "empty", True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
@@ -8,61 +9,65 @@ from mongoengine.base import BaseField
|
||||
from .errors import translate_errors_context, ParseCallError
|
||||
|
||||
|
||||
def get_fields(cls, of_type=BaseField, return_instance=False):
|
||||
def get_fields(cls, of_type=BaseField, return_instance=False, subfields=False):
|
||||
return _get_fields(
|
||||
cls,
|
||||
of_type=of_type,
|
||||
subfields=subfields,
|
||||
selector=lambda k, v: (k, v) if return_instance else k,
|
||||
)
|
||||
|
||||
|
||||
def get_fields_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = []
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.extend([k if not return_instance else (k, v)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, of_type)])
|
||||
return res
|
||||
return dict(
|
||||
_get_fields(cls, with_attr=attr, selector=lambda k, v: (k, getattr(v, attr)))
|
||||
)
|
||||
|
||||
|
||||
def get_fields_and_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = {}
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.update({k: getattr(v, attr)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, BaseField) and hasattr(v, attr)})
|
||||
return res
|
||||
def get_fields_choices(cls, attr):
|
||||
def get_choices(field_name: str, field: BaseField) -> Tuple:
|
||||
if isinstance(field, ListField):
|
||||
return field_name, field.field.choices
|
||||
return field_name, field.choices
|
||||
|
||||
return dict(_get_fields(cls, with_attr=attr, subfields=True, selector=get_choices))
|
||||
|
||||
|
||||
def _get_field_choices(name, field):
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
obj = field.document_type_obj
|
||||
n, choices = _get_field_choices(field.name, obj.field)
|
||||
return '%s__%s' % (name, n), choices
|
||||
elif issubclass(type(field), ListField):
|
||||
return name, field.field.choices
|
||||
return name, field.choices
|
||||
|
||||
|
||||
def get_fields_with_attr(cls, attr, default=False):
|
||||
def _get_fields(
|
||||
cls,
|
||||
with_attr=None,
|
||||
of_type=BaseField,
|
||||
subfields=False,
|
||||
selector: Optional[Callable[[str, BaseField], Any]] = None,
|
||||
path: Tuple[str, ...] = (),
|
||||
):
|
||||
fields = []
|
||||
for field_name, field in cls._fields.items():
|
||||
if not getattr(field, attr, default):
|
||||
continue
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
fields.extend((('%s__%s' % (field_name, name), choices)
|
||||
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
|
||||
elif issubclass(type(field), ListField):
|
||||
fields.append((field_name, field.field.choices))
|
||||
else:
|
||||
fields.append((field_name, field.choices))
|
||||
field_path = path + (field_name,)
|
||||
if isinstance(field, of_type) and (not with_attr or hasattr(field, with_attr)):
|
||||
full_name = "__".join(field_path)
|
||||
fields.append(selector(full_name, field) if selector else full_name)
|
||||
|
||||
if subfields and isinstance(field, EmbeddedDocumentField):
|
||||
fields.extend(
|
||||
_get_fields(
|
||||
field.document_type,
|
||||
with_attr=with_attr,
|
||||
of_type=of_type,
|
||||
subfields=subfields,
|
||||
selector=selector,
|
||||
path=field_path,
|
||||
)
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
|
||||
res = {
|
||||
k: v
|
||||
for k, v in getmembers(cls)
|
||||
if not (k.startswith("_") or ismethod(v))
|
||||
}
|
||||
res = {k: v for k, v in getmembers(cls) if not (k.startswith("_") or ismethod(v))}
|
||||
return res
|
||||
|
||||
|
||||
@@ -81,7 +86,7 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
fields = {k: None for k in fields}
|
||||
fields = {k: v for k, v in fields.items() if k in cls_fields}
|
||||
res = {}
|
||||
with translate_errors_context('parsing call data'):
|
||||
with translate_errors_context("parsing call data"):
|
||||
for field, desc in fields.items():
|
||||
value = call_data.get(field)
|
||||
if value is None:
|
||||
@@ -90,23 +95,34 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
desc(value)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
|
||||
raise ParseCallError('expecting %s' % desc.__name__, field=field)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
|
||||
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
|
||||
if issubclass(desc, Document):
|
||||
if not desc.objects(id=value).only("id"):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
elif callable(desc):
|
||||
try:
|
||||
desc(value)
|
||||
except TypeError:
|
||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||
except Exception as ex:
|
||||
raise ParseCallError(str(ex), field=field)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
def init_cls_from_base(cls, instance):
|
||||
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in instance.to_mongo(use_db_field=False).to_dict().items()
|
||||
if k[0] != "_"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_company_or_none_constraint(company=None):
|
||||
return Q(company__in=(company, None, '')) | Q(company__exists=False)
|
||||
return Q(company__in=(company, None, "")) | Q(company__exists=False)
|
||||
|
||||
|
||||
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
@@ -118,23 +134,43 @@ def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = (Q(**{f"{field}__exists": False}) |
|
||||
Q(**{f"{field}__in": {empty_value, None}}))
|
||||
query = Q(**{f"{field}__exists": False}) | Q(
|
||||
**{f"{field}__in": {empty_value, None}}
|
||||
)
|
||||
if is_list:
|
||||
query |= Q(**{f"{field}__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def field_exists(field: str, empty_value=None, is_list=False) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that exists and is not None or empty.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used)
|
||||
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = Q(**{f"{field}__exists": True}) & Q(
|
||||
**{f"{field}__nin": {empty_value, None}}
|
||||
)
|
||||
if is_list:
|
||||
query &= Q(**{f"{field}__not__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def get_subkey(d, key_path, default=None):
|
||||
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
|
||||
the nested dictionary.
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
keys = key_path.split(".")
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(d, dict):
|
||||
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
|
||||
raise KeyError(
|
||||
"Expecting a dict (%s)" % (".".join(keys[:i]) if i else "bad input")
|
||||
)
|
||||
d = d.get(key)
|
||||
if key is None:
|
||||
if d is None:
|
||||
return default
|
||||
return d
|
||||
|
||||
@@ -158,3 +194,42 @@ def merge_dicts(*dicts):
|
||||
def filter_fields(cls, fields):
|
||||
"""From the fields dictionary return only the fields that match cls fields"""
|
||||
return {key: fields[key] for key in fields if key in get_fields(cls)}
|
||||
|
||||
|
||||
def _names_set(*names: str) -> Set[str]:
|
||||
"""
|
||||
Given a list of names return set with names and '-names'
|
||||
"""
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
"queue": _names_set("default"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
:param entity: The name of the entity that defines the list of the system tags
|
||||
:param tags: The tags to partition
|
||||
:param system_tags: Optional. If passed then these tags are considered system together
|
||||
with those defined for the entity.
|
||||
:return: a tuple where the first element is the sequence of user-defined tags and
|
||||
the second element is the sequence of system tags
|
||||
"""
|
||||
tags = set(tags)
|
||||
system_tags = set(system_tags)
|
||||
system_tags |= tags & system_tag_names[entity]
|
||||
|
||||
prefixes = system_tag_prefixes.get(entity, [])
|
||||
system_tags |= {t for t in tags for p in prefixes if t.lower().startswith(p)}
|
||||
|
||||
return list(tags - system_tags), list(system_tags)
|
||||
|
||||
@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
|
||||
session.mount('http://', adapter)
|
||||
|
||||
session.delete(url)
|
||||
r = session.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(data),
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
return [
|
||||
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
|
||||
]
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("hosts", nargs="+")
|
||||
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
nargs="+",
|
||||
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
args = parse_args()
|
||||
print(">>>>> Applying mapping to " + str(args.hosts))
|
||||
res = apply_mappings_to_cluster(args.hosts, args.key)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
64
server/elastic/initialize.py
Normal file
64
server/elastic/initialize.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from time import sleep
|
||||
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ElasticConnectionError(Exception):
|
||||
"""
|
||||
Exception when could not connect to elastic during init
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def check_elastic_empty() -> bool:
|
||||
"""
|
||||
Check for elasticsearch connection
|
||||
Use probing settings and not the default es cluster ones
|
||||
so that we can handle correctly the connection rejects due to ES not fully started yet
|
||||
:return:
|
||||
"""
|
||||
cluster_conf = es_factory.get_cluster_config("events")
|
||||
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
|
||||
timeout = config.get("apiserver.elastic.probing.timeout", 30)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
except exceptions.ConnectionError:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError()
|
||||
log.warn(
|
||||
f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
|
||||
)
|
||||
sleep(timeout)
|
||||
|
||||
|
||||
def init_es_data():
|
||||
for name in es_factory.get_all_cluster_names():
|
||||
cluster_conf = es_factory.get_cluster_config(name)
|
||||
hosts_config = cluster_conf.get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(f"for cluster '{name}'")
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
log.info(res)
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 5
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
server/elastic/mappings/events/events.json
Normal file
40
server/elastic/mappings/events/events.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
server/elastic/mappings/events/events_log.json
Normal file
15
server/elastic/mappings/events/events_log.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"index_patterns": "events-log-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events/events_plot.json
Normal file
12
server/elastic/mappings/events/events_plot.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"index_patterns": "events-plot-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
server/elastic/mappings/workers/worker_stats.json
Normal file
37
server/elastic/mappings/workers/worker_stats.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,28 @@
|
||||
from datetime import datetime
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = "ELASTIC_SERVICE_SERVICE_HOST"
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"TRAINS_ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
|
||||
|
||||
OVERRIDE_HOST = getenv(OVERRIDE_HOST_ENV_KEY)
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override elastic host {OVERRIDE_HOST}")
|
||||
|
||||
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -20,6 +30,7 @@ class MissingClusterConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -27,6 +38,7 @@ class InvalidClusterConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration does not contain required properties
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -41,16 +53,22 @@ def connect(cluster_name):
|
||||
"""
|
||||
if cluster_name not in _instances:
|
||||
cluster_config = get_cluster_config(cluster_name)
|
||||
hosts = cluster_config.get('hosts', None)
|
||||
hosts = cluster_config.get("hosts", None)
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
args = cluster_config.get('args', {})
|
||||
_instances[cluster_name] = Elasticsearch(hosts=hosts, transport_class=Transport, **args)
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
hosts=hosts, transport_class=Transport, **args
|
||||
)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
|
||||
def get_all_cluster_names():
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
|
||||
def get_cluster_config(cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
@@ -58,14 +76,20 @@ def get_cluster_config(cluster_name):
|
||||
:return: config section for the cluster
|
||||
:raises MissingClusterConfiguration: in case no config section is found for the cluster
|
||||
"""
|
||||
cluster_key = '.'.join(('hosts.elastic', cluster_name))
|
||||
cluster_key = ".".join(("hosts.elastic", cluster_name))
|
||||
cluster_config = config.get(cluster_key, None)
|
||||
if not cluster_config:
|
||||
raise MissingClusterConfiguration(cluster_name)
|
||||
|
||||
def set_host_prop(key, value):
|
||||
for host in cluster_config.get("hosts", []):
|
||||
host[key] = value
|
||||
|
||||
if OVERRIDE_HOST:
|
||||
for host in cluster_config.get('hosts', []):
|
||||
host["host"] = OVERRIDE_HOST
|
||||
set_host_prop("host", OVERRIDE_HOST)
|
||||
|
||||
if OVERRIDE_PORT:
|
||||
set_host_prop("port", OVERRIDE_PORT)
|
||||
|
||||
return cluster_config
|
||||
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from database.model.auth import User as AuthUser, Credentials
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
log.info(res)
|
||||
|
||||
|
||||
def _ensure_company():
|
||||
company_id = config.get("apiserver.default_company")
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
return company_id
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data, company_id):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data.keys())
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
if user:
|
||||
return user.id
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
user = AuthUser(
|
||||
id=user_data.get("id", f"__{user_data['name']}__"),
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
|
||||
if ensure_credentials
|
||||
else None,
|
||||
)
|
||||
|
||||
user.save()
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
def _ensure_user(user: FixedUser, company_id: str):
|
||||
if User.objects(id=user.user_id).first():
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
|
||||
_ensure_auth_user(
|
||||
user_data=data,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
company_id = _ensure_company()
|
||||
users = [
|
||||
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
|
||||
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
for user in users:
|
||||
credentials = config.get(f"secure.credentials.{user['name']}")
|
||||
user["key"] = credentials.user_key
|
||||
user["secret"] = credentials.user_secret
|
||||
_ensure_auth_user(user, company_id)
|
||||
|
||||
if FixedUser.enabled():
|
||||
log.info("Fixed users mode is enabled")
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
_ensure_user(user, company_id)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user['name']}: {ex}")
|
||||
except Exception as ex:
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user