Compare commits

87 Commits

Author SHA1 Message Date
allegroai
11d76e7d8c Update AWS AMIs for v0.15.0 2020-06-01 23:07:38 +03:00
allegroai
e76c0fbc63 Version bump to 0.15.0 2020-06-01 22:20:58 +03:00
allegroai
fdc9956da3 Update trains-agent-services docker image 2020-06-01 21:53:33 +03:00
allegroai
f4addaa653 Add new services mode agent container to the docker-compose 2020-06-01 21:02:49 +03:00
allegroai
667964cc82 Add clear_all flag to tasks.reset 2020-06-01 13:07:35 +03:00
allegroai
e1309e30b7 Fix UPLOAD_FOLDER handling when provided as env var or when fileserver is run by gunicorn 2020-06-01 13:05:45 +03:00
allegroai
9403942ef7 Add support for additional task types as well as tasks.get_types to obtain actual types used globally or per project 2020-06-01 13:05:12 +03:00
allegroai
84a75d9e70 Add server uid to server.info response in API v2.8 2020-06-01 13:01:31 +03:00
allegroai
c85ab66ae6 Add organization.get_tags to obtain the set of all used task, model, queue and project tags 2020-06-01 13:00:35 +03:00
allegroai
bf7f0f646b Sort hyper parameters numeric values as numbers and not strings 2020-06-01 12:27:56 +03:00
allegroai
dcdf2a3d58 Fix task can't be cloned if input model was deleted 2020-06-01 12:23:29 +03:00
allegroai
f8d8fc40a6 Support filtering users by activity in projects 2020-06-01 11:55:40 +03:00
allegroai
45d434a123 When clearing a task do not delete draft models used by other tasks 2020-06-01 11:51:43 +03:00
allegroai
1834abe5bc Better handling of execution parameter paths 2020-06-01 11:49:35 +03:00
allegroai
d6321588f3 Fix role checked for endpoints not requiring authorization 2020-06-01 11:43:55 +03:00
allegroai
c17b10ff1d Revoke built-in webserver system-role credentials (used by the WebApp) in case we're running in fixed-mode 2020-06-01 11:41:43 +03:00
allegroai
b125a56f86 Make sure configuration path loaded from an environment variable name is lower-case 2020-06-01 11:40:34 +03:00
allegroai
c43ce3a17b Update 0.15 mongo migration to drop indices (so new ones will be automatically created) 2020-06-01 11:36:22 +03:00
allegroai
b0b09616a8 Fix single bad event causes events.add_batch to skip remaining events 2020-06-01 11:33:39 +03:00
allegroai
ede5586ccc Extract non-responsive tasks watchdog from main tasks logic 2020-06-01 11:31:36 +03:00
allegroai
a1dcdffa53 Update pymongo and mongoengine versions 2020-06-01 11:29:50 +03:00
allegroai
35a11db58e Support task log retrieval with no scroll 2020-06-01 11:27:36 +03:00
allegroai
d9bdebefc7 Update AWS AMIs 2020-05-14 17:54:30 +03:00
allegroai
f29884f05a Version bump to v0.14.2 2020-05-14 17:53:56 +03:00
allegroai
0f72d662f8 Update GCP documentation 2020-05-04 17:31:11 +03:00
allegroai
6202219034 Update README 2020-05-03 11:08:21 +03:00
allegroai
bb3218f65d Update GCP installation instructions 2020-04-06 12:59:29 +03:00
allegroai
cbcaa7c789 Add MongoDB performance optimization 2020-04-01 19:20:53 +03:00
allegroai
427322a424 Update schema 2020-04-01 19:16:34 +03:00
allegroai
0e7d7d36a9 Update docs for GCP Custom Images 2020-03-30 15:51:58 +03:00
allegroai
06032a6d66 Update documentation 2020-03-20 10:51:43 +02:00
allegroai
b48f4eb2eb Make sure time intervals are calculated in ms 2020-03-20 10:50:56 +02:00
Allegro AI
383b2666c4 Update AWS AMIs 2020-03-16 21:57:07 +02:00
allegroai
50c373cf0d Version bump to v0.14.1 2020-03-16 18:47:35 +02:00
allegroai
394a9de5fa Update docs with AMI IDs for v0.14.1 2020-03-16 18:47:20 +02:00
allegroai
fb5c06e9c3 Version bump to v0.14.0 2020-03-05 20:03:48 +02:00
allegroai
1a9bbc9420 Update docs with AMI IDs for v0.14.0 2020-03-05 20:03:33 +02:00
allegroai
294da32401 Fix getting empty metrics from task 2020-03-05 14:57:20 +02:00
allegroai
7f00672010 Fix missing routing value when downloading tasks events 2020-03-05 14:55:40 +02:00
allegroai
99bf89a360 Add pre-populate feature to allow starting a new server installation with packaged example experiments 2020-03-05 14:54:34 +02:00
allegroai
6c8508eb7f Add support for pagination in events.debug_images 2020-03-01 18:00:07 +02:00
allegroai
69714d5b5c Use top-level module for api version number instead of a fixed value 2020-03-01 17:51:03 +02:00
allegroai
f9516ec7d3 Fix ActualEnumField initialization in case default was not provided 2020-03-01 17:47:47 +02:00
allegroai
6fdde93dee Add migration script 2020-03-01 17:46:10 +02:00
allegroai
7afc71ec91 Update requirements 2020-02-26 17:26:59 +02:00
allegroai
4595117d91 Support setting fileserver upload folder using an environment variable 2020-02-26 17:26:46 +02:00
allegroai
8630cc1021 Fix queue update time to update when task is taken from queue, not when queried 2020-02-20 18:26:56 +02:00
allegroai
135885b609 Improve unit test for entity ordering 2020-02-04 18:21:13 +02:00
allegroai
eb0865662c Fix projects aggregation on tasks with invalid status 2020-02-04 18:21:04 +02:00
allegroai
b7b94e7ae5 Add more validation when parsing task call 2020-02-04 18:19:07 +02:00
allegroai
72be8bee19 Limit metrics and variants to avoid ES error 2020-02-04 18:18:26 +02:00
allegroai
0722b20c1c Fix task scalars comparison aggregation 2020-02-04 18:16:27 +02:00
allegroai
a392a0e6ff Fix request field required constraint 2020-02-04 18:12:30 +02:00
allegroai
e22fa2f478 Limit dpath requirement 2020-02-04 18:09:55 +02:00
allegroai
8b49c1ac06 Update docs with AWS AMI IDs for v0.13.0 2020-01-07 14:40:09 +02:00
allegroai
da1182a405 Update docs with AWS AMI IDs for v0.13.0 2020-01-06 18:41:09 +02:00
allegroai
53e995ee8c Version bump to v0.13.0 2020-01-06 15:28:31 +02:00
allegroai
4732dc1a88 Remove deprecated env vars from docker compose files 2020-01-06 12:23:06 +02:00
allegroai
e325bcaf67 Hash ROI id to make sure it does not violate Elastic's 512 bytes id limitation 2020-01-05 09:20:38 +02:00
allegroai
a7c30453db Update documentation 2020-01-05 09:19:37 +02:00
allegroai
dedac3b2fe Allow using "$", "." and whitespaces in hyper-parameter keys 2020-01-02 15:28:50 +02:00
allegroai
7d10bbdf8e Update requirement 2020-01-02 15:27:04 +02:00
allegroai
72213dffa4 Update migration to convert user preferences to JSON 2020-01-02 15:26:45 +02:00
allegroai
f778837d4b Change the way user preferences are stored (JSON instead of plain dict) 2020-01-02 15:23:47 +02:00
allegroai
153ed6a7b7 Update documentation 2020-01-02 15:21:35 +02:00
allegroai
5d279c8c5a Add fixed user validation
Fix the way a fixed user id is generated
2020-01-02 15:20:55 +02:00
allegroai
ed910d5f6a Improve server threads shutdown on SIGTERM 2019-12-29 09:04:07 +02:00
allegroai
87d2b6fa15 Add some missing definitions 2019-12-29 09:03:19 +02:00
allegroai
94cfb17291 Add minor updates 2019-12-29 09:02:32 +02:00
allegroai
3f641d37b7 Optimize empty schema validator usage 2019-12-29 08:59:52 +02:00
allegroai
551be12f01 Move mongodb migrations inside the server's folder 2019-12-29 08:58:54 +02:00
allegroai
b536020058 Update documentation 2019-12-29 08:47:47 +02:00
Allegro AI
fb6fbc0a06 Update README.md 2019-12-25 14:21:16 +02:00
allegroai
5ae64fd791 Add support for tasks.clone 2019-12-24 18:01:48 +02:00
allegroai
f9776e4319 Allow two users to have the same full name 2019-12-24 17:58:59 +02:00
allegroai
75e736e7d5 Update readme files 2019-12-24 17:58:02 +02:00
allegroai
1e4756aa1d Add support for atomic add/update of task artifacts 2019-12-24 17:57:26 +02:00
allegroai
52529d3c55 Avoid updating experiment last iteration for metric events related to machine/gpu monitoring 2019-12-21 18:14:13 +02:00
allegroai
53296e8891 Use a single definitive way to obtain server version and build 2019-12-21 18:13:05 +02:00
allegroai
1c87ebc900 Use trains-specific environment variables for server configuration 2019-12-21 18:10:48 +02:00
allegroai
14d9924ea0 Update .gitignore 2019-12-21 18:09:04 +02:00
allegroai
69f9b424c7 Update readme and documentation 2019-12-19 18:27:16 +02:00
allegroai
1a6da301a8 Update internal version string 2019-12-19 18:26:19 +02:00
allegroai
2728b3ed14 Add labels to standalone models 2019-12-14 23:54:24 +02:00
allegroai
38284eef1f Add safe guards 2019-12-14 23:53:09 +02:00
allegroai
9debe1adcd Improve resource monitoring 2019-12-14 23:52:39 +02:00
allegroai
cc93c15f8a Optimize ELK 2019-12-14 23:50:26 +02:00
112 changed files with 4933 additions and 1568 deletions

4
.gitignore vendored
View File

@@ -1,11 +1,10 @@
syntax: glob
.idea
apierrors/errors
static/build.json
static/dashboard/node_modules
static/webapp/node_modules
static/webapp/.git
scripts/
generators/
*.pyc
__pycache__
.ropeproject
@@ -20,3 +19,4 @@ build
dist
code.tar.gz
server/schema/services/_cache.json
server/apierrors/errors/*

226
README.md
View File

@@ -1,4 +1,4 @@
# TRAINS Server
# Trains Server
## Auto-Magical Experiment Manager & Version Control for AI
@@ -7,27 +7,24 @@
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![PyPI status](https://img.shields.io/badge/status-beta-yellow.svg)](https://img.shields.io/badge/status-beta-yellow.svg)
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
## Introduction
The **trains-server** is the backend service infrastructure for [TRAINS](https://github.com/allegroai/trains).
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
It allows multiple users to collaborate and manage their experiments.
By default, TRAINS is set up to work with the TRAINS demo server, which is open to anyone and resets periodically.
In order to host your own server, you will need to install **trains-server** and point TRAINS to it.
By default, **Trains** is set up to work with the **Trains** demo server, which is open to anyone and resets periodically.
In order to host your own server, you will need to launch **trains-server** and point **Trains** to it.
**trains-server** contains the following components:
* The TRAINS Web-App, a single-page UI for experiment management and browsing
* The **Trains** Web-App, a single-page UI for experiment management and browsing
* RESTful API for:
* Documenting and logging experiment information, statistics and results
* Querying experiments history, logs and results
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
You can quickly setup your **trains-server** using:
- [Docker Installation](#installation)
- Pre-built Amazon [AWS image](#aws)
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#trains-server-for-kubernetes-clusters-using-helm)
or manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#trains-server-for-kubernetes-clusters)
You can quickly [deploy](#launching-trains-server) your **trains-server** using Docker, AWS EC2 AMI, or Kubernetes.
## System design
@@ -44,136 +41,43 @@ You can quickly setup your **trains-server** using:
- Web application on sub-domain: app.\*.\*
- API service on sub-domain: api.\*.\*
- File storage service on sub-domain: files.\*.\*
## Launching trains-server
## Install / Upgrade - AWS <a name="aws"></a>
### Prerequisites
Use one of our pre-installed Amazon Machine Images for easy deployment in AWS.
For details and instructions, see [TRAINS-server: AWS pre-installed images](docs/install_aws.md).
## Docker Installation - Linux, macOS, and Windows <a name="installation"></a>
Use our pre-built Docker image for easy deployment in Linux and macOS. <br>
For [Windows](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#docker_compose_win10), please see detailed docker-compose installation instructions on our [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#docker_compose_win10).<br>
Latest docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
1. Setup Docker (docker-compose installation details: [Ubuntu](docs/faq.md#ubuntu) / [macOS](docs/faq.md#mac-osx))
<details>
<summary>Make sure ports 8080/8081/8008 are available for the TRAINS-server services:</summary>
The ports 8080/8081/8008 must be available for the **trains-server** services.
For example, to see if port `8080` is in use:
For example, to see if port `8080` is in use:
```bash
$ sudo lsof -Pn -i4 | grep :8080 | grep LISTEN
```
* Linux or macOS:
sudo lsof -Pn -i4 | grep :8080 | grep LISTEN
* Windows:
netstat -an |find /i "8080"
### Launching
</details>
Increase vm.max_map_count for `ElasticSearch` docker
Launch **trains-server** in any of the following formats:
- Linux
```bash
$ echo "vm.max_map_count=262144" > /tmp/99-trains.conf
$ sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
$ sudo sysctl -w vm.max_map_count=262144
$ sudo service docker restart
```
- macOS
```bash
$ screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
$ sysctl -w vm.max_map_count=262144
```
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
- Pre-built Docker Image
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
- Kubernetes
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
1. Create local directories for the databases and storage.
## Connecting Trains to your trains-server
```bash
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/data/redis
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo mkdir -p /opt/trains/config
```
Set folder permissions
- Linux
```bash
$ sudo chown -R 1000:1000 /opt/trains
```
- macOS
```bash
$ sudo chown -R $(whoami):staff /opt/trains
```
1. Download the `docker-compose.yml` file, either download [manually](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml) or execute:
```bash
$ curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
```
1. Launch the Docker containers <a name="launch-docker"></a>
```bash
$ docker-compose -f docker-compose.yml up
```
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
* API server on port `8008`
* File server on port `8081`
**\* If something went wrong along the way, check our FAQ: [Docker Setup](docs/docker_setup.md#setup-docker), [Ubuntu Support](docs/faq.md#ubuntu), [macOS Support](docs/faq.md#mac-osx)**
## Optional Configuration
The **trains-server** default configuration can be easily overridden using external configuration files. By default, the server will look for these files in `/opt/trains/config`.
In order to apply the new configuration, you must restart the server (see [Restarting trains-server](#restart-server)).
### Adding Web Login Authentication
By default anyone can login to the **trains-server** Web-App.
You can configure the **trains-server** to allow only a specific set of users to access the system.
Enable this feature by placing `apiserver.conf` file under `/opt/trains/config`.
Sample `apiserver.conf` configuration file can be found [here](https://github.com/allegroai/trains-server/blob/master/docs/apiserver.conf)
To apply the changes, you must [restart the *trains-server*](#restart-server).
### Configuring the Non-Responsive Experiments Watchdog
The non-responsive experiment watchdog, monitors experiments that were not updated for a given period of time,
and marks them as `aborted`. The watchdog is always active with a default of 7200 seconds (2 hours) of inactivity threshold.
To change the watchdog's timeouts, place a `services.conf` file under `/opt/trains/config`.
Sample watchdog `services.conf` configuration file can be found [here](https://github.com/allegroai/trains-server/blob/master/docs/services.conf)
To apply the changes, you must [restart the *trains-server*](#restart-server).
### Restarting trains-server <a name="restart-server"></a>
To restart the **trains-server**, you must first stop the containers, and then restart them.
```bash
$ docker-compose down
$ docker-compose -f docker-compose.yml up
```
## Configuring **TRAINS** client
Once you have installed the **trains-server**, make sure to configure **TRAINS** [client](https://github.com/allegroai/trains)
to use your locally installed server (and not the demo server).
- Run the `trains-init` command for an interactive setup
- Or manually edit `~/trains.conf` file, making sure the `api_server` value is configured correctly, for example:
By default, the **Trains** client is set up to work with the [**Trains** demo server](https://demoapp.trains.allegro.ai/).
To have the **Trains** client use your **trains-server** instead:
- Run the `trains-init` command for an interactive setup.
- Or manually edit `~/trains.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
api {
# API server on port 8008
@@ -186,26 +90,42 @@ to use your locally installed server (and not the demo server).
files_server: "http://localhost:8081"
}
* Notice that if you setup **trains-server** in a sub-domain configuration, there is no need to specify a port number,
**Note**: If you have set up **trains-server** in a sub-domain configuration, then there is no need to specify a port number,
it will be inferred from the http/s scheme.
See [Installing and Configuring TRAINS](https://github.com/allegroai/trains#configuration) for more details.
After launching the **trains-server** and configuring the **Trains** client to use the **trains-server**,
you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in your experiments and view them in your **trains-server** web server,
for example http://localhost:8080.
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
## What next?
## Advanced Functionality
Now that the **trains-server** is installed, and TRAINS is configured to use it,
you can [use](https://github.com/allegroai/trains#using-trains) TRAINS in your experiments and view them in the web server,
for example http://localhost:8080
**trains-server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
## Restarting trains-server
To restart the **trains-server**, you must first stop the containers, and then restart them.
```bash
docker-compose down
docker-compose -f docker-compose.yml up
```
## Upgrading <a name="upgrade"></a>
We are constantly updating, improving and adding to the **trains-server**.
New releases will include new pre-built Docker images.
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
**trains-server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker-compose.yml).
We strongly encourage you to keep your **trains-server** up to date, by keeping up with the current release.
**Note**: The following upgrade instructions use the Linux OS as an example.
To upgrade your existing **trains-server** deployment:
1. Shut down the docker containers
```bash
$ docker-compose down
docker-compose down
```
1. We highly recommend backing up your data directory before upgrading.
@@ -213,7 +133,7 @@ When we release a new version and include a new pre-built Docker image for it, u
Assuming your data directory is `/opt/trains`, to archive all data into `~/trains_backup.tgz` execute:
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
<details>
@@ -221,29 +141,29 @@ When we release a new version and include a new pre-built Docker image for it, u
To restore this example backup, execute:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
sudo rm -R /opt/trains/data
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
</details>
1. Download the latest `docker-compose.yml` file, either [manually](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml) or execute:
1. Download the latest `docker-compose.yml` file.
```bash
$ curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
```
1. Spin up the docker containers, it will automatically pull the latest trains-server build
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
```bash
$ docker-compose -f docker-compose.yml pull
$ docker-compose -f docker-compose.yml up
docker-compose -f docker-compose.yml pull
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Docker Upgrade](docs/docker_setup.md#common-docker-upgrade-errors)**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
## Community & Support
If you have any questions, look to the TRAINS-server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).

View File

@@ -20,9 +20,12 @@ services:
- mongo
- elasticsearch
environment:
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
networks:
- backend
elasticsearch:

View File

@@ -16,9 +16,14 @@ services:
- elasticsearch
- fileserver
environment:
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:
@@ -114,4 +119,4 @@ networks:
driver: bridge
volumes:
mongodata:
mongodata:

View File

@@ -16,9 +16,14 @@ services:
- elasticsearch
- fileserver
environment:
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:
@@ -110,6 +115,36 @@ services:
ports:
- "8080:80"
agent-services:
container_name: trains-agent-services
image: allegroai/trains-agent-services:latest
restart: unless-stopped
privileged: true
environment:
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
TRAINS_API_HOST: ${TRAINS_API_HOST:-}
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
TRAINS_WORKER_ID: "trains-services"
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/trains/agent:/root/.trains
depends_on:
- apiserver
networks:
backend:
driver: bridge

View File

@@ -1,5 +1,5 @@
auth {
# Fixed users login credetials
# Fixed users login credentials
# No other user will be able to login
fixed_users {
enabled: true

View File

@@ -1,166 +0,0 @@
# TRAINS-server: Using Docker Pre-Built Images
The pre-built Docker image for the **trains-server** is the quickest way to get started with your own **TRAINS** server.
You can also build the entire **trains-server** architecture using the code available in the [trains-server](https://github.com/allegroai/trains-server) repository.
**Note**: We tested this pre-built Docker image with Linux, only. For Windows users, we recommend installing the pre-built image on a Linux virtual machine.
## Prerequisites
* You must be logged in as a user with sudo privileges
* Use `bash` for all command-line instructions in this installation
## Setup Docker
### Step 1: Install Docker CE
You must first install Docker. For instructions about installing Docker, see [Supported platforms](https://docs.docker.com/install//#support) in the Docker documentation.
For example, to [install in Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/) / Mint (x86_64/amd64):
```bash
sudo apt-get install -y apt-transport-https ca-certificates curl software-properties-common
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
. /etc/os-release
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $UBUNTU_CODENAME stable"
sudo apt-get update
sudo apt-get install -y docker-ce
```
### Step 2: Set the Maximum Number of Memory Map Areas
Elastic requires that the `vm.max_map_count` kernel setting, which is the maximum number of memory map areas a process can use, is set to at least 262144.
For CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19.x, we tested the following commands to set `vm.max_map_count`:
```bash
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
```
For information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
### Step 3: Restart the Docker daemon
Restart the Docker daemon.
```bash
sudo service docker restart
```
### Step 4: Choose a Data Directory
Choose a directory on your system in which all data maintained by the **trains-server** is stored.
Create this directory, and set its owner and group to `uid` 1000. The data stored in this directory includes the database, uploaded files and logs.
For example, if your data directory is `/opt/trains`, then use the following command:
```bash
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/data/redis
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
sudo mkdir -p /opt/trains/config
sudo chown -R 1000:1000 /opt/trains
```
## TRAINS-server: Manually Launching Docker Containers <a name="launch"></a>
You can manually launch the Docker containers using the following commands.
If your data directory is not `/opt/trains`, then in the five `docker run` commands below, you must replace all occurrences of `/opt/trains` with your data directory path.
1. Launch the **trains-elastic** Docker container.
sudo docker run -d --restart="always" --name="trains-elastic" -e "bootstrap.memory_lock=true" --ulimit memlock=-1:-1 -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
1. Launch the **trains-mongo** Docker container.
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
1. Launch the **trains-redis** Docker container.
sudo docker run -d --restart="always" --name="trains-redis" -v /opt/trains/data/redis:/data --network="host" redis:5.0
1. Launch the **trains-fileserver** Docker container.
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
1. Launch the **trains-apiserver** Docker container.
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/config:/opt/trains/config allegroai/trains:latest apiserver
1. Launch the **trains-webserver** Docker container.
sudo docker run -d --restart="always" --name="trains-webserver" -p 8080:80 allegroai/trains:latest webserver
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* API server on port `8008`
* Web server on port `8080`
* File server on port `8081`
## Manually Upgrading TRAINS-server Containers <a name="upgrade"></a>
We are constantly updating, improving and adding to the **trains-server**.
New releases will include new pre-built Docker images.
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
1. Shut down and remove each of your Docker instances using the following commands:
```bash
$ sudo docker stop <docker-name>
$ sudo docker rm -v <docker-name>
```
The Docker names are (see [Launching Docker Containers](#launch-docker)):
* `trains-elastic`
* `trains-mongo`
* `trains-redis`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
2. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
For example, if your data directory is `/opt/trains`, use the following command:
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
This backups all data to an archive in your home directory.
To restore this example backup, use the following command:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
3. Pull the new **trains-server** docker image using the following command:
```bash
$ sudo docker pull allegroai/trains:latest
```
If you wish to pull a different version, replace `latest` with the required version number, for example:
```bash
$ sudo docker pull allegroai/trains:0.11.0
```
4. Launch the newly released Docker image (see [Launching Docker Containers](#trains-server-manually-launching-docker-containers-)).
#### Common Docker Upgrade Errors
* In case of a docker error: "... The container name "/trains-???" is already in use by ..."
Try removing deprecated images with:
```bash
$ docker rm -f $(docker ps -a -q)
```

View File

@@ -1,77 +1,122 @@
# TRAINS-server FAQ
# trains-server FAQ
* [Deploying trains-server on Kubernetes clusters](#kubernetes)
Launching **trains-server**
* [Creating a Helm Chart for trains-server Kubernetes deployment](#helm)
* How do I launch **trains-server** on:
* [Running trains-server on Mac OS X](#mac-osx)
* [Stand alone Linux Ubuntu systems?](#ubuntu)
* [macOS?](#mac-osx)
* [Windows 10?](#docker_compose_win10)
* [Running trains-server on Windows 10](#docker_compose_win10)
* [How do I restart trains-server?](#restart)
* [Installing trains-server on stand alone Linux Ubuntu systems ](#ubuntu)
Kubernetes
* [Resolving port conflicts preventing fixed users mode authentication and login](#port-conflict)
* [Can I deploy trains-server on Kubernetes clusters?](#kubernetes)
* [Configuring trains-server for sub-domains and load balancers](#sub-domains)
* [Can I create a Helm Chart for trains-server Kubernetes deployment?](#helm)
Configuration
### Deploying trains-server on Kubernetes clusters <a name="kubernetes"></a>
* [How do I configure trains-server for sub-domains and load balancers?](#sub-domains)
**trains-server** supports Kubernetes. See [trains-server-k8s](https://github.com/allegroai/trains-server-k8s)
which contains the YAML files describing the required services and detailed instructions for deploying
**trains-server** to a Kubernetes clusters.
* [Can I add web login authentication to trains-server?](#web-auth)
### Creating a Helm Chart for trains-server Kubernetes deployment <a name="helm"></a>
* [Can I modify the non-responsive experiment watchdog settings?](#watchdog)
**trains-server** supports creating a Helm chart for Kubernetes deployment. See [trains-server-helm](https://github.com/allegroai/trains-server-helm)
which you can use to create a Helm chart for **trains-server** and contains detailed instructions for deploying
**trains-server** to a Kubernetes clusters using Helm.
Troubleshooting
### Running trains-server on Mac OS X <a name="mac-osx"></a>
* [How do I fix Docker upgrade errors?](#common-docker-upgrade-errors)
To install and configure **trains-server** on Mac OS X, follow the steps below.
* [Why is web login authentication not working?](#port-conflict)
1. Install [docker for OS X](https://docs.docker.com/docker-for-mac/install/).
## Launching **trains-server**
1. Configure [Docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode).
### How do I launch trains-server on stand alone Linux Ubuntu systems? <a name="ubuntu"></a>
$ screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
$ sysctl -w vm.max_map_count=262144
To launch **trains-server** on a stand alone Linux Ubuntu:
1. Install [docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
1. Install `docker-compose` using the following commands (for more detailed information, see the [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
1. Remove the previous installation of **trains-server**.
**WARNING**: This clears all existing **Trains** databases.
sudo rm -R /opt/trains/
1. Create local directories for the databases and storage.
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/data/redis
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/config
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R $(whoami):staff /opt/trains
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/config
sudo mkdir -p /opt/trains/data/fileserver
sudo chown -R 1000:1000 /opt/trains
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
git clone https://github.com/allegroai/trains-server.git
cd trains-server
1. Run `docker-compose`
/usr/local/bin/docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### How do I launch trains-server on macOS? <a name="mac-osx"></a>
To launch **trains-server** on macOS:
1. Install [docker for macOS](https://docs.docker.com/docker-for-mac/install/).
1. Configure [Docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode).
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
1. Create local directories for the databases and storage.
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/data/redis
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/config
sudo mkdir -p /opt/trains/data/fileserver
sudo chown -R $(whoami):staff /opt/trains
1. Open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
git clone https://github.com/allegroai/trains-server.git
cd trains-server
1. Run `docker-compose` with the unified docker image.
1. Run `docker-compose` with the docker compose file.
$ docker-compose -f docker-compose-unified.yml up
docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Running trains-server on Windows 10 <a name="docker_compose_win10"></a>
### How do I launch trains-server on Windows 10? <a name="docker_compose_win10"></a>
You can run **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
To run **trains-server** on Windows 10, follow the steps below.
To launch **trains-server** on Windows 10:
1. Install the Docker Desktop for Windows application by either:
* Following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
* Running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
* following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
* running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
1. Increase the memory allocation in Docker Desktop to `4GB`.
@@ -83,110 +128,46 @@ To run **trains-server** on Windows 10, follow the steps below.
1. Create local directories for data and logs. Open PowerShell and execute the following commands:
mkdir c:\opt\trains\logs
mkdir c:\opt\trains\config
cd c:
mkdir c:\opt\trains\data
mkdir c:\opt\trains\data\elastic
mkdir c:\opt\trains\data\redis
mkdir c:\opt\trains\data\fileserver
mkdir c:\opt\trains\logs
1. Save the **trains-server** docker-compose YAML file [docker-compose-win10.yml](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml) as `c:\opt\trains\docker-compose.yml`.
1. Download the **trains-server** docker-compose YAML file [docker-compose-win10.yml](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml) as `c:\opt\trains\docker-compose.yml`.
1. Run `docker-compose`. In PowerShell, execute the following commands:
cd c:\opt\trains\
docker-compose up
docker-compose -f up docker-compose-win10.yml
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Installing trains-server on stand alone Linux Ubuntu systems <a name="ubuntu"></a>
### How do I restart trains-server? <a name="restart"></a>
To install **trains-server** on a stand alone Linux Ubuntu, follow the steps belows.
Restart *trains-server* by first stopping the Docker containers and then restarting them.
1. Install [docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
```bash
docker-compose down
docker-compose up -f docker-compose.yml
```
**Note**: If you are using a different docker-compose YAML file, specify that file.
1. Install `docker-compose` using the following commands (for more detailed information, see the [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
## Kubernetes
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
### Can I deploy trains-server on Kubernetes clusters? <a name="kubernetes"></a>
1. Remove the previous installation of **trains-server**.
**trains-server** supports Kubernetes. See [trains-server-k8s](https://github.com/allegroai/trains-server-k8s)
which contains the YAML files describing the required services and detailed instructions for deploying
**trains-server** to a Kubernetes clusters.
**WARNING**: This clears all existing **TRAINS** databases.
### Can I create a Helm Chart for trains-server Kubernetes deployment? <a name="helm"></a>
$ sudo rm -R /opt/trains/
**trains-server** supports creating a Helm chart for Kubernetes deployment. See [trains-server-helm](https://github.com/allegroai/trains-server-helm)
which you can use to create a Helm chart for **trains-server** and contains detailed instructions for deploying
**trains-server** to a Kubernetes clusters using Helm.
1. Create local directories for the databases and storage.
## Configuration
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/config
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R 1000:1000 /opt/trains
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
1. Run `docker-compose`
$ /usr/local/bin/docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Resolving port conflicts preventing fixed users mode authentication and login <a name="port-conflict"></a>
A port conflict may occur between the **trains-server** MongoDB and Elastic instances and other
instances running on your system. **trains-server** uses the following default ports which may be in conflict with other instances:
* MongoDB port `27017`
* Elastic port `9200`
You can check for port conflicts in the logs in `/opt/trains/log`.
If a port conflict occurs, first change the port in your **trains-server** `/opt/trains/server/config/default/hosts.conf` file to the new port and then
run the `docker run` command with the `port` option specifying the new port to restart the **trains-server** instance.
For example, to resolve a MongoDB port conflict change port `27017` to `27018`:
1. Modify `/opt/trains/server/config/default/hosts.conf` changing the ports in the `mongo` section:
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
retry_on_timeout: true
}
index_version: "1"
}
}
mongo {
backend {
host: "mongodb://127.0.0.1:27018/backend"
}
auth {
host: "mongodb://127.0.0.1:27018/auth"
}
}
2. Start the **trains-server** MongoDB container using `--port 27018`.
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5 mongod --port 27018
In a future version of **trains-server**, to start the API server, environment variables will be available to use instead of modifying the configuration file (instead of Step 1 above).
The environment variables will be available to set different ports for both MongoDB and Elastic instances:
* `MONGODB_SERVICE_PORT` (e.g., `MONGODB_SERVICE_PORT=27018`)
* `ELASTIC_SERVICE_POST` (e.g., `ELASTIC_SERVICE_POST=9201`)
### Configuring trains-server for sub-domains and load balancers <a name="sub-domains"></a>
### How do I configure trains-server for sub-domains and load balancers? <a name="sub-domains"></a>
You can configure **trains-server** for sub-domains and a load balancer.
@@ -222,3 +203,126 @@ For example, if your domain is `trains.mydomain.com` and your sub-domains are `a
1. Run the Docker containers with our updated `docker run` commands (see [Launching Docker Containers](#https://github.com/allegroai/trains-server#launching-docker-containers)).
### Can I add web login authentication to trains-server? <a name="web-auth"></a>
By default, anyone can login to the **trains-server** Web-App.
You can configure the **trains-server** to allow only a specific set of users to access the system.
To add web login authentication to **trains-server**:
1. If you are not using the current **trains-server** version, then [upgrade](https://github.com/allegroai/trains-server#upgrade).
1. In `/opt/trains/config/apiserver.conf`, add the `auth` section and in it specify the users, for example:
**Note**: A sample `apiserver.conf` configuration file is also available [here](https://github.com/allegroai/trains-server/blob/master/docs/apiserver.conf).
auth {
# Fixed users login credentials
# No other user will be able to login
fixed_users {
enabled: true
users: [
{
username: "jane"
password: "12345678"
name: "Jane Doe"
},
{
username: "john"
password: "12345678"
name: "John Doe"
},
]
}
}
1. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).
### Can I modify the experiment watchdog settings? <a name="watchdog"></a>
The non-responsive experiment watchdog monitors experiments that were not updated for a specified period of time
and marks them as `aborted`. The watchdog is always active.
You can modify the following settings for the watchdog:
* the time threshold (in seconds) of experiment inactivity (default value is 7200 seconds (2 hours))
* the time interval (in seconds) between watchdog cycles
To change the watchdog's settings:
1. In `/opt/trains/config`, add the `services.conf` file and in it specify the watchdog settings, for example:
**Note**: A sample watchdog `services.conf` configuration file is also available [here](https://github.com/allegroai/trains-server/blob/master/docs/services.conf).
tasks {
non_responsive_tasks_watchdog {
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
threshold_sec: 7200
# Watchdog will sleep for this number of seconds after each cycle
watch_interval_sec: 900
}
}
1. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).
## Troubleshooting
### How do I fix Docker upgrade errors? <a name="common-docker-upgrade-errors"></a>
To resolve the Docker error "... The container name "/trains-???" is already in use by ...", try removing deprecated images:
docker rm -f $(docker ps -a -q)
### Why is web login authentication not working?
A port conflict between the **trains-server** MongoDB and / or Elastic instances, and other
instances running on your system may prevent web login authentication
from working correctly.
**trains-server** uses the following default ports which may be in conflict with other instances:
* MongoDB port `27017`
* Elastic port `9200`
You can check for port conflicts in the logs in `/opt/trains/log`.
If a port conflict occurs, change the MongoDB and / or Elastic ports in the `docker-compose.yml`,
and then run the Docker compose commands to restart the **trains-server** instance.
To change the MongoDB and / or Elastic ports for **trains-server**:
1. Edit the `docker-compose.yml` file.
1. In the `services/trainsserver/environment` section, add the following environment variable(s):
* For MongoDB:
MONGODB_SERVICE_PORT: <new-mongodb-port>
* For Elastic:
ELASTIC_SERVICE_PORT: <new-elasticsearch-port>
For example:
MONGODB_SERVICE_PORT: 27018
ELASTIC_SERVICE_PORT: 9201
1. For MongoDB, in the `services/mongo/ports` section, expose the new MongoDB port:
<new-mongodb-port>:27017
For example:
20718:27017
1. For Elastic, in the `services/elasticsearch/ports` section, expose the new Elastic port:
<new-elsticsearch-port>:9200
For example:
9201:9200
2. Restart **trains-server** (see the [Restarting trains-server](#restart) FAQ).

View File

@@ -1,32 +1,36 @@
# **TRAINS-server**: AWS pre-installed images
# Deploying **trains-server** on AWS
In order to easily deploy **trains-server** on AWS, we created the following Amazon Machine Images (AMIs).
To easily deploy **trains-server** on AWS, use one of our pre-built Amazon Machine Images (AMIs).
We provide AMIs per region for each released version of **trains-server**, see [Released versions](#released-versions) below.
Service port numbers on these AMIs are:
- Web: 8080
- API: 8008
- File Server: 8081
Once the AMI is up and running, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
The service port numbers on our **trains-server** AMIs:
Persistent storage configuration:
- MongoDB: /opt/trains/data/mongo/
- ElasticSearch: /opt/trains/data/elastic/
- File Server: /mnt/fileserver/
- Web application: `8080`
- API Server: `8008`
- File Server: `8081`
Instructions on launching a custom AMI from the EC2 console can be found [here](https://aws.amazon.com/premiumsupport/knowledge-center/launch-instance-custom-ami/)
and a detailed version [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/launching-instance.html).
The persistent storage configuration:
The minimum recommended instance type is **t3a.large**
- MongoDB: `/opt/trains/data/mongo/`
- ElasticSearch: `/opt/trains/data/elastic/`
- File Server: `/mnt/fileserver/`
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
For instructions on launching a custom AMI from the EC2 console, see the [AWS Knowledge Center](https://aws.amazon.com/premiumsupport/knowledge-center/launch-instance-custom-ami/) or detailed instructions in the [AWS Documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/launching-instance.html).
The minimum recommended amount of RAM is 8GB. For example, **t3.large** or **t3a.large** would have the minimum recommended amount of resources.
## Upgrading
In order to upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
To upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
### Upgrading AMI's to v0.12
**Including the automatically updated AMI**
### Note on upgrading AMIs to v0.12
Version 0.12 introduced an additional REDIS docker to the trains-server setup.
This upgrade includes the automatically updated AMI in Version 0.12. It also includes an additional REDIS docker to the **trains-server** setup.
AMI upgrading instructions:
To upgrade the AMI:
1. SSH to the EC2 machine running one of the `Latest Version AMI's`
2. Execute the following bash commands
@@ -44,47 +48,142 @@ AMI upgrading instructions:
## Released versions
The following sections provide a list containing AMI Image ID per region for each released **trains-server** version.
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
### Latest Version AMI <a name="autoupdate"></a>
**For easier upgrades: The following AMI automatically update to the latest release every reboot**
### Latest version AMI - v0.15.0 (auto update)<a name="autoupdate"></a>
* **eu-north-1** : ami-055909c1b9471451d
* **ap-south-1** : ami-0476123cc77226faf
* **eu-west-3** : ami-01df7d35ab63cca70
* **eu-west-2** : ami-00e8004c11fd0228e
* **eu-west-1** : ami-04293fbba6d3acad1
* **ap-northeast-2** : ami-004331f9c5eb13e94
* **ap-northeast-1** : ami-08cc80e2049b30e61
* **sa-east-1** : ami-06d814a0b6ffa3153
* **ca-central-1** : ami-069210ff757e9c1b7
* **ap-southeast-1** : ami-0d12cc70d6e9c0f39
* **ap-southeast-2** : ami-0b4615aa76c055267
* **eu-central-1** : ami-06537f431e52e4763
* **us-east-2** : ami-0c3cfbcb8e72ecfc5
* **us-west-1** : ami-0d83de031b83b6880
* **us-west-2** : ami-06968633c4f7187c4
* **us-east-1** : ami-07ff2f5f7ef99e8f6
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
### v0.12.1
* **eu-north-1** : ami-003118a8103286d84
* **ap-south-1** : ami-02dfe86baa48e096f
* **eu-west-3** : ami-0cc1f01267d2a780d
* **eu-west-2** : ami-0e4c8332e5ce09585
* **eu-west-1** : ami-03459a2f0b0a3b1ab
* **ap-northeast-2** : ami-08f6c2aed3a53f24c
* **ap-northeast-1** : ami-0b798eab95a7c5435
* **sa-east-1** : ami-0d3ee166c09f0d1b2
* **ca-central-1** : ami-00a758c56bd63acd5
* **ap-southeast-1** : ami-0be64d4988cd03fbb
* **ap-southeast-2** : ami-02087310d43a63f31
* **eu-central-1** : ami-097bbefeac0c74225
* **us-east-2** : ami-07eda256712b90f4d
* **us-west-1** : ami-02ef2b55cbd01c7df
* **us-west-2** : ami-037c6176ef4735360
* **us-east-1** : ami-08715c20c0e3f1c15
* **eu-north-1** : ami-0a05eb5b384a84609
* **ap-south-1** : ami-00f190b50e60b1eb5
* **eu-west-3** : ami-044fad585e1d1798e
* **eu-west-2** : ami-04ab930416a4af8c5
* **eu-west-1** : ami-00c022f333417e78e
* **ap-northeast-2** : ami-0c436e94f461a9a22
* **ap-northeast-1** : ami-018e761ad0009d5d4
* **sa-east-1** : ami-0b6c0e8e93b6ebbdd
* **ca-central-1** : ami-0cf12aab70c14237d
* **ap-southeast-1** : ami-0fe7840b9bde05581
* **ap-southeast-2** : ami-00f230e86e1afda91
* **eu-central-1** : ami-0635d13b79f76e04f
* **us-east-2** : ami-0b323078d0206db0e
* **us-west-1** : ami-07fdc1d461906f957
* **us-west-2** : ami-0a5cac167c3ebdedb
* **us-east-1** : ami-0d03956bea3aa5a44
### v0.15.0 (static update)
* **eu-north-1** : ami-0475a5068d615769b
* **ap-south-1** : ami-00c7e642badaa2ebf
* **eu-west-3** : ami-0655f769c28843e25
* **eu-west-2** : ami-04d82f48f09e2b846
* **eu-west-1** : ami-07a2aab2dc7b4ec5f
* **ap-northeast-2** : ami-0257ab220a8bc7a52
* **ap-northeast-1** : ami-0c4900af758b91dde
* **sa-east-1** : ami-021f758a4a21d5725
* **ca-central-1** : ami-0ce9703b3b47cfe70
* **ap-southeast-1** : ami-0b38689fdb8f71b74
* **ap-southeast-2** : ami-0c2b3a171e7ae4b00
* **eu-central-1** : ami-0fdd3420d6e6b4a1f
* **us-east-2** : ami-0288e9654da36ed1c
* **us-west-1** : ami-0f1d6ee0b73fe9ca2
* **us-west-2** : ami-025f0c5bfeacbf390
* **us-east-1** : ami-0b17b0bfa8b91f805
### v0.14.2 (static update)
* **eu-north-1** : ami-006d491e9e8869248
* **ap-south-1** : ami-0e55ec221687f98e7
* **eu-west-3** : ami-06ad9cf3c05c83e91
* **eu-west-2** : ami-0d05839268e748cff
* **eu-west-1** : ami-0d14c297789ce0d7a
* **ap-northeast-2** : ami-0d7fd775f0e76cc6f
* **ap-northeast-1** : ami-0c0a6e1daeb3f7a9c
* **sa-east-1** : ami-01e0c5e30e94ec887
* **ca-central-1** : ami-07a31896832734897
* **ap-southeast-1** : ami-0886d5b2d4b7fccd5
* **ap-southeast-2** : ami-0397d5a2db3c356fe
* **eu-central-1** : ami-0629f26eea22f5c17
* **us-east-2** : ami-0499c3d7bb45a1a6e
* **us-west-1** : ami-02fa8a961a4daf9f0
* **us-west-2** : ami-05c711cfab4342468
* **us-east-1** : ami-0b97d99a08012c726
### v0.14.1 (static update)
* **eu-north-1** : ami-036defe1885dced2e
* **ap-south-1** : ami-0b403aa1da6a5dc17
* **eu-west-3** : ami-0d30c2d330d1255c4
* **eu-west-2** : ami-06f0e8d075e50a029
* **eu-west-1** : ami-0da721d874f282b6d
* **ap-northeast-2** : ami-03bffe94675dd5f8c
* **ap-northeast-1** : ami-0f96520d646423673
* **sa-east-1** : ami-0c2f706a3b7d97282
* **ca-central-1** : ami-0da74525dcfd74e32
* **ap-southeast-1** : ami-066368a21cf6d232b
* **ap-southeast-2** : ami-0bfd09170067f7318
* **eu-central-1** : ami-06aa99b1c41492986
* **us-east-2** : ami-065c1880f59d03272
* **us-west-1** : ami-0b7f6b896f5058eba
* **us-west-2** : ami-0041e10ca68eef29a
* **us-east-1** : ami-0b7125e4305bbd7eb
### v0.14.0 (static update)
* **eu-north-1** : ami-02de71586ec496e38
* **ap-south-1** : ami-074b03849b51852e5
* **eu-west-3** : ami-022c388835e0eeb03
* **eu-west-2** : ami-0a151c236c6b27707
* **eu-west-1** : ami-06de69b06b4e73312
* **ap-northeast-2** : ami-0ee821b72d9f669b1
* **ap-northeast-1** : ami-03687ae215e64e100
* **sa-east-1** : ami-01eb83364b7f667af
* **ca-central-1** : ami-02e9b35f9c90377e6
* **ap-southeast-1** : ami-0d3ab5ab0048fea51
* **ap-southeast-2** : ami-0bd39d908fe3a9e06
* **eu-central-1** : ami-0b8638701311b35c4
* **us-east-2** : ami-02ff039693fc3a614
* **us-west-1** : ami-08634f7dfb608a9a7
* **us-west-2** : ami-034d693ef742b9333
* **us-east-1** : ami-0b828b05c323dde7f
### v0.13.0 (static update)
* **eu-north-1** : ami-0d9c74a015e7510d8
* **ap-south-1** : ami-02acd6dd0659bb5c1
* **eu-west-3** : ami-0f0cc5cb6d9afd194
* **eu-west-2** : ami-0298fdc0860206ed9
* **eu-west-1** : ami-0cdc072e528401d5e
* **ap-northeast-2** : ami-0055579cc95b0e53e
* **ap-northeast-1** : ami-0ced7becb9b83b5d0
* **sa-east-1** : ami-033345d0f16a1b5e4
* **ca-central-1** : ami-06c63b05aed47ae67
* **ap-southeast-1** : ami-09f0355f367f30602
* **ap-southeast-2** : ami-0bd2314163ce0fba0
* **eu-central-1** : ami-05fbae957df63e366
* **us-east-2** : ami-050c51b5b4074d3fc
* **us-west-1** : ami-06ad513073d4e5a19
* **us-west-2** : ami-0c96e1361d1d4ca94
* **us-east-1** : ami-07b669040d1eea213
### v0.12.1 (static update)
* **eu-north-1** : ami-003118a8103286d84
* **ap-south-1** : ami-02dfe86baa48e096f
* **eu-west-3** : ami-0cc1f01267d2a780d
* **eu-west-2** : ami-0e4c8332e5ce09585
* **eu-west-1** : ami-03459a2f0b0a3b1ab
* **ap-northeast-2** : ami-08f6c2aed3a53f24c
* **ap-northeast-1** : ami-0b798eab95a7c5435
* **sa-east-1** : ami-0d3ee166c09f0d1b2
* **ca-central-1** : ami-00a758c56bd63acd5
* **ap-southeast-1** : ami-0be64d4988cd03fbb
* **ap-southeast-2** : ami-02087310d43a63f31
* **eu-central-1** : ami-097bbefeac0c74225
* **us-east-2** : ami-07eda256712b90f4d
* **us-west-1** : ami-02ef2b55cbd01c7df
* **us-west-2** : ami-037c6176ef4735360
* **us-east-1** : ami-08715c20c0e3f1c15
### v0.12.0 (static update)
### v0.12.0
* **eu-north-1** : ami-03ff8ab48cd43e77e
* **ap-south-1** : ami-079c1a41ff836487c
* **eu-west-3** : ami-0121ef0398ae87ab0
@@ -102,7 +201,8 @@ The following sections provide a list containing AMI Image ID per region for eac
* **us-west-2** : ami-0018d5a7e58966848
* **us-east-1** : ami-08f24178fc14a84d2
### v0.11.0
### v0.11.0 (static update)
* **eu-north-1** : ami-0cbe338f058018c97
* **ap-south-1** : ami-06d72ff894f7a5e5d
* **eu-west-3** : ami-00f2a45d67df2d2f3
@@ -120,7 +220,8 @@ The following sections provide a list containing AMI Image ID per region for eac
* **us-west-2** : ami-0e384b6f78bf96ebe
* **us-east-1** : ami-0a7b46f907d5d9c4a
### v0.10.1
### v0.10.1 (static update)
* **eu-north-1** : ami-09937ec4d18350c32
* **ap-south-1** : ami-089d6ba7541ec4c7f
* **eu-west-3** : ami-0accb1a94bdd5c5c1
@@ -138,7 +239,8 @@ The following sections provide a list containing AMI Image ID per region for eac
* **us-west-2** : ami-0d1cb8ba7de246ff0
* **us-east-1** : ami-049ccba6abdb40cba
### v0.10.0
### v0.10.0 (static update)
* **eu-north-1** : ami-05ba33c763877e54e
* **ap-south-1** : ami-0529eec569161cae5
* **eu-west-3** : ami-03cb9396f63e26ff6
@@ -157,7 +259,7 @@ The following sections provide a list containing AMI Image ID per region for eac
* **us-west-2** : ami-04a522ecb2250fb44
* **us-east-1** : ami-0a66ddbd50959f91e
### v0.9.0
### v0.9.0 (static update)
* **us-east-1** : ami-0991ad536ecbacdac
* **eu-north-1** : ami-07cbcdff501b14afe
@@ -175,3 +277,4 @@ The following sections provide a list containing AMI Image ID per region for eac
* **us-east-2** : ami-03b01914b07428488
* **us-west-1** : ami-0cf4768e9d47ed076
* **us-west-2** : ami-0b145f37da31eb9fb

58
docs/install_gcp.md Normal file
View File

@@ -0,0 +1,58 @@
# Deploying Trains Server on Google Cloud Platform
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
The service port numbers on our Trains Server GCP Custom Image are:
- Web application: `8080`
- API Server: `8008`
- File Server: `8081`
The persistent storage configuration:
- MongoDB: `/opt/trains/data/mongo/`
- ElasticSearch: `/opt/trains/data/elastic/`
- File Server: `/mnt/fileserver/`
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
## Importing the Custom Image to your GCP account
In order to launch an instance using the Trains Server GCP Custom Image, you'll need to import the image to your custom images list.
**Note:** there's **no need** to upload the image file to Google Cloud Storage - we already provide links to image files stored in Google Storage
To import the image to your custom images list:
1. In the Cloud Console, go to the [Images](https://console.cloud.google.com/compute/images) page.
1. At the top of the page, click **Create image**.
1. In the **Name** field, specify a unique name for the image.
1. Optionally, specify an image family for your new image, or configure specific encryption settings for the image.
1. Click the **Source** menu and select **Cloud Storage file**.
1. Enter the Trains Server image bucket path (see [Trains Server GCP Custom Image](#released-versions)), for example:
`allegro-files/trains-server/trains-server.tar.gz`
1. Click the **Create** button to import the image. The process can take several minutes depending on the size of the boot disk image.
For more information see [Import the image to your custom images list](https://cloud.google.com/compute/docs/import/import-existing-image#import_image) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
## Launching an instance with a Custom Image
For instructions on launching an instance using a GCP Custom Image, see the [Manually importing virtual disks](https://cloud.google.com/compute/docs/import/import-existing-image#overview) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
For more information on Custom Images, see [Custom Images](https://cloud.google.com/compute/docs/images#custom_images) in the Compute Engine Documentation.
The minimum recommended requirements for Trains Server are:
- 2 vCPUs
- 7.5GB RAM
## Upgrading
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
## Released versions
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
### Latest version image (v0.14.1)
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz

97
docs/install_linux_mac.md Normal file
View File

@@ -0,0 +1,97 @@
# Launching the **trains-server** Docker in Linux or macOS
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux users:
* You must be logged in as a user with sudo privileges.
* Use `bash` for all command-line instructions in this installation.
To launch **trains-server** on Linux or macOS:
1. Install Docker.
* Linux - see [Docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
* macOS - see [Docker for macOS](https://docs.docker.com/docker-for-mac/install/).
1. Verify the Docker CE installation. Execute the command:
sudo docker run hello-world
The expected is output is:
Hello from Docker!
This message shows that your installation appears to be working correctly.
To generate this message, Docker took the following steps:
1. The Docker client contacted the Docker daemon.
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
1. Increase `vm.max_map_count` for ElasticSearch docker.
Linux:
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
sudo service docker restart
macOS:
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
1. Remove any previous installation of **trains-server**.
**WARNING**: This clears all existing **Trains** databases.
sudo rm -R /opt/trains/
1. Create local directories for the databases and storage.
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/data/redis
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/config
sudo mkdir -p /opt/trains/data/fileserver
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
1. Grant access to the Dockers.
Linux:
sudo chown -R 1000:1000 /opt/trains
macOS:
sudo chown -R $(whoami):staff /opt/trains
1. Download the **trains-server** docker-compose YAML file.
cd /opt/trains
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
1. Run `docker-compose` with the downloaded configuration file.
sudo docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
* API server on port `8008`
* File server on port `8081`
## Next Step
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).

50
docs/install_win.md Normal file
View File

@@ -0,0 +1,50 @@
# Launching the **trains-server** Docker in Windows 10
For Windows, we recommend launching our pre-built Docker image on a Linux virtual machine.
However, you can launch **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
To launch **trains-server** on Windows 10:
1. Install the Docker Desktop for Windows application by either:
* Following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
* Running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
1. Increase the memory allocation in Docker Desktop to `4GB`.
1. In your Windows notification area (system tray), right click the Docker icon.
1. Click *Settings*, *Advanced*, and then set the memory to at least `4096`.
1. Click *Apply*.
1. Remove any previous installation of **trains-server**.
**WARNING**: This clears all existing **Trains** databases.
rmdir c:\opt\trains /s
1. Create local directories for data and logs. Open PowerShell and execute the following commands:
cd c:
mkdir c:\opt\trains\data
mkdir c:\opt\trains\logs
1. Save the **trains-server** docker-compose YAML file.
cd c:\opt\trains
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml -o docker-compose-win10.yml
1. Run `docker-compose`. In PowerShell, execute the following commands:
docker-compose -f docker-compose-win10.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
* API server on port `8008`
* File server on port `8081`
## Next Step
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).

View File

@@ -10,10 +10,14 @@ from flask_cors import CORS
from config import config
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
@app.route("/", methods=["POST"])
def upload():
@@ -54,12 +58,13 @@ def main():
parser.add_argument(
"--upload-folder",
"-u",
default="/mnt/fileserver",
default=DEFAULT_UPLOAD_FOLDER,
help="Upload folder (default %(default)s)",
)
args = parser.parse_args()
app.config["UPLOAD_FOLDER"] = args.upload_folder
if app.config.get("UPLOAD_FOLDER") is None:
app.config["UPLOAD_FOLDER"] = args.upload_folder
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)

1
server/api_version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "2.8.0"

View File

@@ -47,6 +47,7 @@ _error_codes = {
128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
131: ('events_not_added', 'events not added'),
# Models
200: ('model_error', 'general task error'),
@@ -89,6 +90,8 @@ _error_codes = {
1003: ('worker_registered', 'worker is already registered'),
1004: ('worker_not_registered', 'worker is not registered'),
1005: ('worker_stats_not_found', 'worker stats not found'),
1104: ('invalid_scroll_id', 'Invalid scroll id'),
},
(401, 'unauthorized'): {
@@ -105,7 +108,6 @@ _error_codes = {
(403, 'forbidden'): {
10: ('routing_error', 'forbidden (routing error)'),
11: ('missing_routing_header', 'forbidden (missing routing header)'),
12: ('blocked_internal_endpoint', 'forbidden (blocked internal endpoint)'),
20: ('role_not_allowed', 'forbidden (not allowed for role)'),
21: ('no_write_permission', 'forbidden (modification not allowed)'),
@@ -121,6 +123,7 @@ _error_codes = {
100: ('data_error', 'general data error'),
101: ('inconsistent_data', 'inconsistent data encountered in document'),
102: ('database_unavailable', 'database is temporarily unavailable'),
110: ('update_failed', 'update failed'),
# Index-related issues
201: ('missing_index', 'missing internal index'),

View File

@@ -5,14 +5,15 @@ from typing import Union, Type, Iterable
import jsonmodels.errors
import six
import validators
from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors
from utilities.json import loads, dumps
def make_default(field_cls, default_value):
@@ -66,9 +67,7 @@ class DictField(fields.BaseField):
value_types = tuple()
return tuple(
_LazyType(type_)
if isinstance(type_, six.string_types)
else type_
_LazyType(type_) if isinstance(type_, six.string_types) else type_
for type_ in value_types
)
@@ -78,6 +77,9 @@ class DictField(fields.BaseField):
if not self.value_types:
return
if not value:
return
for item in value.values():
self.validate_single_value(item)
@@ -104,7 +106,7 @@ class IntField(fields.IntField):
def validate_lucene_query(value):
if value == '':
if value == "":
return
try:
parser.parse(value)
@@ -122,6 +124,7 @@ class LuceneQueryField(fields.StringField):
class NullableEnumValidator(EnumValidator):
"""Validator for enums that allows a None value."""
def validate(self, value):
if value is not None:
super(NullableEnumValidator, self).validate(value)
@@ -150,10 +153,6 @@ class EnumField(fields.StringField):
class ActualEnumField(fields.StringField):
@property
def types(self):
return (self.__enum,)
def __init__(
self,
enum_class: Type[Enum],
@@ -164,12 +163,13 @@ class ActualEnumField(fields.StringField):
**kwargs
):
self.__enum = enum_class
self.types = (enum_class,)
# noinspection PyTypeChecker
choices = list(enum_class)
validator_cls = EnumValidator if required else NullableEnumValidator
validators = [*(validators or []), validator_cls(*choices)]
super().__init__(
default=default and self.parse_value(default),
default=self.parse_value(default) if default else NotSet,
*args,
required=required,
validators=validators,
@@ -194,7 +194,7 @@ class EmailField(fields.StringField):
super().validate(value)
if value is None:
return
if validators.email(value) is not True:
if email_validator(value) is not True:
raise errors.bad_request.InvalidEmailAddress()
@@ -203,14 +203,14 @@ class DomainField(fields.StringField):
super().validate(value)
if value is None:
return
if validators.domain(value) is not True:
if domain_validator(value) is not True:
raise errors.bad_request.InvalidDomainName()
class StringEnum(Enum):
def __str__(self):
return self.value
class JsonSerializableMixin:
def to_json(self: ModelBase):
return dumps(self.to_struct())
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))

View File

@@ -58,3 +58,7 @@ class UpdateResponse(models.Base):
class PagedRequest(models.Base):
page = fields.IntField()
page_size = fields.IntField()
class IdResponse(models.Base):
id = fields.StringField(required=True)

View File

@@ -1,9 +1,12 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType
from bll.event.scalar_key import ScalarKeyEnum
@@ -17,4 +20,52 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(items_types=str)
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
class DebugImagesRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
iters: int = IntField(default=1, validators=validators.Min(1))
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)

View File

@@ -9,7 +9,7 @@ from apimodels.tasks import PublishResponse as TaskPublishResponse
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,), required=True)
labels = DictField(value_types=string_types+(int,))
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()

View File

@@ -0,0 +1,10 @@
from jsonmodels import fields, models
class Filter(models.Base):
system_tags = fields.ListField([str])
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)

View File

@@ -12,3 +12,4 @@ class ReportStatsOptionResponse(Base):
enabled_time = DateTimeField(nullable=True)
enabled_version = StringField(nullable=True)
enabled_user = StringField(nullable=True)
current_version = StringField()

View File

@@ -1,6 +1,6 @@
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum
from apimodels import DictField, ListField
@@ -9,6 +9,24 @@ from database.model.task.task import TaskType
from database.utils import get_options
class ArtifactTypeData(models.Base):
preview = StringField()
content_type = StringField()
data_hash = StringField()
class Artifact(models.Base):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(validators=Enum("input", "output"), default="output")
uri = StringField()
hash = StringField()
content_size = IntField()
timestamp = IntField()
type_data = EmbeddedField(ArtifactTypeData)
display_data = ListField([list])
class StartedResponse(UpdateResponse):
started = IntField()
@@ -72,3 +90,31 @@ class CreateRequest(TaskData):
class PingRequest(TaskRequest):
pass
class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
new_task_tags = ListField([str])
new_task_system_tags = ListField([str])
new_task_parent = StringField()
new_task_project = StringField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
class AddOrUpdateArtifactsRequest(TaskRequest):
artifacts = ListField([Artifact], required=True)
class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str])
updated = ListField([str])
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)

View File

@@ -1,4 +1,3 @@
import json
from enum import Enum
import six
@@ -13,7 +12,7 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apimodels import make_default, ListField, EnumField
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
@@ -61,7 +60,7 @@ class IdNameEntry(Base):
name = StringField()
class WorkerEntry(Base):
class WorkerEntry(Base, JsonSerializableMixin):
key = StringField() # not required due to migration issues
id = StringField(required=True)
user = EmbeddedField(IdNameEntry)
@@ -75,13 +74,6 @@ class WorkerEntry(Base):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
def to_json(self):
return json.dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**json.loads(s))
class CurrentTaskEntry(IdNameEntry):
running_time = IntField()

View File

@@ -0,0 +1,462 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
return config.get("services.events.max_metrics_concurrency", 4)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state_)
for metric_state in state_.metrics:
metric_state.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
es_index=es_index,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
)
)
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE in stats.event_stats_by_type
]
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
es_index,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
def _init_metric_states(
self, es_index, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(self._max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(self._init_metric_states_for_task, es_index=es_index),
tasks.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], es_index
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
}
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
variants=[
init_variant_scroll_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
es_index: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
"""
if metric.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
if "aggregations" not in es_res:
return metric.task, metric.name, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for ev in dpath.get(v, "events/hits/hits")
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]

View File

@@ -1,11 +1,10 @@
import hashlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from enum import Enum
from operator import attrgetter
from typing import Sequence
from typing import Sequence, Set, Tuple
import attr
import six
from elasticsearch import helpers
from mongoengine import Q
@@ -14,67 +13,92 @@ from nested_dict import nested_dict
import database.utils as dbutils
import es_factory
from apierrors import errors
from bll.event.event_metrics import EventMetrics
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus
from redis_manager import redman
from timing_context import TimingContext
from utilities.dicts import flatten_nested_items
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
# noinspection PyTypeChecker
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s
class TaskEventsResult(object):
events = attr.ib(type=list, default=attr.Factory(list))
total_events = attr.ib(type=int, default=0)
next_scroll_id = attr.ib(type=str, default=None)
class EventBLL(object):
id_fields = ["task", "iter", "metric", "variant", "key"]
id_fields = ("task", "iter", "metric", "variant", "key")
def __init__(self, events_es=None):
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
@property
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_events = nested_dict(
task_last_scalar_events = nested_dict(
3, dict
) # task_id -> metric_hash -> variant_hash -> MetricEvent
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
if "type" not in event:
raise errors.BadRequest("Event must have a 'type' field", event=event)
event_type = event.get("type")
if event_type is None:
errors_per_type["Event must have a 'type' field"] += 1
continue
event_type = event["type"].replace(" ", "_")
event_type = event_type.replace(" ", "_")
if event_type not in EVENT_TYPES:
raise errors.BadRequest(
"Invalid event type {}".format(event_type),
event=event,
types=EVENT_TYPES,
)
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
continue
event["type"] = event_type
@@ -103,6 +127,9 @@ class EventBLL(object):
event["value"] = event["values"]
del event["values"]
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
index_name = EventMetrics.get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
@@ -117,89 +144,82 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if iter is not None:
task_iteration[task_id] = max(iter, task_iteration[task_id])
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
if event_type == EventType.metrics_scalar.value:
self._update_last_metric_event_for_task(
task_last_events=task_last_events, task_id=task_id, event=event
)
else:
es_action["_routing"] = task_id
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, ids=invalid_task_ids
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
errors_in_bulk = []
added = 0
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_in_bulk.append(info)
if not updated:
remaining_tasks.add(task_id)
continue
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
# who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions))
return added, errors_in_bulk
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
def _update_last_scalar_events_for_task(self, last_events, event):
"""
Update task_last_events structure for the provided task_id with the provided event details if this event is more
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
@@ -210,13 +230,34 @@ class EventBLL(object):
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_events = task_last_events[task_id]
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
def _update_last_metric_events_for_task(self, last_events, event):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
event_type = event.get("type")
if not (metric and event_type):
return
timestamp = last_events[metric][event_type].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric][event_type] = event
def _update_task(
self,
company_id,
task_id,
now,
iter_max=None,
last_scalar_events=None,
last_events=None,
):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
@@ -229,15 +270,18 @@ class EventBLL(object):
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_events:
fields["last_values"] = list(
if last_scalar_events:
fields["last_scalar_values"] = list(
flatten_nested_items(
last_events,
last_scalar_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
)
)
if last_events:
fields["last_events"] = last_events
if not fields:
return False
@@ -245,7 +289,7 @@ class EventBLL(object):
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.id_fields if field in event)
return "-".join(id_values)
return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events(
self,
@@ -276,7 +320,9 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
es_res = self.es.search(
index=es_index, body=es_req, scroll="1h", routing=task_id
)
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
@@ -294,10 +340,16 @@ class EventBLL(object):
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"variants": {
"terms": {"field": "variant"},
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"iters": {
"terms": {
@@ -496,8 +548,18 @@ class EventBLL(object):
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": 200},
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
}
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
@@ -537,14 +599,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": 1000,
"size": EventMetrics.MAX_METRICS_COUNT,
"order": {"_term": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": 1000,
"size": EventMetrics.MAX_VARIANTS_COUNT,
"order": {"_term": "asc"},
},
"aggs": {

View File

@@ -1,12 +1,13 @@
import itertools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from typing import Sequence, Tuple, Callable
from mongoengine import Q
from apierrors import errors
@@ -20,10 +21,19 @@ from utilities import safe_get
log = config.logger(__file__)
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventMetrics:
MAX_TASKS_COUNT = 100
MAX_TASKS_COUNT = 50
MAX_METRICS_COUNT = 200
MAX_VARIANTS_COUNT = 500
MAX_AGGS_ELEMENTS_COUNT = 50
def __init__(self, es: Elasticsearch):
self.es = es
@@ -62,6 +72,12 @@ class EventMetrics:
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
@@ -97,6 +113,31 @@ class EventMetrics:
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _split_metrics_by_max_aggs_count(
self, task_metrics: Sequence[TaskMetric]
) -> Iterable[Sequence[TaskMetric]]:
"""
Return task metrics in groups where amount of task metrics in each group
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
variants while always preserving all their tasks in the same group
"""
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
yield task_metrics
return
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
groups = []
for group in tm_grouped.values():
groups.append(group)
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
yield list(itertools.chain(*groups))
groups = []
if groups:
yield list(itertools.chain(*groups))
return
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
@@ -126,21 +167,25 @@ class EventMetrics:
if not intervals:
return {}
with ThreadPoolExecutor(len(intervals)) as pool:
metrics = list(
itertools.chain.from_iterable(
pool.map(
partial(
get_func, task_ids=task_ids, es_index=es_index, key=key
),
intervals,
)
intervals = list(
itertools.chain.from_iterable(
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
for i, tms in intervals
)
)
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def _get_metric_intervals(
@@ -310,7 +355,13 @@ class EventMetrics:
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
"tasks": {
"terms": {
"field": "task",
"size": self.MAX_TASKS_COUNT,
},
"aggs": aggregation,
}
},
}
},
@@ -396,3 +447,50 @@ class EventMetrics:
]
}
}
def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence[Tuple]:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return [(tid, []) for tid in task_ids]
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_concurrency) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
),
task_ids,
)
return list(zip(task_ids, res))
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"type": event_type.value}},
]
}
},
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
]

View File

@@ -0,0 +1,169 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = "log"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
def init_state(state_: LogEventsScrollState):
state_.task = task_id
def validate_state(state_: LogEventsScrollState):
"""
Checks that the task id stored in the state
is equal to the one passed with the current call
Refresh the state if requested
"""
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
def _get_events(
self,
es_index,
batch_size: int,
navigate_earlier: bool,
state: LogEventsScrollState,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": state.task}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if navigate_earlier and state.last_min_timestamp is not None:
es_req["search_after"] = [state.last_min_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": state.task}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
already_present_ids = set(ev["_id"] for ev in events)
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[
*events,
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
],
hits_total,
)

View File

@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from apimodels import StringEnum
from utilities.stringenum import StringEnum
from bll.util import extract_properties_to_lists
from config import config
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"interval": f"{interval}ms",
"min_doc_count": 1,
}
}
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"interval": f"{interval}ms",
"min_doc_count": 1,
"format": "strict_date_time",
}

View File

@@ -0,0 +1,85 @@
from typing import Sequence
from mongoengine import Q
from config import config
from database.model.base import GetMixin
from database.model.model import Model
from database.model.task.task import Task
from redis_manager import redman
from utilities import json
log = config.logger(__file__)
class OrgBLL:
_tags_field = "tags"
_system_tags_field = "system_tags"
_settings_prefix = "services.organization"
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
@property
def _tags_cache_expiration_seconds(self):
return config.get(
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
)
@staticmethod
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
filter_str = "_".join(filter_) if filter_ else ""
return f"{field}_{company}_{filter_str}"
@staticmethod
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
query = Q(company=company)
if filter_:
query &= GetMixin.get_list_field_query("system_tags", filter_)
tags = set()
for cls_ in (Task, Model):
tags |= set(cls_.objects(query).distinct(field))
return tags
def get_tags(
self, company, include_system: bool = False, filter_: Sequence[str] = None
) -> dict:
"""
Get tags and optionally system tags for the company
Return the dictionary of tags per tags field name
The function retrieves both cached values from Redis in one call
and re calculates any of them if missing in Redis
"""
fields = [
self._tags_field,
*([self._system_tags_field] if include_system else []),
]
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
cached = self.redis.mget(redis_keys)
ret = {}
for field, tag_data, key in zip(fields, cached, redis_keys):
if tag_data is not None:
tags = json.loads(tag_data)
else:
tags = list(self._get_tags_from_db(company, field, filter_))
self.redis.setex(
key,
time=self._tags_cache_expiration_seconds,
value=json.dumps(tags),
)
ret[field] = tags
return ret
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
"""
Updates system tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
"""
if reset or tags is not None:
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
if reset or system_tags is not None:
self.redis.delete(
self._get_tags_cache_key(company, self._system_tags_field)
)

View File

@@ -0,0 +1 @@
from .project_bll import ProjectBLL

View File

@@ -0,0 +1,33 @@
from typing import Sequence, Optional
from mongoengine import Q
from config import config
from database.model.model import Model
from database.model.task.task import Task
from timing_context import TimingContext
log = config.logger(__file__)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res

View File

@@ -9,9 +9,12 @@ import es_factory
from apierrors import errors
from bll.queue.queue_metrics import QueueMetrics
from bll.workers import WorkerBLL
from config import config
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry
log = config.logger(__file__)
class QueueBLL(object):
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
@@ -189,9 +192,7 @@ class QueueBLL(object):
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(
pop__entries=-1, last_update=datetime.utcnow(), upsert=False
)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@@ -200,6 +201,11 @@ class QueueBLL(object):
if not queue.entries:
return
try:
Queue.objects(**query).update(last_update=datetime.utcnow())
except Exception:
log.exception("Error while updating Queue.last_update")
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:

View File

@@ -0,0 +1,79 @@
from contextlib import contextmanager
from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
import database
from timing_context import TimingContext
T = TypeVar("T")
def _do_nothing(_: T):
return
class RedisCacheManager(Generic[T]):
"""
Class for store/retrieve of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
self.expiration_interval - expiration interval in seconds
"""
def __init__(
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
):
self.state_class = state_class
self.redis = redis
self.expiration_interval = expiration_interval
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
@contextmanager
def get_or_create_state(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
):
"""
Try to retrieve state with the given id from the Redis cache if yes then validates it
If no then create a new one with randomly generated id
Yield the state and write it back to redis once the user code block exits
:param state_id: id of the state to retrieve
:param init_state: user callback to init the newly created state
If not passed then no init except for the id generation is done
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
try:
yield state
finally:
self.set_state(state)

View File

@@ -6,6 +6,8 @@ from time import sleep
import attr
import psutil
from utilities.threads_manager import ThreadsManager
class ResourceMonitor(Thread):
@attr.s(auto_attribs=True)
@@ -58,7 +60,9 @@ class ResourceMonitor(Thread):
)
def run(self):
while True:
while not ThreadsManager.terminating:
sleep(self.sample_interval_sec)
sample = self._get_sample()
with self._lock:
@@ -67,21 +71,20 @@ class ResourceMonitor(Thread):
self._avg = self._avg.avg(sample, self._count)
self._count += 1
sleep(self.sample_interval_sec)
def get_stats(self) -> dict:
""" Returns current resource statistics and clears internal resource statistics """
with self._lock:
min_ = attr.asdict(self._min)
max_ = attr.asdict(self._max)
avg = attr.asdict(self._avg)
res = {
"interval_sec": (datetime.utcnow() - self._clear_time).total_seconds(),
"num_cores": psutil.cpu_count(),
**{
k: {"min": v, "max": max_[k], "avg": avg[k]}
for k, v in min_.items()
}
}
interval = datetime.utcnow() - self._clear_time
self._clear()
return res
return {
"interval_sec": interval.total_seconds(),
"num_cores": psutil.cpu_count(),
**{
k: {"min": v, "max": max_[k], "avg": avg[k]}
for k, v in min_.items()
}
}

View File

@@ -53,11 +53,8 @@ class StatisticsReporter:
report_interval = timedelta(
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
while True:
sleep(report_interval.total_seconds())
sleep(report_interval.total_seconds())
while not ThreadsManager.terminating:
try:
for company in Company.objects(
defaults__stats_option__enabled=True
@@ -68,6 +65,8 @@ class StatisticsReporter:
except Exception as ex:
log.exception(f"Failed collecting stats: {str(ex)}")
sleep(report_interval.total_seconds())
@classmethod
@threads.register("sender", daemon=True)
def start_sender(cls):
@@ -86,7 +85,7 @@ class StatisticsReporter:
WarningFilter.attach()
while True:
while not ThreadsManager.terminating:
try:
report = cls.send_queue.get()
@@ -281,7 +280,7 @@ class StatisticsReporter:
]
return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline)
for group in Task.aggregate(pipeline)
}

View File

@@ -4,4 +4,5 @@ from .utils import (
update_project_time,
validate_status_change,
split_by,
ParameterKeyEscaper,
)

View File

@@ -0,0 +1,89 @@
from datetime import timedelta, datetime
from time import sleep
from apierrors import errors
from bll.task import ChangeStatusRequest
from config import config
from database.model.task.task import TaskStatus, Task
from utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
class NonResponsiveTasksWatchdog:
threads = ThreadsManager()
class _Settings:
"""
Retrieves watchdog settings from the config file
The properties are not cached so that the updates in
the config file are reflected
"""
_prefix = "services.tasks.non_responsive_tasks_watchdog"
@property
def enabled(self):
return config.get(f"{self._prefix}.enabled", True)
@property
def watch_interval_sec(self):
return config.get(f"{self._prefix}.watch_interval_sec", 900)
@property
def threshold_sec(self):
return config.get(f"{self._prefix}.threshold_sec", 7200)
settings = _Settings()
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:
stopped = cls.cleanup_tasks(
threshold_sec=cls.settings.threshold_sec
)
log.info(f"{stopped} non-responsive tasks stopped")
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
sleep(watch_interval)
@classmethod
def cleanup_tasks(cls, threshold_sec):
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(seconds=threshold_sec)
ref_time = datetime.utcnow() - threshold
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
)
tasks = list(
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
"id", "name", "status", "project", "last_update"
)
)
log.info(f"{len(tasks)} non-responsive tasks found")
if not tasks:
return 0
err_count = 0
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
try:
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
except errors.bad_request.FailedChangingTaskStatus:
err_count += 1
return len(tasks) - err_count

View File

@@ -1,41 +1,64 @@
import re
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from bll.organization import OrgBLL
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.project import Project
from database.model.task.metrics import EventStats, MetricEventStats
from database.model.task.output import Output
from database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
Artifact,
external_task_types,
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from services.utils import validate_tags
from timing_context import TimingContext
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
from utilities.dicts import deep_merge
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__)
org_bll = OrgBLL()
class TaskBLL(object):
threads = ThreadsManager("TaskBLL")
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
@@ -145,30 +168,96 @@ class TaskBLL(object):
return model
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
def clone_task(
cls,
company_id,
user_id,
task_id,
name: Optional[str] = None,
comment: Optional[str] = None,
parent: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
validate_tags(tags, system_tags)
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
now = datetime.utcnow()
with translate_errors_context():
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
execution=execution_dict,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
return new_task
@classmethod
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project:
Project.get_for_writing(company=task.company, id=task.project)
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
if invalid_keys:
raise errors.bad_request.ValidationError(
"execution.parameters keys contain whitespace", keys=invalid_keys
)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
@@ -208,7 +297,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = Task.aggregate(*pipeline)
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
@@ -226,7 +315,8 @@ class TaskBLL(object):
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
"""
@@ -238,7 +328,8 @@ class TaskBLL(object):
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_values: Last reported metrics summary (value, metric, variant).
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
@@ -249,17 +340,33 @@ class TaskBLL(object):
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_values is not None:
if last_scalar_values is not None:
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_values:
for path, value in last_scalar_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
extra_updates[op_path("max", *path[:-1], "max_value")] = value
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
}
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
)
for metric_key, metric_data in last_events.items()
}
extra_updates["metric_stats"] = metric_stats
Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)
@@ -373,7 +480,7 @@ class TaskBLL(object):
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
@@ -412,56 +519,95 @@ class TaskBLL(object):
).execute()
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start_non_responsive_tasks_watchdog(cls):
log = config.logger("non_responsive_tasks_watchdog")
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(
seconds=config.get(
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
)
)
while True:
sleep(
config.get(
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec",
900,
)
)
try:
def add_or_update_artifacts(
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
) -> Tuple[List[str], List[str]]:
key = attrgetter("key", "mode")
ref_time = datetime.utcnow() - threshold
if not artifacts:
return [], []
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
artifacts: List[Artifact] = [
Artifact(**artifact.to_struct()) for artifact in artifacts
]
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
for retry in range(attempts):
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
tasks = list(
Task.objects(
status__in=relevant_status, last_update__lt=ref_time
).only("id", "name", "status", "project", "last_update")
current = list(map(key, task.execution.artifacts))
updated = [a for a in artifacts if key(a) in current]
added = [a for a in artifacts if a not in updated]
filter = {"_id": task_id, "company": company_id}
update = {}
array_filters = None
if current:
filter["execution.artifacts"] = {
"$size": len(current),
"$all": [
*(
{"$elemMatch": {"key": key, "mode": mode}}
for key, mode in current
)
],
}
else:
filter["$or"] = [
{"execution.artifacts": {"$exists": False}},
{"execution.artifacts": {"$size": 0}},
]
if added:
update["$push"] = {
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
}
if updated:
update["$set"] = {
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
for index, artifact in enumerate(updated)
}
array_filters = [
{
f"artifact{index}.key": artifact.key,
f"artifact{index}.mode": artifact.mode,
}
for index, artifact in enumerate(updated)
]
if not update:
return [], []
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
filter=filter,
update=update,
array_filters=array_filters,
upsert=False,
)
if tasks:
if result.matched_count >= 1:
break
log.info(f"Stopping {len(tasks)} non-responsive tasks")
wait_msec = random() * int(
config.get("services.tasks.artifacts.update_retry_msec", 500)
)
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
log.warning(
f"Failed to update artifacts for task {task_id} (updated by another party),"
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
)
log.info(f"Done")
sleep(wait_msec / 1000)
else:
raise errors.server_error.UpdateFailed(
"task artifacts updated by another party"
)
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
return [a.key for a in added], [a.key for a in updated]
@staticmethod
def get_aggregated_project_execution_parameters(
@@ -502,10 +648,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(
Task.aggregate(*pipeline),
None,
)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
@@ -513,7 +656,10 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [r["_id"] for r in result.get("results", [])]
results = [
ParameterKeyEscaper.unescape(r["_id"])
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@@ -3,6 +3,7 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from boltons.dictutils import OneToOne
from apierrors import errors
from database.errors import translate_errors_context
@@ -171,3 +172,26 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@@ -33,8 +33,8 @@ log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es if es is not None else es_factory.connect("workers")
self.redis = redis if redis is not None else redman.connection("workers")
self.es_client = es or es_factory.connect("workers")
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@property
@@ -223,7 +223,7 @@ class WorkerBLL:
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(*projection)
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
task_ids = task_ids.union(
filter(

View File

@@ -47,7 +47,7 @@ class BasicConfig:
def logger(self, name):
if Path(name).is_file():
name = Path(name).stem
path = ".".join((self.prefix, Path(name).stem))
path = ".".join((self.prefix, name))
return logging.getLogger(path)
def _read_extra_env_config_values(self):
@@ -57,7 +57,7 @@ class BasicConfig:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
@@ -77,7 +77,7 @@ class BasicConfig:
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
]
if invalid:
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
return [path for path in paths if path.is_dir()]
def _load(self, verbose=True):

View File

@@ -34,6 +34,12 @@
aggregate {
allow_disk_use: true
}
pre_populate {
enabled: false
zip_file: "/path/to/export.zip"
fail_on_error: false
}
}
auth {

View File

@@ -32,6 +32,11 @@ mongo {
}
redis {
apiserver {
host: "127.0.0.1"
port: 6379
db: 0
}
workers {
host: "127.0.0.1"
port: 6379

View File

@@ -13,17 +13,21 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}
}
}

View File

@@ -1,3 +1,13 @@
{
es_index_prefix:"events"
}
es_index_prefix: "events"
ignore_iteration {
metrics: [":monitor:machine", ":monitor:gpu"]
}
# max number of concurrent queries to ES when calculating events metrics
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
events_retrieval {
state_expiration_sec: 3600
}

View File

@@ -0,0 +1,3 @@
tags_cache {
expiration_seconds: 3600
}

View File

@@ -1,7 +1,14 @@
non_responsive_tasks_watchdog {
enabled: true
# In-progress tasks older than this value in seconds will be stopped by the watchdog
threshold_sec: 7200
# Watchdog will sleep for this number of seconds after each cycle
watch_interval_sec: 900
}
artifacts {
update_attempts: 10
update_retry_msec: 500
}

View File

@@ -1,43 +1,43 @@
from functools import lru_cache
from pathlib import Path
from os import getenv
from pathlib import Path
from version import __version__
from config import config
root = Path(__file__).parent.parent
@lru_cache()
def get_build_number():
try:
return (root / "BUILD").read_text().strip()
except FileNotFoundError:
return ""
@lru_cache()
def get_version():
try:
return (root / "VERSION").read_text().strip()
except FileNotFoundError:
return ""
@lru_cache()
def get_commit_number():
try:
return (root / "COMMIT").read_text().strip()
except FileNotFoundError:
return ""
@lru_cache()
def get_deployment_type() -> str:
value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE")
def _get(prop_name, env_suffix=None, default=""):
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
if value:
return value
try:
value = (root / "DEPLOY").read_text().strip()
return (root / prop_name).read_text().strip()
except FileNotFoundError:
pass
return default
return value or "manual"
@lru_cache()
def get_build_number():
return _get("BUILD")
@lru_cache()
def get_version():
return _get("VERSION", default=__version__)
@lru_cache()
def get_commit_number():
return _get("COMMIT")
@lru_cache()
def get_deployment_type() -> str:
return _get("DEPLOY", env_suffix="DEPLOYMENT_TYPE", default="manual")
def get_default_company():
return config.get("apiserver.default_company")

View File

@@ -14,6 +14,9 @@ from mongoengine import (
DictField,
DynamicField,
)
from mongoengine.fields import key_not_string, key_starts_with_dollar
NoneType = type(None)
class LengthRangeListField(ListField):
@@ -125,17 +128,39 @@ def contains_empty_key(d):
return True
class SafeMapField(MapField):
class DictValidationMixin:
"""
DictField validation in MongoEngine requires default alias and permissions to access DB version:
https://github.com/MongoEngine/mongoengine/issues/2239
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
"""
def _safe_validate(self: DictField, value):
if not isinstance(value, dict):
self.error("Only dictionaries may be used in a DictField")
if key_not_string(value):
msg = "Invalid dictionary key - documents must have only string keys"
self.error(msg)
if key_starts_with_dollar(value):
self.error(
'Invalid dictionary key name - keys may not startswith "$" characters'
)
super(DictField, self).validate(value)
class SafeMapField(MapField, DictValidationMixin):
def validate(self, value):
super(SafeMapField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField):
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
super(SafeDictField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
SortedListField that does not raise an error in case items are not comparable
(in which case they will be sorted by their string representation)
"""
def to_mongo(self, *args, **kwargs):
try:
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None:
def key(v): return str(itemgetter(self._ordering)(v))
def key(v):
return str(itemgetter(self._ordering)(v))
else:
key = str
return sorted(value, key=key, reverse=self._order_reverse)

View File

@@ -43,6 +43,7 @@ class Role(object):
class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
last_used = DateTimeField()
@@ -52,7 +53,7 @@ class User(DbModelMixin, AuthDocument):
meta = {"db_alias": Database.auth, "strict": strict}
id = StringField(primary_key=True)
name = StringField(unique_with="company")
name = StringField()
created = DateTimeField()
""" User auth entry creation time """

View File

@@ -1,9 +1,9 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union
from typing import Collection, Sequence, Union, Optional
from boltons.iterutils import first
from boltons.iterutils import first, bucketize
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
@@ -34,7 +34,12 @@ class AuthDocument(Document):
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
def to_proper_dict(
self: Union["ProperDictMixin", Document],
strip_private=True,
only=None,
extra_dict=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
@@ -60,7 +65,7 @@ class ProperDictMixin(object):
class GetMixin(PropsMixin):
_text_score = "$text_score"
_projection_key = "projection"
_ordering_key = "order_by"
_search_text_key = "search_text"
@@ -71,6 +76,8 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
class QueryParameterOptions(object):
def __init__(
self,
@@ -91,11 +98,48 @@ class GetMixin(PropsMixin):
self.list_fields = list_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_default = "in"
_ops = {"not": "nin"}
_next = _default
def __init__(self, legacy=False):
self._legacy = legacy
def key(self, v):
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"]
elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
return None
next_ = self._next
self._next = self._default
return next_
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
get_all_query_options = QueryParameterOptions()
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
cls: Union["GetMixin", Document],
company,
id,
*,
_only=None,
include_public=False,
**kwargs,
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
@@ -162,17 +206,7 @@ class GetMixin(PropsMixin):
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
query &= cls.get_list_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@@ -216,6 +250,47 @@ class GetMixin(PropsMixin):
return query & RegexQ(**dict_query)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
is included.
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
)
if not allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
@@ -270,11 +345,26 @@ class GetMixin(PropsMixin):
return override_projection
if not parameters:
return []
return parameters.get("projection") or parameters.get("only_fields", [])
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def set_default_ordering(cls, parameters, value):
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
parameters[cls._projection_key] = value
return value
@classmethod
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
return parameters.get(cls._ordering_key)
@classmethod
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters[cls._ordering_key] = value
return value
@classmethod
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def get_many_with_join(
@@ -394,7 +484,12 @@ class GetMixin(PropsMixin):
)
@classmethod
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
parameters=None,
override_projection=None,
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -445,6 +540,8 @@ class GetMixin(PropsMixin):
"""
Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order.
If the first order field is a user defined parameter (either from execution.parameters,
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
@@ -485,6 +582,16 @@ class GetMixin(PropsMixin):
query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
@@ -578,7 +685,13 @@ class UpdateMixin(object):
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
def safe_update(
cls: Union["UpdateMixin", Document],
company_id,
id,
partial_update_dict,
injected_update=None,
):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
@@ -595,7 +708,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
cls: Union["DbModelMixin", Document],
pipeline: Sequence[dict],
allow_disk_use=None,
**kwargs,
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
@@ -610,7 +726,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
return cls.objects.aggregate(pipeline, **kwargs)
def validate_id(cls, company, **kwargs):
@@ -632,5 +748,5 @@ def validate_id(cls, company, **kwargs):
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
)

View File

@@ -1,8 +1,9 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from mongoengine import Document, StringField, DateTimeField, BooleanField
from database import Database, strict
from database.fields import StrippedStringField, SafeDictField
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
@@ -12,46 +13,61 @@ from database.model.user import User
class Model(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"parent",
"project",
"task",
("company", "name"),
("company", "user"),
{
'name': '%s.model.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$parent',
'$task',
'$project',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'parent': 5,
'task': 3,
'project': 3,
}
}
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
"default_language": "english",
"weights": {
"name": 10,
"id": 10,
"comment": 10,
"parent": 5,
"task": 3,
"project": 3,
},
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"user",
"project",
"task",
"parent",
),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field='Model', required=False)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default='', user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default="", user_set_allowed=True)
framework = StringField()
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)

View File

@@ -1,11 +1,14 @@
from mongoengine import MapField, IntField
from database.fields import NoneType, UnionField, SafeMapField
class ModelLabels(MapField):
class ModelLabels(SafeMapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
super(ModelLabels, self).__init__(
field=UnionField(types=(int, NoneType)), *args, **kwargs
)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
non_empty_values = list(filter(None, value.values()))
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
self.error("Same label id appears more than once in model labels")

View File

@@ -1,7 +1,7 @@
from mongoengine import StringField, DateTimeField, ListField
from mongoengine import StringField, DateTimeField
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import AttributedDocument
from database.model.base import GetMixin
@@ -17,12 +17,13 @@ class Project(AttributedDocument):
"db_alias": Database.backend,
"strict": strict,
"indexes": [
("company", "name"),
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
"default_language": "english",
"weights": {"name": 10, "id": 10, "description": 10},
}
},
],
}
@@ -35,7 +36,7 @@ class Project(AttributedDocument):
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True))
system_tags = ListField(StringField(required=True))
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()

View File

@@ -4,11 +4,10 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -7,6 +7,10 @@ from database import Database, strict
from database.model import DbModelMixin
class SettingKeys:
server__uuid = "server.uuid"
class Settings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
@@ -40,10 +44,6 @@ class Settings(DbModelMixin, Document):
""" Sets a new value or adds a new key/value setting (if key does not exist) """
key = key.strip(sep)
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
# if Settings.objects(key=key).only("key"):
#
# else:
# res = Settings(key=key, value=value).save()
return bool(res)
@classmethod
@@ -51,7 +51,7 @@ class Settings(DbModelMixin, Document):
""" Adds a new key/value settings. Fails if key already exists. """
key = key.strip(sep)
try:
res = Settings(key=key, value=value).save(force_insert=True)
res = cls(key=key, value=value).save(force_insert=True)
return bool(res)
except NotUniqueError:
return False

View File

@@ -1,10 +1,18 @@
from mongoengine import EmbeddedDocument, StringField, DynamicField
from mongoengine import (
EmbeddedDocument,
StringField,
DynamicField,
LongField,
EmbeddedDocumentField,
)
from database.fields import SafeMapField
class MetricEvent(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
'strict': False,
"strict": False,
}
metric = StringField(required=True)
@@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument):
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
max_value = DynamicField() # for backwards compatibility reasons
class EventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
last_update = LongField()
class MetricEventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
metric = StringField(required=True)
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))

View File

@@ -18,10 +18,11 @@ from database.fields import (
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.base import ProperDictMixin, GetMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
from .metrics import MetricEvent
from .metrics import MetricEvent, MetricEventStats
from .output import Output
DEFAULT_LAST_ITERATION = 0
@@ -66,10 +67,15 @@ class ArtifactTypeData(EmbeddedDocument):
data_hash = StringField()
class ArtifactModes:
input = "input"
output = "output"
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=("input", "output"), default="output")
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
uri = StringField()
hash = StringField()
content_size = LongField()
@@ -78,7 +84,7 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class Execution(EmbeddedDocument):
class Execution(EmbeddedDocument, ProperDictMixin):
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
@@ -94,9 +100,26 @@ class Execution(EmbeddedDocument):
class TaskType(object):
training = "training"
testing = "testing"
inference = "inference"
data_processing = "data_processing"
application = "application"
monitor = "monitor"
controller = "controller"
optimizer = "optimizer"
service = "service"
qc = "qc"
custom = "custom"
external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -104,6 +127,13 @@ class Task(AttributedDocument):
"created",
"started",
"completed",
"parent",
"project",
("company", "name"),
("company", "user"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
@@ -128,6 +158,12 @@ class Task(AttributedDocument):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
id = StringField(primary_key=True)
name = StrippedStringField(
@@ -146,13 +182,14 @@ class Task(AttributedDocument):
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
output: Output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))

View File

@@ -1,16 +1,17 @@
from mongoengine import Document, StringField
from mongoengine import Document, StringField, DynamicField
from database import Database, strict
from database.fields import SafeDictField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)
@@ -18,4 +19,4 @@ class User(DbModelMixin, Document):
family_name = StringField(user_set_allowed=True)
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = SafeDictField(default=dict, exclude_by_default=True)
preferences = DynamicField(default="", exclude_by_default=True)

View File

@@ -1,8 +1,14 @@
import copy
import re
from typing import Union
from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
from mongoengine.queryset.visitor import (
QueryCompilerVisitor,
SimplificationVisitor,
QCombination,
QNode,
)
class RegexWrapper(object):
@@ -17,17 +23,16 @@ class RegexWrapper(object):
class RegexMixin(object):
def to_query(self, document):
def to_query(self: Union["RegexMixin", QNode], document):
query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document))
return query
def _combine(self, other, operation):
def _combine(self: Union["RegexMixin", QNode], other, operation):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
if getattr(other, "empty", True):
return self
if self.empty:

View File

@@ -95,21 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
res[field] = None
continue
if desc:
if callable(desc):
desc(value)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc
):
raise ParseCallError(
"expecting %s" % desc.__name__, field=field
)
if issubclass(desc, Document) and not desc.objects(id=value).only(
"id"
):
if issubclass(desc, Document):
if not desc.objects(id=value).only("id"):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
elif callable(desc):
try:
desc(value)
except TypeError:
raise ParseCallError(f"expecting {desc.__name__}", field=field)
except Exception as ex:
raise ParseCallError(str(ex), field=field)
res[field] = value
return res

View File

@@ -10,7 +10,11 @@ from pathlib import Path
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
HERE = Path(__file__).parent
HERE = Path(__file__).resolve().parent
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount('http://', adapter)
def apply_mappings_to_host(host: str):
@@ -20,10 +24,6 @@ def apply_mappings_to_host(host: str):
es_server = host
url = f"{es_server}/_template/{f.stem}"
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount('http://', adapter)
session.delete(url)
r = session.post(
url,

View File

@@ -0,0 +1,27 @@
from furl import furl
from config import config
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
log = config.logger(__file__)
class MissingElasticConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
log.info(res)

View File

@@ -1,7 +1,7 @@
{
"template": "events-*",
"settings": {
"number_of_shards": 5
"number_of_shards": 1
},
"mappings": {
"_default_": {

View File

@@ -1,220 +0,0 @@
import importlib.util
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import attr
from furl import furl
from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from bll.queue import QueueBLL
from config import config
from database import Database
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
from database.model.user import User
from database.model.version import Version as DatabaseVersion
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
from service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
migration_dir = (Path(__file__) / "../../migration/mongodb").resolve()
class MissingElasticConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
log.info(res)
def _ensure_company():
company_id = config.get("apiserver.default_company")
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_auth_user(user_data, company_id):
ensure_credentials = {"key", "secret"}.issubset(user_data.keys())
if ensure_credentials:
user = AuthUser.objects(
credentials__match=Credentials(
key=user_data["key"], secret=user_data["secret"]
)
).first()
if user:
return user.id
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_data.get("id", f"__{user_data['name']}__"),
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
if ensure_credentials
else None,
)
user.save()
return user.id
def _ensure_user(user: FixedUser, company_id: str):
if User.objects(id=user.user_id).first():
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
_ensure_auth_user(user_data=data, company_id=company_id)
given_name, _, family_name = user.name.partition(" ")
User(
id=user.user_id,
company=company_id,
name=user.name,
given_name=given_name,
family_name=family_name,
).save()
def _apply_migrations():
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
migration_log = log.getChild("mongodb_migration")
for script_version in sorted(new_scripts.keys()):
script = new_scripts[script_version]
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
migration_log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
migration_log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError("Migration failed, aborting. Please restore backup.")
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
def init_mongo_data():
try:
_apply_migrations()
_ensure_uuid()
company_id = _ensure_company()
_ensure_default_queue(company_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id)
if FixedUser.enabled():
log.info("Fixed users mode is enabled")
for user in FixedUser.from_config():
try:
_ensure_user(user, company_id)
except Exception as ex:
log.error(f"Failed creating fixed user {user['name']}: {ex}")
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@@ -0,0 +1,65 @@
from pathlib import Path
from config import config
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__)
def init_mongo_data():
try:
empty_dbs = _apply_migrations(log)
_ensure_uuid()
company_id = _ensure_company(log)
_ensure_default_queue(company_id)
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
if not zip_file or not Path(zip_file).is_file():
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip(zip_file, user_id=user_id)
fixed_mode = FixedUser.enabled()
for user, credentials in config.get("secure.credentials", {}).items():
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
if fixed_mode:
log.info("Fixed users mode is enabled")
FixedUser.validate()
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, company_id, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@@ -0,0 +1,86 @@
import importlib.util
from datetime import datetime
from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from database import Database
from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool:
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
"""
log = log.getChild(Path(__file__).stem)
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
for script_version in sorted(new_scripts):
script = new_scripts[script_version]
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError(
"Migration failed, aborting. Please restore backup."
)
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
log.info("Finished mongodb migrations")
return empty_dbs

View File

@@ -0,0 +1,153 @@
import importlib
from collections import defaultdict
from datetime import datetime
from os.path import splitext
from typing import List, Optional, Any, Type, Set, Dict
from zipfile import ZipFile, ZIP_BZIP2
import mongoengine
from tqdm import tqdm
class PrePopulate:
@classmethod
def export_to_zip(
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
):
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
cls._export(zfile, experiments, projects)
@classmethod
def import_from_zip(cls, filename: str, user_id: str = None):
with ZipFile(filename) as zfile:
cls._import(zfile, user_id)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[List[str]]
) -> List[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items}
missing = ids - resolved
for name_candidate in missing:
results = list(cls.objects(name=name_candidate))
if not results:
print(f"ERROR: no match for `{name_candidate}`")
exit(1)
elif len(results) > 1:
print(f"ERROR: more than one match for `{name_candidate}`")
exit(1)
items.append(results[0])
return items
@classmethod
def _resolve_entities(
cls, experiments: List[str] = None, projects: List[str] = None
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
from database.model.project import Project
from database.model.task.task import Task
entities = defaultdict(set)
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...")
objs = Task.objects(
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments:
print("Reading experiments...")
entities[Task].update(cls._resolve_type(Task, experiments))
print("--> Reading experiments projects...")
objs = Project.objects(
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
)
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
return entities
@classmethod
def _cleanup_task(cls, task):
from database.model.task.task import TaskStatus
task.completed = None
task.started = None
if task.execution:
task.execution.model = None
task.execution.model_desc = None
task.execution.model_labels = None
if task.output:
task.output.model = None
task.status = TaskStatus.created
task.comment = "Auto generated by Allegro.ai"
task.created = datetime.utcnow()
task.last_iteration = 0
task.last_update = task.created
task.status_changed = task.created
task.status_message = ""
task.status_reason = ""
task.user = ""
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
from database.model.task.task import Task
if entity_cls == Task:
cls._cleanup_task(entity)
@classmethod
def _export(
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
):
entities = cls._resolve_entities(experiments, projects)
for cls_, items in entities.items():
if not items:
continue
filename = f"{cls_.__module__}.{cls_.__name__}.json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with writer.open(filename, "w") as f:
f.write("[\n".encode("utf-8"))
last = len(items) - 1
for i, item in enumerate(items):
cls._cleanup_entity(cls_, item)
f.write(item.to_json().encode("utf-8"))
if i != last:
f.write(",".encode("utf-8"))
f.write("\n".encode("utf-8"))
f.write("]\n".encode("utf-8"))
@staticmethod
def _import(reader: ZipFile, user_id: str = None):
for file_info in reader.filelist:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
with reader.open(file_info) as f:
for item in tqdm(
f.readlines(),
desc=f"Writing {cls_.__name__.lower()}s into database",
unit="doc",
):
item = (
item.decode("utf-8")
.strip()
.lstrip("[")
.rstrip("]")
.rstrip(",")
.strip()
)
if not item:
continue
doc = cls_.from_json(item)
if user_id is not None and hasattr(doc, "user"):
doc.user = user_id
doc.save(force_insert=True)

View File

@@ -0,0 +1,79 @@
from datetime import datetime
from logging import Logger
import attr
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.user import User
from service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
ensure_credentials = {"key", "secret"}.issubset(user_data)
if ensure_credentials:
user = AuthUser.objects(
credentials__match=Credentials(
key=user_data["key"], secret=user_data["secret"]
)
).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
user_id = user_data.get("id", f"__{user_data['name']}__")
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])] if not revoke else []
if ensure_credentials
else None,
)
user.save()
return user.id
def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
given_name, _, family_name = user_name.partition(" ")
User(
id=user_id,
company=company_id,
name=user_name,
given_name=given_name,
family_name=family_name,
).save()
return user_id
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
if User.objects(id=user.user_id).first():
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
given_name, _, family_name = user.name.partition(" ")
User(
id=user.user_id,
company=company_id,
name=user.name,
given_name=given_name,
family_name=family_name,
).save()

View File

@@ -0,0 +1,40 @@
from logging import Logger
from uuid import uuid4
from bll.queue import QueueBLL
from config import config
from config.info import get_default_company
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings, SettingKeys
log = config.logger(__file__)
def _ensure_company(log: Logger):
company_id = get_default_company()
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_uuid():
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))

View File

@@ -0,0 +1,20 @@
import json
from pymongo.database import Database, Collection
def migrate_auth(db: Database):
collection: Collection = db["user"]
if "name_1_company_1" in [doc["name"] for doc in collection.list_indexes()]:
collection.drop_index("name_1_company_1")
def migrate_backend(db: Database):
collection: Collection = db["user"]
users = collection.find(
{"preferences": {"$exists": True, "$ne": None, "$type": "object"}}
)
for doc in users:
collection.update_one(
{"_id": doc["_id"]}, {"$set": {"preferences": json.dumps(doc["preferences"])}}
)

View File

@@ -0,0 +1,46 @@
import hashlib
from pymongo.database import Database, Collection
from service_repo.auth.fixed_user import FixedUser
def _get_ids():
if not FixedUser.enabled():
return
return {
hashlib.md5(f"{user.username}:{user.password}".encode()).hexdigest(): user.user_id
for user in FixedUser.from_config()
}
def _switch_uuid(collection: Collection, uuid_field: str, uuids: dict):
docs = list(collection.find({uuid_field: {"$in": [uuids]}}))
if not docs:
return
replaced_uuids = [doc[uuid_field] for doc in docs]
for doc in docs:
doc[uuid_field] = uuids[doc[uuid_field]]
collection.insert_many(docs)
collection.delete_many({uuid_field: {"$in": replaced_uuids}})
def migrate_auth(db: Database):
uuids = _get_ids()
if not uuids:
return
collection = db["user"]
collection.drop_index("name_1_company_1")
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)
def migrate_backend(db: Database):
uuids = _get_ids()
if not uuids:
return
for name in ("project", "task", "model"):
_switch_uuid(collection=db[name], uuid_field="user", uuids=uuids)

View File

@@ -0,0 +1,58 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
def migrate_auth(db: Database):
"""
Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
_drop_all_indices_from_collections(db, ["user"])
def migrate_backend(db: Database):
"""
1. Sort tags and system tags
2. Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
fields = ("tags", "system_tags")
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
for collection_name in ("task", "model", "project", "queue"):
collection = db[collection_name]
for doc in collection.find(filter=query, projection=fields):
update = {
field: sorted(doc[field])
for field in fields
if doc.get(field)
}
if update:
collection.update_one({"_id": doc["_id"]}, {"$set": update})
_drop_all_indices_from_collections(
db,
[
"company",
"model",
"project",
"queue",
"settings",
"task",
"task__trash",
"user",
"versions",
],
)

View File

@@ -1,31 +1,30 @@
six
Flask>=0.12.2
elasticsearch>=5.0.0,<6.0.0
pyhocon>=0.3.35
requests>=2.13.0
arrow>=0.10.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
Flask-Cors>=3.0.5
Flask-Compress>=1.4.0
mongoengine==0.16.2
jsonmodels>=2.3
pyjwt>=1.3.0
gunicorn>=19.7.1
Jinja2==2.10
python-rapidjson>=0.6.3
jsonschema>=2.6.0
dpath>=1.4.2
funcsigs==1.0.2
luqum>=0.7.2
typing>=3.6.4
attrs>=19.1.0
nested_dict>=1.61
related>=0.7.2
validators>=0.12.4
fastjsonschema>=2.8
boltons>=19.1.0
semantic_version>=2.6.0,<3
dpath>=1.4.2,<2.0
elasticsearch>=5.0.0,<6.0.0
fastjsonschema>=2.8
Flask-Compress>=1.4.0
Flask-Cors>=3.0.5
Flask>=0.12.2
funcsigs==1.0.2
furl>=2.0.0
redis>=2.10.5
gunicorn>=19.7.1
humanfriendly==4.18
Jinja2==2.10
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.7.2
mongoengine==0.19.1
nested_dict>=1.61
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt>=1.3.0
pymongo==3.10.1
python-rapidjson>=0.6.3
redis>=2.10.5
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.0,<3
six
tqdm
validators>=0.12.4

View File

@@ -171,6 +171,30 @@
critical
]
}
event_type_enum {
type: string
enum: [
training_stats_scalar
training_stats_vector
training_debug_image
plot
log
]
}
task_metric {
type: object
required: [task, metric]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_log_event {
description: """A log event associated with a task."""
type: object
@@ -234,6 +258,7 @@
properties {
added { type: integer }
errors { type: integer }
errors_info { type: object }
}
}
}
@@ -319,6 +344,84 @@
}
}
}
"2.7" {
description: "Get the debug image events for the requested amount of iterations per each task's metric"
request {
type: object
required: [
metrics
]
properties {
metrics {
type: array
items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived"
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest iterations. The default is False"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
metrics {
type: array
items: { type: object }
description: "Debug image events grouped by task metrics and iterations"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_task_metrics{
"2.7": {
description: "For each task, get a list of metrics for which the requested event type was reported"
request {
type: object
required: [
tasks
]
properties {
tasks {
type: array
items { type: string }
description: "Task IDs"
}
event_type {
"description": "Event type"
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
items { type: object }
description: "List of task with their metrics"
}
}
}
}
}
get_task_log {
"1.5" {
@@ -427,6 +530,59 @@
}
}
}
// "2.7" {
// description: "Get 'log' events for this task"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// type: string
// description: "Task ID"
// }
// batch_size {
// type: integer
// description: "The amount of log events to return"
// }
// navigate_earlier {
// type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// }
// refresh {
// type: boolean
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
// }
// scroll_id {
// type: string
// description: "Scroll ID of previous call (used for getting more results)"
// }
// }
// }
// response {
// type: object
// properties {
// events {
// type: array
// items { type: object }
// description: "Log items list"
// }
// returned {
// type: integer
// description: "Number of log events returned"
// }
// total {
// type: number
// description: "Total number of log events available for this query"
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
}
get_task_events {
"2.1" {
@@ -455,7 +611,7 @@
}
batch_size {
type: integer
description: "Number of events to return each time"
description: "Number of events to return each time (default 500)"
}
event_type {
type: string

View File

@@ -159,6 +159,11 @@
description: "Get only models whose name matches this pattern (python regular expression syntax)"
type: string
}
user {
description: "List of user IDs used to filter results by the model's creating user"
type: array
items { type: string }
}
ready {
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
type: boolean
@@ -261,7 +266,7 @@
type: string
}
uri {
description: "URI for the model"
description: "URI for the model. Exactly one of uri or override_model_id is a required."
type: string
}
name {
@@ -283,7 +288,7 @@
items {type: string}
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task."
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
type: string
}
iteration {
@@ -324,7 +329,6 @@
required: [
uri
name
labels
]
properties {
uri {

View File

@@ -0,0 +1,43 @@
_description: "This service provides organization level operations"
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}

View File

@@ -69,6 +69,17 @@ info {
}
}
}
"2.8": ${info."2.1"} {
response {
type: object
properties {
uid {
description: "Server UID"
type: string
}
}
}
}
}
endpoints {
"2.1" {
@@ -86,6 +97,7 @@ endpoints {
}
}
report_stats_option {
allow_roles = [ "*" ]
"2.4" {
description: "Get or set the report statistics option per-company"
request {
@@ -117,6 +129,10 @@ report_stats_option {
description: "If enabled, returns Id of the user who enabled the option"
type: string
}
current_version {
description: "Returns the current server version"
type: string
}
}
}
}

View File

@@ -254,6 +254,15 @@ _definitions {
enum: [
training
testing
inference
data_processing
application
monitor
controller
optimizer
service
qc
custom
]
}
last_metrics_event {
@@ -475,7 +484,11 @@ get_all {
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
description: """List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned documents.
Use '-' prefix to specify descending order. Optional, recommended when using page.
If the first order field is a hyper parameter or metric then string values are ordered
according to numeric ordering rules where applicable"""
type: array
items { type: string }
}
@@ -550,6 +563,89 @@ get_all {
}
}
}
get_types {
"2.8" {
description: "Get the list of task types used in the specified projects"
request {
type: object
properties {
projects {
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
type: array
items: {type: string}
}
}
}
response {
type: object
properties {
types {
description: "Unique list of the task types used in the requested projects"
type: array
items: {type: string}
}
}
}
}
}
clone {
"2.5" {
description: "Clone an existing task"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task"
type: string
}
new_task_name {
description: "The name of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_comment {
description: "The comment of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_tags {
description: "The user-defined tags of the cloned task. If not provided then taken from the original task"
type: array
items { type: string }
}
new_task_system_tags {
description: "The system tags of the cloned task. If not provided then empty"
type: array
items { type: string }
}
new_task_parent {
description: "The parent of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_project {
description: "The project of the cloned task. If not provided then taken from the original task"
type: string
}
execution_overrides {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
}
validate_references {
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
type: boolean
}
}
}
response {
type: object
properties {
id {
description: "ID of the new task"
type: string
}
}
}
}
}
create {
"2.1" {
description: "Create a new task"
@@ -847,6 +943,11 @@ reset {
properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
properties.clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
} ${_references.status_change_request}
response {
type: object
@@ -1304,4 +1405,40 @@ ping {
additionalProperties: false
}
}
}
add_or_update_artifacts {
"2.6" {
description: """ Update an existing artifact (search by key/mode) or add a new one """
request {
type: object
required: [ task, artifacts ]
properties {
task {
description: "Task ID"
type: string
}
artifacts {
description: "Artifacts to add or update"
type: array
items { "$ref": "#/definitions/artifact" }
}
}
}
response {
type: object
properties {
added {
description: "Keys of artifacts added"
type: array
items { type: string }
}
updated {
description: "Keys of artifacts updated"
type: array
items { type: string }
}
}
}
}
}

View File

@@ -145,6 +145,19 @@ get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
}
"2.8": ${get_all."2.1"} {
request {
type: object
properties {
active_in_projects {
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
type: array
items { type: string }
}
}
}
}
}
get_all {

View File

@@ -1,3 +1,4 @@
import atexit
from argparse import ArgumentParser
from flask import Flask, request, Response
@@ -9,13 +10,15 @@ import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from init_data import init_es_data, init_mongo_data
from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
from utilities.threads_manager import ThreadsManager
app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors"))
@@ -41,6 +44,13 @@ check_updates_thread.start()
StatisticsReporter.start()
def graceful_shutdown():
ThreadsManager.terminating = True
atexit.register(graceful_shutdown)
@app.before_first_request
def before_app_first_request():
pass

View File

@@ -21,6 +21,8 @@ JSON_CONTENT_TYPE = "application/json"
class DataContainer(object):
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
null_schema_validator: SchemaValidator = SchemaValidator(None)
def __init__(self, data=None, batched_data=None):
if data and batched_data:
raise ValueError("data and batched data are not supported simultaneously")
@@ -28,7 +30,7 @@ class DataContainer(object):
self._data = None
self._data_model = None
self._data_model_cls = None
self._schema_validator: SchemaValidator = SchemaValidator(None)
self._schema_validator: SchemaValidator = self.null_schema_validator
# use setter to properly initialize data
self.data = data
self.batched_data = batched_data

View File

@@ -5,27 +5,45 @@ from typing import Sequence, TypeVar
import attr
from config import config
from config.info import get_default_company
T = TypeVar("T", bound="FixedUser")
class FixedUsersError(Exception):
pass
@attr.s(auto_attribs=True)
class FixedUser:
username: str
password: str
name: str
company: str = get_default_company()
def __attrs_post_init__(self):
self.user_id = hashlib.md5(f"{self.username}:{self.password}".encode()).hexdigest()
self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest()
@classmethod
def enabled(cls):
return config.get("apiserver.auth.fixed_users.enabled", False)
@classmethod
def validate(cls):
if not cls.enabled():
return
users = cls.from_config()
if len({user.username for user in users}) < len(users):
raise FixedUsersError(
"Duplicate user names found in fixed users configuration"
)
@classmethod
@lru_cache()
def from_config(cls) -> Sequence[T]:
return [cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])]
return [
cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])
]
@classmethod
@lru_cache()

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import Callable, Sequence, Text
from boltons.iterutils import remap
from jsonmodels import models
from jsonmodels.errors import FieldNotSupported
@@ -87,7 +88,14 @@ class Endpoint(object):
Provided data_model schema if available
"""
try:
return data_model.to_json_schema()
res = data_model.to_json_schema()
def visit(path, key, value):
if isinstance(value, Enum):
value = str(value)
return key, value
return remap(res, visit=visit)
except (FieldNotSupported, TypeError):
return str(data_model.__name__)

View File

@@ -9,6 +9,7 @@ import jsonmodels.models
import timing_context
from apierrors import APIError
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
from api_version import __version__ as _api_version_
from config import config
from service_repo.base import PartialVersion
from .apicall import APICall
@@ -34,7 +35,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.4")
_max_version = PartialVersion(".".join(_api_version_.split(".")[:2]))
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -166,7 +167,7 @@ class ServiceRepo(object):
return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version: PartialVersion = endpoint.min_version
call.actual_endpoint_version = endpoint.min_version
call.requires_authorization = endpoint.authorize
return endpoint

View File

@@ -52,7 +52,7 @@ def validate_all(call: APICall, endpoint: Endpoint):
def validate_role(endpoint, call):
try:
if not endpoint.allows(call.identity.role):
if endpoint.authorize and not endpoint.allows(call.identity.role):
raise errors.forbidden.RoleNotAllowed(role=call.identity.role, allowed=endpoint.allow_roles)
except MissingIdentity:
pass

View File

@@ -2,12 +2,16 @@ import itertools
from collections import defaultdict
from operator import itemgetter
import six
from apierrors import errors
from apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@@ -23,10 +27,10 @@ event_bll = EventBLL()
def add(call: APICall, company_id, req_model):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, batch_errors = event_bll.add_events(
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@@ -36,13 +40,13 @@ def add_batch(call: APICall, company_id, req_model):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=len(batch_errors))
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log(call, company_id, req_model):
def get_task_log_v1_5(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
@@ -90,6 +94,29 @@ def get_task_log_v1_7(call, company_id, req_model):
)
# uncomment this once the front end is ready
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
# def get_task_log(call, company_id, req_model: LogEventsRequest):
# task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
#
# res = event_bll.log_events_iterator.get_task_events(
# company_id=company_id,
# task_id=task_id,
# batch_size=req_model.batch_size,
# navigate_earlier=req_model.navigate_earlier,
# refresh=req_model.refresh,
# state_id=req_model.scroll_id,
# )
#
# call.result.data = dict(
# events=res.events,
# returned=len(res.events),
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
task_id = call.data["task"]
@@ -211,7 +238,7 @@ def vector_metrics_iter_histogram(call, company_id, req_model):
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size")
batch_size = call.data.get("batch_size", 500)
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
@@ -299,7 +326,7 @@ def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
if isinstance(task_ids, six.string_types):
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
@@ -481,7 +508,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images(call, company_id, req_model):
def get_debug_images_v1_8(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@@ -507,6 +534,53 @@ def get_debug_images(call, company_id, req_model):
)
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
tasks = set(m.task for m in req_model.metrics)
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company_id,
metrics=[(m.task, m.metric) for m in req_model.metrics],
iter_count=req_model.iters,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
)
call.result.data_model = DebugImageResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
],
)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
task_bll.assert_exists(
call.identity.company, task_ids=req_model.tasks, allow_public=True
)
res = event_bll.metrics.get_tasks_metrics(
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]

View File

@@ -12,6 +12,7 @@ from apimodels.models import (
PublishModelResponse,
ModelTaskPublishResponse,
)
from bll.organization import OrgBLL
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@@ -29,51 +30,34 @@ from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"project",
"task",
"parent",
),
)
org_bll = OrgBLL()
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call):
assert isinstance(call, APICall)
def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=call.identity.company)
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["output"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -84,13 +68,11 @@ def get_by_task_id(call):
model_id = task.output.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
@@ -98,31 +80,27 @@ def get_by_task_id(call):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@@ -146,13 +124,18 @@ create_fields = {
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call, company_id, _):
assert isinstance(call, APICall)
def update_for_task(call: APICall, company_id, _):
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
@@ -195,7 +178,9 @@ def update_for_task(call, company_id, _):
if task.output and task.output.model:
# model exists, update
res = _update_model(call, model_id=task.output.model).to_struct()
res = _update_model(
call, company_id, model_id=task.output.model
).to_struct()
res.update({"id": task.output.model, "created": False})
call.result.data = res
return
@@ -218,6 +203,7 @@ def update_for_task(call, company_id, _):
**fields,
)
model.save()
_update_org_tags(company_id, fields)
TaskBLL.update_statistics(
task_id=task_id,
@@ -234,48 +220,46 @@ def update_for_task(call, company_id, _):
request_data_model=CreateModelRequest,
response_data_model=CreateModelResponse,
)
def create(call, company, req_model):
assert isinstance(call, APICall)
assert isinstance(req_model, CreateModelRequest)
identity = call.identity
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company = ""
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company, project=project)
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(call, req_data)
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
# create and save model
model = Model(
id=database.utils.id(),
user=identity.user,
company=company,
user=call.identity.user,
company=company_id,
created=datetime.utcnow(),
**fields,
)
model.save()
_update_org_tags(company_id, fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, fields):
def prepare_update_fields(call, company_id, fields: dict):
fields = fields.copy()
if "uri" in fields:
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(call, fields)
validate_task(company_id, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -290,33 +274,36 @@ def prepare_update_fields(call, fields):
invalid_keys = find_other_types(labels.keys(), str)
if invalid_keys:
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys)
raise errors.bad_request.ValidationError(
"labels keys must be strings", keys=invalid_keys
)
invalid_values = find_other_types(labels.values(), int)
if invalid_values:
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values)
raise errors.bad_request.ValidationError(
"labels values must be integers", values=invalid_values
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def validate_task(call, fields):
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall):
identity = call.identity
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
@@ -331,47 +318,44 @@ def edit(call: APICall):
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get('task')
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
updated = model.update(upsert=False, **fields)
if updated:
_update_org_tags(company_id, fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, model_id=None):
identity = call.identity
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
data = prepare_update_fields(call, call.data)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(
call.identity.company, model.id, data
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
_update_org_tags(company_id, updated_fields)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -379,8 +363,8 @@ def _update_model(call: APICall, model_id=None):
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call):
call.result.data_model = _update_model(call)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint(
@@ -388,31 +372,29 @@ def update(call):
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company, req_model: PublishModelRequest):
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready(
model_id=req_model.model,
company_id=company,
company_id=company_id,
publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task
force_publish_task=req_model.force_publish_task,
)
call.result.data_model = PublishModelResponse(
updated=updated,
published_task=ModelTaskPublishResponse(
**published_task_data
) if published_task_data else None
published_task=ModelTaskPublishResponse(**published_task_data)
if published_task_data
else None,
)
@endpoint("models.delete", required_fields=["model"])
def update(call):
assert isinstance(call, APICall)
identity = call.identity
def update(call: APICall, company_id, _):
model_id = call.data["model"]
force = call.data.get("force", False)
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
@@ -445,4 +427,6 @@ def update(call):
)
del_count = Model.objects(**query).delete()
if del_count:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=del_count > 0)

View File

@@ -0,0 +1,13 @@
from apimodels.organization import TagsRequest
from bll.organization import OrgBLL
from service_repo import endpoint, APICall
org_bll = OrgBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
def get_tags(call: APICall, company, request: TagsRequest):
filter_ = request.filter.system_tags if request.filter else None
call.result.data = org_bll.get_tags(
company, include_system=request.include_system, filter_=filter_
)

View File

@@ -33,8 +33,7 @@ create_fields = {
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
)
@@ -58,10 +57,10 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(project_ids, specific_state=None):
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def ensure_system_tags():
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
@@ -73,14 +72,20 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
"then": [],
"else": "$system_tags",
}
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
ensure_system_tags(),
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
@@ -149,11 +154,12 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing", "annotation"]},
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_system_tags(),
ensure_valid_fields(),
{
# for each project
"$group": group_step
@@ -192,7 +198,7 @@ def get_all_ex(call: APICall):
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
ids, specific_state=specific_state
call.identity.company, ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@@ -202,7 +208,7 @@ def get_all_ex(call: APICall):
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(*status_count_pipeline):
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
@@ -216,7 +222,7 @@ def get_all_ex(call: APICall):
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(*runtime_pipeline)
for result in Task.aggregate(runtime_pipeline)
}
def safe_get(obj, path, default=None):
@@ -268,7 +274,7 @@ def create(call):
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -305,7 +311,7 @@ def update(call: APICall):
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)

View File

@@ -58,7 +58,9 @@ def get_all(call: APICall):
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
)
@@ -73,7 +75,7 @@ def create(call: APICall, company_id, request: CreateRequest):
)
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data)
conform_tag_fields(call, data, validate=True)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
@@ -212,7 +214,9 @@ def get_queue_metrics(
dates=data["date"],
avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"],
) if data else QueueMetrics(queue=queue)
)
if data
else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items()
]
)

View File

@@ -10,8 +10,8 @@ from config.info import get_version, get_build_number, get_commit_number
from database.errors import translate_errors_context
from database.model import Company
from database.model.company import ReportStatsOption
from database.model.settings import Settings, SettingKeys
from service_repo import ServiceRepo, APICall, endpoint
from version import __version__ as current_version
@endpoint("server.get_stats")
@@ -61,6 +61,12 @@ def info(call: APICall):
}
@endpoint("server.info", min_version="2.8")
def info_2_8(call: APICall):
info(call)
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
@endpoint(
"server.report_stats_option",
request_data_model=ReportStatsOptionRequest,
@@ -79,7 +85,7 @@ def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest)
stats_option = ReportStatsOption(
enabled=enabled,
enabled_time=datetime.utcnow(),
enabled_version=current_version,
enabled_version=get_version(),
enabled_user=call.identity.user,
)
updated = query.update(defaults__stats_option=stats_option)
@@ -87,7 +93,8 @@ def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest)
raise errors.server_error.InternalError(
f"Failed setting report_stats to {enabled}"
)
result = ReportStatsOptionResponse(**stats_option.to_mongo())
data = stats_option.to_mongo()
data["current_version"] = get_version()
result = ReportStatsOptionResponse(**data)
call.result.data_model = result

View File

@@ -1,18 +1,17 @@
from copy import deepcopy
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
import attr
import dpath
import mongoengine
import six
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
from apierrors import errors, APIError
from apimodels.base import UpdateResponse
from apimodels.base import UpdateResponse, IdResponse
from apimodels.tasks import (
StartedResponse,
ResetResponse,
@@ -27,10 +26,23 @@ from apimodels.tasks import (
EnqueueRequest,
EnqueueResponse,
DequeueResponse,
CloneRequest,
AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL
from bll.queue import QueueBLL
from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by
from bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
split_by,
ParameterKeyEscaper,
)
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
@@ -50,19 +62,13 @@ from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()
org_bll = OrgBLL()
TaskBLL.start_non_responsive_tasks_watchdog()
NonResponsiveTasksWatchdog.start()
def set_task_status_from_call(
@@ -94,41 +100,79 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True
)
task_dict = task.to_proper_dict()
conform_output_tags(call, task_dict)
unprepare_from_saved(call, task_dict)
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
escaped_paths = []
for path in paths:
if path == prefix:
raise errors.bad_request.ValidationError(
"invalid task field", path=path
)
escaped_paths.append(
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
)
return escaped_paths
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data)
if ordering:
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
conform_output_tags(call, tasks)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
conform_output_tags(call, tasks)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
def get_types(call: APICall, company_id, request: GetTypesRequest):
call.result.data = {
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
}
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
)
@@ -220,6 +264,45 @@ create_fields = {
}
def prepare_for_save(call: APICall, fields: dict):
conform_tag_fields(call, fields, validate=True)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
conform_output_tags(call, tasks_data)
for task_data in tasks_data:
parameters = safe_get(task_data, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
}
dpath.set(task_data, "execution/parameters", parameters)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None
):
@@ -239,28 +322,10 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
conform_tag_fields(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = "script/%s" % field
value = dpath.get(fields, path)
if isinstance(value, six.string_types):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
parameters = {k.strip(): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
return prepare_for_save(call, fields)
def _validate_and_get_task_from_call(call: APICall, **kwargs):
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
@@ -270,7 +335,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
with TimingContext("code", "validate"):
task_bll.validate(task)
return task
return task, fields
@endpoint("tasks.validate", request_data_model=CreateRequest)
@@ -278,15 +343,44 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
_validate_and_get_task_from_call(call)
@endpoint("tasks.create", request_data_model=CreateRequest)
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
task = _validate_and_get_task_from_call(call)
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
task.save()
_update_org_tags(company_id, fields)
update_project_time(task.project)
call.result.data = {"id": task.id}
call.result.data_model = IdResponse(id=task.id)
@endpoint(
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
)
def clone_task(call: APICall, company_id, request: CloneRequest):
task = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
name=request.new_task_name,
comment=request.new_task_comment,
parent=request.new_task_parent,
project=request.new_task_project,
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
)
call.result.data_model = IdResponse(id=task.id)
def prepare_update_fields(call: APICall, task, call_data):
@@ -296,8 +390,7 @@ def prepare_update_fields(call: APICall, task, call_data):
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
conform_tag_fields(call, fields)
return fields, valid_fields
return prepare_for_save(call, fields), valid_fields
@endpoint(
@@ -322,9 +415,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
partial_update_dict=partial_update_dict,
injected_update=dict(last_update=datetime.utcnow()),
)
update_project_time(updated_fields.get("project"))
conform_output_tags(call, updated_fields)
if updated_count:
_update_org_tags(company_id, updated_fields)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -355,9 +449,7 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
@endpoint("tasks.update_batch")
def update_batch(call: APICall):
identity = call.identity
def update_batch(call: APICall, company_id, _):
items = call.batched_data
if items is None:
raise errors.bad_request.BatchContainsNoItems()
@@ -367,7 +459,7 @@ def update_batch(call: APICall):
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=identity.company, query=Q(id__in=list(items))
company=company_id, query=Q(id__in=list(items))
)
}
@@ -385,7 +477,7 @@ def update_batch(call: APICall):
continue
partial_update_dict.update(last_update=now)
update_op = UpdateOne(
{"_id": id, "company": identity.company}, {"$set": partial_update_dict}
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
bulk_ops.append(update_op)
@@ -393,7 +485,8 @@ def update_batch(call: APICall):
if bulk_ops:
res = Task._get_collection().bulk_write(bulk_ops)
updated = res.modified_count
if updated:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = {"updated": updated}
@@ -448,8 +541,10 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fields.update(last_update=now)
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project"))
conform_output_tags(call, fields)
if updated:
_update_org_tags(company_id, fixed_fields)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
@@ -575,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
@endpoint(
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, req_model: UpdateRequest):
def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, requires_write_access=True
request.task, company_id=company_id, requires_write_access=True
)
force = req_model.force
force = request.force
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@@ -598,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
else:
if dequeued:
api_results.update(dequeued=dequeued)
updates.update(unset__execution__queue=1)
cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up))
@@ -606,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
unset__output__result=1,
unset__output__model=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if request.clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
updates.update(
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
res = ResetResponse(
**ChangeStatusRequest(
task=task,
@@ -674,8 +782,7 @@ class CleanupResult(object):
deleted_models = attr.ib(type=int)
def cleanup_task(task, force=False):
# type: (Task, bool) -> CleanupResult
def cleanup_task(task: Task, force: bool = False):
"""
Validate task deletion and delete/modify all its output.
:param task: task object
@@ -702,7 +809,7 @@ def cleanup_task(task, force=False):
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
@@ -733,6 +840,15 @@ def get_outputs_for_deletion(task, force=False):
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = [m.id for m in models.draft]
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
)
busy_models = [t.execution.model for t in dependent_tasks]
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
with TimingContext("mongo", "get_task_children"):
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
published_tasks = [
@@ -793,7 +909,7 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
task.switch_collection(collection_name)
task.delete()
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=True, **attr.asdict(result))
@@ -837,3 +953,18 @@ def ping(_, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
)
@endpoint(
"tasks.add_or_update_artifacts",
min_version="2.6",
request_data_model=AddOrUpdateArtifactsRequest,
response_data_model=AddOrUpdateArtifactsResponse,
)
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
added, updated = TaskBLL.add_or_update_artifacts(
task_id=request.task, company_id=company_id, artifacts=request.artifacts
)
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)

View File

@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, Tuple
from typing import Tuple
import dpath
from boltons.iterutils import remap
@@ -7,10 +7,8 @@ from mongoengine import Q
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.users import (
CreateRequest,
SetPreferencesRequest,
)
from apimodels.users import CreateRequest, SetPreferencesRequest
from bll.project import ProjectBLL
from bll.user import UserBLL
from config import config
from database.errors import translate_errors_context
@@ -19,12 +17,13 @@ from database.model.company import Company
from database.model.user import User
from database.utils import parse_from_call
from service_repo import APICall, endpoint
from utilities.json import loads, dumps
log = config.logger(__file__)
get_all_query_options = User.QueryParameterOptions(list_fields=("id",))
project_bll = ProjectBLL()
def get_user(call, user_id, only=None):
def get_user(call, company_id, user_id, only=None):
"""
Get user object by the user's ID
:param call: API call
@@ -36,7 +35,7 @@ def get_user(call, user_id, only=None):
# allow system users to get info for all users
query = dict(id=user_id)
else:
query = dict(id=user_id, company=call.identity.company)
query = dict(id=user_id, company=company_id)
with translate_errors_context("retrieving user"):
user = User.objects(**query)
@@ -50,47 +49,53 @@ def get_user(call, user_id, only=None):
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
call.result.data = {"user": get_user(call, user_id)}
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
def get_all_ex(call):
assert isinstance(call, APICall)
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
)
res = User.get_many_with_join(company=company_id, query_dict=call.data)
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
active_in_projects = call.data.get("active_in_projects", None)
if active_in_projects is not None:
active_users = project_bll.get_active_users(
company_id, active_in_projects, call.data.get("id")
)
active_users.discard(None)
if not active_users:
call.result.data = {"users": []}
return
data = data.copy()
data["id"] = list(active_users)
res = User.get_many_with_join(company=company_id, query_dict=data)
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
def get_all(call):
assert isinstance(call, APICall)
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
company=call.identity.company,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
company=company_id, parameters=call.data, query_dict=call.data
)
call.result.data = {"users": res}
@endpoint("users.get_current_user")
def get_current_user(call):
assert isinstance(call, APICall)
def get_current_user(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
projection = (
{"company.name"}
.union(User.get_fields())
@@ -98,7 +103,7 @@ def get_current_user(call):
)
res = User.get_many_with_join(
query=Q(id=call.identity.user),
company=call.identity.company,
company=company_id,
override_projection=projection,
)
@@ -128,13 +133,11 @@ def create(call: APICall):
@endpoint("users.delete", required_fields=["user"])
def delete(call):
assert isinstance(call, APICall)
def delete(call: APICall):
UserBLL.delete(call.data["user"])
def update_user(user_id, company_id, data):
# type: (str, str, Dict) -> Tuple[int, Dict]
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
"""
Update user.
:param user_id: user ID to update
@@ -152,30 +155,29 @@ def update_user(user_id, company_id, data):
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
assert isinstance(call, APICall)
user_id = call.data["user"]
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
def get_user_preferences(call):
def get_user_preferences(call: APICall, company_id):
user_id = call.identity.user
return get_user(call, user_id, ["preferences"]).get("preferences", {})
preferences = get_user(call, company_id, user_id, only=["preferences"]).get(
"preferences"
)
if preferences and isinstance(preferences, str):
preferences = loads(preferences)
return preferences or {}
@endpoint("users.get_preferences")
def get_preferences(call):
assert isinstance(call, APICall)
return {"preferences": get_user_preferences(call)}
def get_preferences(call: APICall, company_id, _):
return {"preferences": get_user_preferences(call, company_id)}
@endpoint(
"users.set_preferences", request_data_model=SetPreferencesRequest
)
def set_preferences(call, company_id, req_model):
# type: (APICall, str, SetPreferencesRequest) -> Dict
assert isinstance(call, APICall)
changes = req_model.preferences
@endpoint("users.set_preferences", request_data_model=SetPreferencesRequest)
def set_preferences(call: APICall, company_id, request: SetPreferencesRequest):
changes = request.preferences
def invalid_key(_, key, __):
if not isinstance(key, str):
@@ -188,7 +190,7 @@ def set_preferences(call, company_id, req_model):
remap(changes, visit=invalid_key)
base_preferences = get_user_preferences(call)
base_preferences = get_user_preferences(call, company_id)
new_preferences = deepcopy(base_preferences)
for key, value in changes.items():
try:
@@ -205,9 +207,11 @@ def set_preferences(call, company_id, req_model):
updated, fields = 0, {}
else:
with translate_errors_context("updating user preferences"):
fields = dict(preferences=new_preferences)
updated = User.objects(id=call.identity.user, company=company_id).update(
upsert=False, **fields
upsert=False, preferences=dumps(new_preferences)
)
return {"updated": updated, "fields": fields if updated else {}}
return {
"updated": updated,
"fields": {"preferences": new_preferences} if updated else {},
}

View File

@@ -1,5 +1,7 @@
from typing import Union, Sequence, Tuple
from apierrors import errors
from database.model.base import GetMixin
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
@@ -19,13 +21,13 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tag_fields(call: APICall, document: dict):
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags")
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
@@ -34,16 +36,18 @@ def conform_tag_fields(call: APICall, document: dict):
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return _get_unique_values(tags), _get_unique_values(system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
@@ -55,9 +59,12 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
return tags, system_tags
def _get_unique_values(values: Sequence) -> Sequence:
"""Get unique values from the given sequence"""
if not values:
return values
return list(set(values))
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)

View File

@@ -54,6 +54,10 @@ class TestService(TestCase, TestServiceInterface):
)
return object_id
@staticmethod
def update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
return self._create_temp_helper(
service=service,

View File

@@ -1,14 +1,14 @@
import operator
from time import sleep
from typing import Sequence
from typing import Sequence, Mapping
from tests.automated import TestService
class TestEntityOrdering(TestService):
test_comment = "Entity ordering test"
only_fields = ["id", "started", "comment"]
only_fields = ["id", "started", "comment", "execution.parameters"]
def setUp(self, **kwargs):
super().setUp(**kwargs)
@@ -27,6 +27,9 @@ class TestEntityOrdering(TestService):
# sort by the same field that we use for the search
self._assertGetTasksWithOrdering(order_by="comment")
# sort by parameter which type is not part of db schema
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
def test_order_with_paging(self):
order_field = "started"
# all results in one page
@@ -52,23 +55,33 @@ class TestEntityOrdering(TestService):
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
return self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
order_by=[order_by] if isinstance(order_by, str) else order_by,
comment=self.test_comment,
page=page,
page_size=page_size,
).tasks
def _assertSorted(self, vals: Sequence, ascending=True):
def _assertSorted(self, vals: Sequence, ascending=True, is_numeric=False):
"""
Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end
"""
if None in vals:
first_null_idx = vals.index(None)
none_tail = vals[first_null_idx:]
vals = vals[:first_null_idx]
self.assertTrue(all(val is None for val in none_tail))
self.assertTrue(all(val is not None for val in vals))
empty = [None, "", [], {}]
empty_value = None
idx = 0
for idx, val in enumerate(vals):
if val in empty:
empty_value = val
break
if idx < len(vals) - 1:
none_tail = vals[idx:]
vals = vals[:idx]
self.assertTrue(all(val == empty_value for val in none_tail))
self.assertTrue(all(val != empty_value for val in vals))
if is_numeric:
vals = list(map(int, vals))
if ascending:
cmp = operator.le
@@ -76,10 +89,18 @@ class TestEntityOrdering(TestService):
cmp = operator.ge
self.assertTrue(all(cmp(i, j) for i, j in zip(vals, vals[1:])))
def _get_value_for_path(self, data: Mapping, field_path: Sequence[str]):
val = None
for name in field_path:
val = data.get(name)
data = val if isinstance(val, dict) else {}
return val
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
order_by=[order_by] if isinstance(order_by, str) else order_by,
comment=self.test_comment,
**kwargs,
).tasks
@@ -87,12 +108,21 @@ class TestEntityOrdering(TestService):
if order_by:
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [t.get(field_name) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
)
def _create_tasks(self):
tasks = [self._temp_task() for _ in range(10)]
for _, task in zip(range(5), tasks):
tasks = [
self._temp_task(
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
)
for i in range(20)
]
for idx, task in zip(range(5), tasks):
self.api.tasks.started(task=task)
sleep(0.1)
return tasks

View File

@@ -0,0 +1,36 @@
from tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.8"):
super().setUp(version=version)
def test_tags(self):
tag1 = "Orgtest tag1"
tag2 = "Orgtest tag2"
system_tag = "Orgtest system tag"
model = self.create_temp(
"models", name="test_org", uri="file:///a", tags=[tag1]
)
task = self.create_temp(
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
)
data = self.api.organization.get_tags()
self.assertTrue(tag1 in data.tags)
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
data = self.api.organization.get_tags(include_system=True)
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
self.assertTrue(system_tag in data.system_tags)
data = self.api.organization.get_tags(
filter={"system_tags": ["__$not", system_tag]}
)
self.assertTrue(tag1 in data.tags)
self.assertFalse(tag2 in data.tags)
self.api.models.delete(model=model)
data = self.api.organization.get_tags()
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)

View File

@@ -208,25 +208,21 @@ class TestTags(TestService):
self.api.tasks.stopped(task=task_id)
def _temp_queue(self, **kwargs):
self._update_missing(kwargs, name="Test tags")
self.update_missing(kwargs, name="Test tags")
return self.create_temp("queues", **kwargs)
def _temp_project(self, **kwargs):
self._update_missing(kwargs, name="Test tags", description="test")
self.update_missing(kwargs, name="Test tags", description="test")
return self.create_temp("projects", **kwargs)
def _temp_model(self, **kwargs):
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
self.update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def _temp_task(self, **kwargs):
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
self.update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
return self.create_temp("tasks", **kwargs)
@staticmethod
def _update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def _send(self, service, action, **kwargs):
api = kwargs.pop("api", self.api)
return AttrDict(

Some files were not shown because too many files have changed in this diff Show More