Compare commits

68 Commits

Author SHA1 Message Date
allegroai
5bdbcfcd8d Update README and docker-compose files for v0.16.0 2020-08-10 23:48:38 +03:00
allegroai
a2e2052b30 Version bump 2020-08-10 08:56:50 +03:00
allegroai
0146ded4f4 Fix empty projection handling 2020-08-10 08:56:43 +03:00
allegroai
dccf9dd8f8 Fix incorrect formatted timestamp in events.download_task_log 2020-08-10 08:55:01 +03:00
allegroai
7816b402bb Enhance ES7 initialization and migration support
Support older task hyper-parameter migration on pre-population
2020-08-10 08:53:41 +03:00
allegroai
cd4ce30f7c Add support for field exclusion in get_all endpoints
Add support for ephemeral worker tags (valid while worker has not timed out)
2020-08-10 08:48:48 +03:00
allegroai
8c7e230898 Add support for Task hyper-parameter sections and meta-data
Add new Task configuration section
2020-08-10 08:45:25 +03:00
allegroai
42ba696518 Support order parameter in events.get_task_log 2020-08-10 08:37:41 +03:00
allegroai
3f84e60a1f Add debug.ping endpoint
Optimize exhausted scrolls by using a fixed empty scroll
2020-08-10 08:35:34 +03:00
allegroai
baba8b5b73 Move to ElasticSearch 7
Add initial support for project ordering
Add support for sortable task duration (used by the UI in the experiment's table)
Add support for project name in worker's current task info
Add support for results and artifacts in pre-populates examples
Add demo server features
2020-08-10 08:30:40 +03:00
Allegro AI
77397c4f21 Update docker-compose.yml 2020-07-09 13:21:44 +03:00
allegroai
8678091d8f Fix documentation, remove sudo from docker-compose up (issue #48) 2020-07-06 22:07:59 +03:00
allegroai
aa22170ab4 Fix support for example projects and experiments in demo server 2020-07-06 22:06:42 +03:00
allegroai
901ec37290 Improve pre-populate on server startup (including sync lock) 2020-07-06 22:05:36 +03:00
allegroai
21f2ea8b17 Add events.get_task_log for improved log retrieval support 2020-07-06 21:54:25 +03:00
allegroai
8219e3d4e2 Fix trains-agent-services default ubuntu docker to support unicode in tty 2020-07-06 21:52:32 +03:00
allegroai
3ed71a61d5 Add models.get_frameworks endpoint 2020-07-06 21:50:43 +03:00
allegroai
18a88a8e8f Update AWS AMIs 2020-06-24 23:15:47 +03:00
allegroai
318a72987c Update GCP images for v0.15.1 2020-06-22 13:00:30 +03:00
allegroai
5ce202cc99 Update AWS AMIs for v0.15.1 2020-06-22 00:58:11 +03:00
allegroai
d09528bc26 Version bump to v0.15.1 2020-06-21 23:58:07 +03:00
allegroai
42d2a41dbe Update docker compose files 2020-06-21 23:57:58 +03:00
allegroai
82be1840b0 Add fileserver default cache timeout for downloaded files 2020-06-21 23:55:52 +03:00
allegroai
27352c5cb6 Fix last metrics values for the multiple iterations in the same events batch 2020-06-21 23:54:53 +03:00
allegroai
1ea6408d41 Support tags-per-project in tags related services 2020-06-21 23:54:05 +03:00
allegroai
5e095af3aa Fix server unable to create fixed users due to incorrect access to user_data["key"] 2020-06-21 23:52:01 +03:00
allegroai
ab3dceed92 Fix docker-compose mongodb setup on Windows 10 2020-06-16 23:59:59 +03:00
Allegro AI
3bf5126d84 Update README.md 2020-06-03 03:51:11 +03:00
allegroai
ab2ab7b23a Update GCP Images for v0.15.0 2020-06-02 16:50:52 +03:00
allegroai
c9184d125b Update AWS AMIs for v0.15.0 2020-06-02 16:17:03 +03:00
allegroai
0c0fdb72b9 Update docker-compose.yml 2020-06-02 13:20:04 +03:00
Allegro AI
86378053d4 Update docker-compose.yml 2020-06-02 01:29:55 +03:00
Allegro AI
b1cbba0cf1 Update README.md 2020-06-02 00:46:01 +03:00
Allegro AI
f31526042d Update README.md 2020-06-02 00:36:35 +03:00
Allegro AI
3f8d5bc346 Update README.md 2020-06-02 00:21:32 +03:00
allegroai
11d76e7d8c Update AWS AMIs for v0.15.0 2020-06-01 23:07:38 +03:00
allegroai
e76c0fbc63 Version bump to 0.15.0 2020-06-01 22:20:58 +03:00
allegroai
fdc9956da3 Update trains-agent-services docker image 2020-06-01 21:53:33 +03:00
allegroai
f4addaa653 Add new services mode agent container to the docker-compose 2020-06-01 21:02:49 +03:00
allegroai
667964cc82 Add clear_all flag to tasks.reset 2020-06-01 13:07:35 +03:00
allegroai
e1309e30b7 Fix UPLOAD_FOLDER handling when provided as env var or when fileserver is run by gunicorn 2020-06-01 13:05:45 +03:00
allegroai
9403942ef7 Add support for additional task types as well as tasks.get_types to obtain actual types used globally or per project 2020-06-01 13:05:12 +03:00
allegroai
84a75d9e70 Add server uid to server.info response in API v2.8 2020-06-01 13:01:31 +03:00
allegroai
c85ab66ae6 Add organization.get_tags to obtain the set of all used task, model, queue and project tags 2020-06-01 13:00:35 +03:00
allegroai
bf7f0f646b Sort hyper parameters numeric values as numbers and not strings 2020-06-01 12:27:56 +03:00
allegroai
dcdf2a3d58 Fix task can't be cloned if input model was deleted 2020-06-01 12:23:29 +03:00
allegroai
f8d8fc40a6 Support filtering users by activity in projects 2020-06-01 11:55:40 +03:00
allegroai
45d434a123 When clearing a task do not delete draft models used by other tasks 2020-06-01 11:51:43 +03:00
allegroai
1834abe5bc Better handling of execution parameter paths 2020-06-01 11:49:35 +03:00
allegroai
d6321588f3 Fix role checked for endpoints not requiring authorization 2020-06-01 11:43:55 +03:00
allegroai
c17b10ff1d Revoke built-in webserver system-role credentials (used by the WebApp) in case we're running in fixed-mode 2020-06-01 11:41:43 +03:00
allegroai
b125a56f86 Make sure configuration path loaded from an environment variable name is lower-case 2020-06-01 11:40:34 +03:00
allegroai
c43ce3a17b Update 0.15 mongo migration to drop indices (so new ones will be automatically created) 2020-06-01 11:36:22 +03:00
allegroai
b0b09616a8 Fix single bad event causes events.add_batch to skip remaining events 2020-06-01 11:33:39 +03:00
allegroai
ede5586ccc Extract non-responsive tasks watchdog from main tasks logic 2020-06-01 11:31:36 +03:00
allegroai
a1dcdffa53 Update pymongo and mongoengine versions 2020-06-01 11:29:50 +03:00
allegroai
35a11db58e Support task log retrieval with no scroll 2020-06-01 11:27:36 +03:00
allegroai
d9bdebefc7 Update AWS AMIs 2020-05-14 17:54:30 +03:00
allegroai
f29884f05a Version bump to v0.14.2 2020-05-14 17:53:56 +03:00
allegroai
0f72d662f8 Update GCP documentation 2020-05-04 17:31:11 +03:00
allegroai
6202219034 Update README 2020-05-03 11:08:21 +03:00
allegroai
bb3218f65d Update GCP installation instructions 2020-04-06 12:59:29 +03:00
allegroai
cbcaa7c789 Add MongoDB performance optimization 2020-04-01 19:20:53 +03:00
allegroai
427322a424 Update schema 2020-04-01 19:16:34 +03:00
allegroai
0e7d7d36a9 Update docs for GCP Custom Images 2020-03-30 15:51:58 +03:00
allegroai
06032a6d66 Update documentation 2020-03-20 10:51:43 +02:00
allegroai
b48f4eb2eb Make sure time intervals are calculated in ms 2020-03-20 10:50:56 +02:00
Allegro AI
383b2666c4 Update AWS AMIs 2020-03-16 21:57:07 +02:00
133 changed files with 6788 additions and 2229 deletions

View File

@@ -1,12 +1,16 @@
# Trains Server
## Auto-Magical Experiment Manager & Version Control for AI
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![PyPI status](https://img.shields.io/badge/status-beta-yellow.svg)](https://img.shields.io/badge/status-beta-yellow.svg)
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
## Introduction
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
@@ -60,14 +64,15 @@ For example, to see if port `8080` is in use:
Launch **trains-server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
- Pre-built Docker Image
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
- Kubernetes
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
## Connecting Trains to your trains-server
@@ -95,12 +100,32 @@ you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in yo
for example http://localhost:8080.
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
## Trains-Agent Services <a name="services"></a>
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
part of the docker container collection.
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
that previously had to be executed on local / dedicated machines. It allows a single agent to
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
increased data transparency)
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
Every task launched by Trains-Agent Services will be registered as a new node in the system,
providing tracking and transparency capabilities.
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
## Advanced Functionality
**trains-server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
## Restarting trains-server
@@ -149,18 +174,29 @@ To upgrade your existing **trains-server** deployment:
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
```
1. Configure the Trains-Agent Services (not supported on Windows installation).
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
```bash
export TRAINS_HOST_IP=server_host_ip_here
export TRAINS_AGENT_GIT_USER=git_username_here
export TRAINS_AGENT_GIT_PASS=git_password_here
```
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
```bash
docker-compose -f docker-compose.yml pull
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
## Community & Support
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).

View File

@@ -15,6 +15,8 @@ services:
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
- /opt/trains/config:/opt/trains/config
depends_on:
- redis
- mongo
@@ -38,15 +40,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -56,10 +54,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"
mongo:

View File

@@ -22,6 +22,9 @@ services:
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:
@@ -37,15 +40,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -55,10 +54,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"
@@ -73,6 +72,8 @@ services:
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/data/fileserver:/mnt/fileserver
- c:/opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
@@ -84,7 +85,8 @@ services:
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- mongodata:/data
- c:/opt/trains/data/mongo/db:/data/db
- c:/opt/trains/data/mongo/configdb:/data/configdb
ports:
- "27017:27017"
@@ -115,6 +117,3 @@ services:
networks:
backend:
driver: bridge
volumes:
mongodata:

View File

@@ -10,6 +10,7 @@ services:
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
@@ -22,8 +23,10 @@ services:
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
TRAINS__apiserver__pre_populate__enabled: "true"
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
ports:
- "8008:8008"
networks:
@@ -39,15 +42,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -57,10 +56,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"
@@ -75,6 +74,7 @@ services:
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
- /opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
@@ -108,13 +108,43 @@ services:
container_name: trains-webserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- /opt/trains/logs:/var/log/trains
depends_on:
- apiserver
ports:
- "8080:80"
agent-services:
networks:
- backend
container_name: trains-agent-services
image: allegroai/trains-agent-services:latest
restart: unless-stopped
privileged: true
environment:
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
TRAINS_API_HOST: http://apiserver:8008
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
TRAINS_WORKER_ID: "trains-services"
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/trains/agent:/root/.trains
depends_on:
- apiserver
networks:
backend:
driver: bridge

View File

@@ -26,7 +26,7 @@ The minimum recommended amount of RAM is 8GB. For example, **t3.large** or **t3a
To upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
### Upgrading AMIs to v0.12
### Note on upgrading AMIs to v0.12
This upgrade includes the automatically updated AMI in Version 0.12. It also includes an additional REDIS docker to the **trains-server** setup.
@@ -50,45 +50,102 @@ To upgrade the AMI:
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
### Latest version AMI - v0.14.1 (auto update)<a name="autoupdate"></a>
### Latest version AMI - v0.15.1 (auto update)<a name="autoupdate"></a>
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
* **eu-north-1** : ami-033fd0d9163e0a36e
* **ap-south-1** : ami-0cdd7f9880336f9b5
* **eu-west-3** : ami-085f508eef9f650d5
* **eu-west-2** : ami-0936deb94a193502d
* **eu-west-1** : ami-0c20b3620f1c6fff7
* **ap-northeast-2** : ami-0707c9790f8a224c4
* **ap-northeast-1** : ami-04595f0745c090328
* **sa-east-1** : ami-03898299742d43ad4
* **ca-central-1** : ami-0502dfea5223d572a
* **ap-southeast-1** : ami-02aa1f9308404e464
* **ap-southeast-2** : ami-0b66189b90df79b38
* **eu-central-1** : ami-0eb919d8234c49cdc
* **us-east-2** : ami-02fb63fca1b9f0d4b
* **us-west-1** : ami-01fdda7351725a689
* **us-west-2** : ami-004a2f40cdc095870
* **us-east-1** : ami-0a8acd1172ffebc7e
* **eu-north-1** : ami-0f30c84b905d354b9
* **ap-south-1** : ami-050e7acec52c8c74e
* **eu-west-3** : ami-03911c5b5bc77ef75
* **eu-west-2** : ami-0a5ed8aa2573ccc70
* **eu-west-1** : ami-0a53c65e922ec0611
* **ap-northeast-2** : ami-08cd017a37b8e8aab
* **ap-northeast-1** : ami-056b3ca1ad5af9322
* **sa-east-1** : ami-01ddc9325bafb400c
* **ca-central-1** : ami-0fc3cbbd982b18b45
* **ap-southeast-1** : ami-04c7a358df7002ef5
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
* **eu-central-1** : ami-00b8e44041f8175fd
* **us-east-2** : ami-0ac7deebb3f738f6d
* **us-west-1** : ami-06bc07deb8b8c44d6
* **us-west-2** : ami-01ba85ffe79a422f1
* **us-east-1** : ami-04cf5a66cb4928ac3
### v0.15.1 (static update)
* **eu-north-1** : ami-0cd314e267426d1b7
* **ap-south-1** : ami-086182cbe29151f96
* **eu-west-3** : ami-0062366012182815b
* **eu-west-2** : ami-022b8f2e32a9d18d0
* **eu-west-1** : ami-0d8cf60446e09aa3d
* **ap-northeast-2** : ami-0d4c168a815b56889
* **ap-northeast-1** : ami-0daf7887db1053ae4
* **sa-east-1** : ami-020a759a3ba4ff22b
* **ca-central-1** : ami-0c10b5e04b707f3e3
* **ap-southeast-1** : ami-0f61bb3529a165fcd
* **ap-southeast-2** : ami-032dcdc82749c66c5
* **eu-central-1** : ami-08f364f32d2eb3bae
* **us-east-2** : ami-0b7efc3591803eba4
* **us-west-1** : ami-08b2df27b0ada6faf
* **us-west-2** : ami-0693029c4bad28816
* **us-east-1** : ami-0200954fa9c2819ff
### v0.15.0 (static update)
* **eu-north-1** : ami-0bef15c03eab64c0c
* **ap-south-1** : ami-06ac6248e583e2cd2
* **eu-west-3** : ami-0541d86ef47a5714e
* **eu-west-2** : ami-01381ef4c4ed22482
* **eu-west-1** : ami-064626a0dd38b21f1
* **ap-northeast-2** : ami-0a2490a7a3a8aa675
* **ap-northeast-1** : ami-063f1de819a2524b8
* **sa-east-1** : ami-07980486741b94987
* **ca-central-1** : ami-0ced3b8b21ded839e
* **ap-southeast-1** : ami-0c493c5093fde8741
* **ap-southeast-2** : ami-0320a727eccb8dc6c
* **eu-central-1** : ami-0aa85cfc78674c526
* **us-east-2** : ami-01791485051e1880c
* **us-west-1** : ami-0d8eade4d5888ea73
* **us-west-2** : ami-02ceaef72cdf60f7e
* **us-east-1** : ami-0fc3f9d1d0eba1d62
### v0.14.2 (static update)
* **eu-north-1** : ami-006d491e9e8869248
* **ap-south-1** : ami-0e55ec221687f98e7
* **eu-west-3** : ami-06ad9cf3c05c83e91
* **eu-west-2** : ami-0d05839268e748cff
* **eu-west-1** : ami-0d14c297789ce0d7a
* **ap-northeast-2** : ami-0d7fd775f0e76cc6f
* **ap-northeast-1** : ami-0c0a6e1daeb3f7a9c
* **sa-east-1** : ami-01e0c5e30e94ec887
* **ca-central-1** : ami-07a31896832734897
* **ap-southeast-1** : ami-0886d5b2d4b7fccd5
* **ap-southeast-2** : ami-0397d5a2db3c356fe
* **eu-central-1** : ami-0629f26eea22f5c17
* **us-east-2** : ami-0499c3d7bb45a1a6e
* **us-west-1** : ami-02fa8a961a4daf9f0
* **us-west-2** : ami-05c711cfab4342468
* **us-east-1** : ami-0b97d99a08012c726
### v0.14.1 (static update)
* **eu-north-1** : ami-0ccdf4700a6c989ad
* **ap-south-1** : ami-0f6a8de6441d64a68
* **eu-west-3** : ami-0c51b9a8e7b3371cc
* **eu-west-2** : ami-099c094598f72bcb5
* **eu-west-1** : ami-0af20d5e4ab764212
* **ap-northeast-2** : ami-011455e8d852e02d6
* **ap-northeast-1** : ami-0211827ee11d6ed9c
* **sa-east-1** : ami-07509b07aa4554dc2
* **ca-central-1** : ami-07153c171d97e460e
* **ap-southeast-1** : ami-042d61c497063675b
* **ap-southeast-2** : ami-0dcf27f88bd2dd622
* **eu-central-1** : ami-0ae29f89d9bcb1a95
* **us-east-2** : ami-053144df2cea2bd97
* **us-west-1** : ami-0f703537206ee05f1
* **us-west-2** : ami-007c954572c86a583
* **us-east-1** : ami-07c59cbc7541f58e9
* **eu-north-1** : ami-036defe1885dced2e
* **ap-south-1** : ami-0b403aa1da6a5dc17
* **eu-west-3** : ami-0d30c2d330d1255c4
* **eu-west-2** : ami-06f0e8d075e50a029
* **eu-west-1** : ami-0da721d874f282b6d
* **ap-northeast-2** : ami-03bffe94675dd5f8c
* **ap-northeast-1** : ami-0f96520d646423673
* **sa-east-1** : ami-0c2f706a3b7d97282
* **ca-central-1** : ami-0da74525dcfd74e32
* **ap-southeast-1** : ami-066368a21cf6d232b
* **ap-southeast-2** : ami-0bfd09170067f7318
* **eu-central-1** : ami-06aa99b1c41492986
* **us-east-2** : ami-065c1880f59d03272
* **us-west-1** : ami-0b7f6b896f5058eba
* **us-west-2** : ami-0041e10ca68eef29a
* **us-east-1** : ami-0b7125e4305bbd7eb
### v0.14.0 (static update)
* **eu-north-1** : ami-02de71586ec496e38

76
docs/install_gcp.md Normal file
View File

@@ -0,0 +1,76 @@
# Deploying Trains Server on Google Cloud Platform
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
#### Default Trains Server Service ports
The service port numbers on our Trains Server GCP Custom Image are:
- Web application: `8080`
- API Server: `8008`
- File Server: `8081`
#### Default Trains Server Storage paths
The persistent storage configuration:
- MongoDB: `/opt/trains/data/mongo/`
- ElasticSearch: `/opt/trains/data/elastic/`
- File Server: `/mnt/fileserver/`
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
## Importing the Custom Image to your GCP account
In order to launch an instance using the Trains Server GCP Custom Image, you'll need to import the image to your custom images list.
**Note:** there's **no need** to upload the image file to Google Cloud Storage - we already provide links to image files stored in Google Storage
To import the image to your custom images list:
1. In the Cloud Console, go to the [Images](https://console.cloud.google.com/compute/images) page.
1. At the top of the page, click **Create image**.
1. In the **Name** field, specify a unique name for the image.
1. Optionally, specify an image family for your new image, or configure specific encryption settings for the image.
1. Click the **Source** menu and select **Cloud Storage file**.
1. Enter the Trains Server image bucket path (see [Trains Server GCP Custom Image](#released-versions)), for example:
`allegro-files/trains-server/trains-server.tar.gz`
1. Click the **Create** button to import the image. The process can take several minutes depending on the size of the boot disk image.
For more information see [Import the image to your custom images list](https://cloud.google.com/compute/docs/import/import-existing-image#import_image) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
## Launching an instance with a Custom Image
For instructions on launching an instance using a GCP Custom Image, see the [Manually importing virtual disks](https://cloud.google.com/compute/docs/import/import-existing-image#overview) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
For more information on Custom Images, see [Custom Images](https://cloud.google.com/compute/docs/images#custom_images) in the Compute Engine Documentation.
The minimum recommended requirements for Trains Server are:
- 2 vCPUs
- 7.5GB RAM
## Upgrading
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
## Network and Security
Please make sure your instance is properly secured.
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
If your instance is open for public access, we recommend you follow best practices for access management, including:
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
## Released versions
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
### Latest version image
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
### All released images
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz

View File

@@ -1,6 +1,6 @@
# Launching the **trains-server** Docker in Linux or macOS
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux users:
@@ -16,20 +16,20 @@ To launch **trains-server** on Linux or macOS:
1. Verify the Docker CE installation. Execute the command:
sudo docker run hello-world
docker run hello-world
The expected is output is:
Hello from Docker!
This message shows that your installation appears to be working correctly.
To generate this message, Docker took the following steps:
1. The Docker client contacted the Docker daemon.
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
@@ -42,12 +42,12 @@ To launch **trains-server** on Linux or macOS:
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
sudo service docker restart
macOS:
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
1. Remove any previous installation of **trains-server**.
@@ -64,28 +64,28 @@ To launch **trains-server** on Linux or macOS:
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/config
sudo mkdir -p /opt/trains/data/fileserver
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
1. Grant access to the Dockers.
Linux:
sudo chown -R 1000:1000 /opt/trains
macOS:
sudo chown -R $(whoami):staff /opt/trains
1. Download the **trains-server** docker-compose YAML file.
cd /opt/trains
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
1. Run `docker-compose` with the downloaded configuration file.
sudo docker-compose -f docker-compose.yml up
docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
@@ -94,4 +94,4 @@ To launch **trains-server** on Linux or macOS:
## Next Step
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).

View File

@@ -1,4 +1,5 @@
import logging
import os
from functools import reduce
from os import getenv
from os.path import expandvars
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATH_SEP = ":"
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
class BasicConfig:
NotSet = object()
@@ -46,7 +50,23 @@ class BasicConfig:
path = ".".join((self.prefix, Path(name).stem))
return logging.getLogger(path)
def _read_env_paths(self, key):
@staticmethod
def _read_extra_env_config_values():
""" Loads extra configuration from environment-injected values """
result = ConfigTree()
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
return result
@staticmethod
def _read_env_paths(key):
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
if value is None:
return
@@ -64,12 +84,17 @@ class BasicConfig:
def _load(self, verbose=True):
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
extra_config_values = self._read_extra_env_config_values()
configs = [
self._read_recursive(path, verbose=verbose)
for path in [self.folder] + extra_config_paths
]
self._config = reduce(
lambda config, path: ConfigTree.merge_configs(
config, self._read_recursive(path, verbose=verbose), copy_trees=True
lambda last, config: ConfigTree.merge_configs(
last, config, copy_trees=True
),
[self.folder] + extra_config_paths,
configs + [extra_config_values],
ConfigTree(),
)

View File

@@ -1,6 +1,9 @@
download {
# Add response headers requesting no caching for served files
disable_browser_caching: false
# Cache timeout to be set for downloaded files
cache_timeout_sec: 300
}
cors {

View File

@@ -10,12 +10,14 @@ from flask_cors import CORS
from config import config
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
if os.environ.get("TRAINS_UPLOAD_FOLDER"):
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER")
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get("fileserver.download.cache_timeout_sec", 5 * 60)
@app.route("/", methods=["POST"])
@@ -57,12 +59,13 @@ def main():
parser.add_argument(
"--upload-folder",
"-u",
default="/mnt/fileserver",
default=DEFAULT_UPLOAD_FOLDER,
help="Upload folder (default %(default)s)",
)
args = parser.parse_args()
app.config["UPLOAD_FOLDER"] = args.upload_folder
if app.config.get("UPLOAD_FOLDER") is None:
app.config["UPLOAD_FOLDER"] = args.upload_folder
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)

View File

@@ -1 +1 @@
__version__ = "2.7.0"
__version__ = "2.9.0"

View File

@@ -47,6 +47,7 @@ _error_codes = {
128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
131: ('events_not_added', 'events not added'),
# Models
200: ('model_error', 'general task error'),

View File

@@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors
from utilities.json import loads, dumps
def make_default(field_cls, default_value):
@@ -206,10 +207,10 @@ class DomainField(fields.StringField):
raise errors.bad_request.InvalidDomainName()
class StringEnum(Enum):
def __str__(self):
return self.value
class JsonSerializableMixin:
def to_json(self: ModelBase):
return dumps(self.to_struct())
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))

View File

@@ -1,7 +1,8 @@
from jsonmodels import models, fields
from jsonmodels.validators import Length
from mongoengine.base import BaseDocument
from apimodels import DictField
from apimodels import DictField, ListField
class MongoengineFieldsDict(DictField):
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
"""
mongoengine_update_operators = (
'inc',
'dec',
'push',
'push_all',
'pop',
'pull',
'pull_all',
'add_to_set',
"inc",
"dec",
"push",
"push_all",
"pop",
"pull",
"pull_all",
"add_to_set",
)
@staticmethod
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
@classmethod
def _normalize_mongo_field_path(cls, path, value):
parts = path.split('__')
parts = path.split("__")
if len(parts) > 1:
if parts[0] == 'set':
if parts[0] == "set":
parts = parts[1:]
elif parts[0] == 'unset':
elif parts[0] == "unset":
parts = parts[1:]
value = None
elif parts[0] in cls.mongoengine_update_operators:
return None, None
return '.'.join(parts), cls._normalize_mongo_value(value)
return ".".join(parts), cls._normalize_mongo_value(value)
def parse_value(self, value):
value = super(MongoengineFieldsDict, self).parse_value(value)
@@ -62,3 +63,7 @@ class PagedRequest(models.Base):
class IdResponse(models.Base):
id = fields.StringField(required=True)
class MakePublicRequest(models.Base):
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])

View File

@@ -1,17 +1,19 @@
from typing import Sequence
from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from jsonmodels.validators import Length, Min, Max
from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType
from bll.event.scalar_key import ScalarKeyEnum
from utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=10000)
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -21,7 +23,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
items_types=str, validators=[Length(minimum_value=1, maximum_value=10)]
)
@@ -40,6 +42,19 @@ class DebugImagesRequest(Base):
scroll_id: str = StringField()
class LogOrderEnum(StringEnum):
asc = auto()
desc = auto()
class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)

View File

@@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse
from apimodels.tasks import PublishResponse as TaskPublishResponse
class GetFrameworksRequest(models.Base):
projects = fields.ListField(items_types=[str])
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)

View File

@@ -0,0 +1,11 @@
from jsonmodels import fields, models
class Filter(models.Base):
tags = fields.ListField([str])
system_tags = fields.ListField([str])
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)

View File

@@ -1,5 +1,8 @@
from jsonmodels import models, fields
from apimodels import ListField
from apimodels.organization import TagsRequest
class ProjectReq(models.Base):
project = fields.StringField()
@@ -10,7 +13,5 @@ class GetHyperParamReq(ProjectReq):
page_size = fields.IntField(default=500)
class GetHyperParamResp(models.Base):
parameters = fields.ListField(str)
remaining = fields.IntField()
total = fields.IntField()
class ProjectTagsRequest(TagsRequest):
projects = ListField(str)

View File

@@ -1,7 +1,9 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum
from jsonmodels.validators import Enum, Length
from apimodels import DictField, ListField
from apimodels.base import UpdateResponse
@@ -92,6 +94,10 @@ class PingRequest(TaskRequest):
pass
class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@@ -99,7 +105,10 @@ class CloneRequest(TaskRequest):
new_task_system_tags = ListField([str])
new_task_parent = StringField()
new_task_project = StringField()
new_hyperparams = DictField()
new_configuration = DictField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
class AddOrUpdateArtifactsRequest(TaskRequest):
@@ -109,3 +118,76 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str])
updated = ListField([str])
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
class MultiTaskRequest(models.Base):
tasks = ListField([str], validators=Length(minimum_value=1))
class GetHyperParamsRequest(MultiTaskRequest):
pass
class HyperParamItem(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class ReplaceHyperparams(object):
none = "none"
section = "section"
all = "all"
class EditHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
replace_hyperparams = StringField(
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
class HyperParamKey(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
class GetConfigurationsRequest(MultiTaskRequest):
names = ListField([str])
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
class Configuration(models.Base):
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class EditConfigurationRequest(TaskRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))

View File

@@ -1,4 +1,3 @@
import json
from enum import Enum
import six
@@ -13,13 +12,14 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apimodels import make_default, ListField, EnumField
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
class RegisterRequest(WorkerRequest):
@@ -61,26 +61,21 @@ class IdNameEntry(Base):
name = StringField()
class WorkerEntry(Base):
class WorkerEntry(Base, JsonSerializableMixin):
key = StringField() # not required due to migration issues
id = StringField(required=True)
user = EmbeddedField(IdNameEntry)
company = EmbeddedField(IdNameEntry)
ip = StringField()
task = EmbeddedField(IdNameEntry)
project = EmbeddedField(IdNameEntry)
queue = StringField() # queue from which current task was taken
queues = ListField(str) # list of queues this worker listens to
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
def to_json(self):
return json.dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**json.loads(s))
tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):

View File

@@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from typing import Sequence, Tuple, Optional, Mapping
import database
from apierrors import errors
from bll.redis_cache_manager import RedisCacheManager
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from jsonmodels.models import Base
from jsonmodels.fields import StringField, ListField, IntField
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
from utilities.json import loads, dumps
class VariantScrollState(Base):
@@ -45,17 +43,10 @@ class MetricScrollState(Base):
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
def to_json(self):
return dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**loads(s))
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
@@ -65,7 +56,12 @@ class DebugImagesResult(object):
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
STATE_EXPIRATION_SECONDS = 3600
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
@@ -76,7 +72,7 @@ class DebugImagesIterator:
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.STATE_EXPIRATION_SECONDS,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
@@ -92,27 +88,31 @@ class DebugImagesIterator:
if not self.es.indices.exists(es_index):
return DebugImagesResult()
unique_metrics = set(metrics)
state = self.cache_manager.get_state(state_id) if state_id else None
if not state:
state = DebugImageEventsScrollState(
id=database.utils.id(),
metrics=self._init_metric_states(es_index, list(unique_metrics)),
)
else:
state_metrics = set((m.task, m.name) for m in state.metrics)
if state_metrics != unique_metrics:
raise errors.bad_request.InvalidScrollId(
"while getting debug images events", scroll_id=state_id
)
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state)
for metric_state in state.metrics:
self._reinit_outdated_metric_states(company_id, es_index, state_)
for metric_state in state_.metrics:
metric_state.reset()
res = DebugImagesResult(next_scroll_id=state.id)
try:
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
@@ -125,10 +125,8 @@ class DebugImagesIterator:
state.metrics,
)
)
finally:
self.cache_manager.set_state(state)
return res
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState
@@ -210,7 +208,11 @@ class DebugImagesIterator:
"size": 0,
"query": {
"bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"aggs": {
@@ -253,7 +255,7 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -300,6 +302,7 @@ class DebugImagesIterator:
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
@@ -370,7 +373,7 @@ class DebugImagesIterator:
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"},
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
@@ -389,7 +392,7 @@ class DebugImagesIterator:
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return metric.task, metric.name, []

View File

@@ -3,9 +3,8 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence
from typing import Sequence, Set, Tuple, Optional
import attr
import six
from elasticsearch import helpers
from mongoengine import Q
@@ -16,12 +15,14 @@ import es_factory
from apierrors import errors
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus
from redis_manager import redman
from timing_context import TimingContext
from tools import safe_get
from utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
@@ -29,15 +30,9 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s(auto_attribs=True)
class TaskEventsResult(object):
total_events: int = 0
next_scroll_id: str = None
events: list = attr.ib(factory=list)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
@@ -47,12 +42,28 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
@@ -62,19 +73,34 @@ class EventBLL(object):
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
if "type" not in event:
raise errors.BadRequest("Event must have a 'type' field", event=event)
event_type = event.get("type")
if event_type is None:
errors_per_type["Event must have a 'type' field"] += 1
continue
event_type = event["type"].replace(" ", "_")
event_type = event_type.replace(" ", "_")
if event_type not in EVENT_TYPES:
raise errors.BadRequest(
"Invalid event type {}".format(event_type),
event=event,
types=EVENT_TYPES,
)
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
continue
event["type"] = event_type
@@ -110,7 +136,6 @@ class EventBLL(object):
es_action = {
"_op_type": "index", # overwrite if exists with same ID
"_index": index_name,
"_type": "event",
"_source": event,
}
@@ -120,89 +145,74 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
else:
es_action["_routing"] = task_id
actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, ids=invalid_task_ids
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
errors_in_bulk = []
added = 0
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_in_bulk.append(info)
if not updated:
remaining_tasks.add(task_id)
continue
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
# who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions))
return added, errors_in_bulk
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
def _update_last_scalar_events_for_task(self, last_events, event):
"""
@@ -220,9 +230,25 @@ class EventBLL(object):
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
last_event = last_events[metric_hash][variant_hash]
event_iter = event.get("iter", 0)
event_timestamp = event.get("timestamp", 0)
value = event.get("value")
if value is not None and (
(event_iter, event_timestamp)
>= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
)
):
event_data = {
k: event[k]
for k in ("value", "metric", "variant", "iter", "timestamp")
if k in event
}
event_data["min_value"] = min(value, last_event.get("min_value", value))
event_data["max_value"] = max(value, last_event.get("max_value", value))
last_events[metric_hash][variant_hash] = event_data
def _update_last_metric_events_for_task(self, last_events, event):
"""
@@ -265,7 +291,13 @@ class EventBLL(object):
flatten_nested_items(
last_scalar_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
include_leaves=[
"value",
"min_value",
"max_value",
"metric",
"variant",
],
)
)
@@ -290,6 +322,9 @@ class EventBLL(object):
batch_size=10000,
scroll_id=None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
@@ -310,14 +345,9 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(
index=es_index, body=es_req, scroll="1h", routing=task_id
)
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
@@ -345,7 +375,7 @@ class EventBLL(object):
"terms": {
"field": "iter",
"size": num_last_iterations,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
}
}
},
@@ -361,7 +391,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -381,6 +411,9 @@ class EventBLL(object):
size: int = 500,
scroll_id: str = None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
@@ -390,13 +423,11 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
must = []
if last_iterations_per_plot is None:
must = query["bool"]["must"]
must.append({"terms": {"task": tasks}})
else:
should = query["bool"]["should"]
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type
@@ -419,32 +450,41 @@ class EventBLL(object):
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(tasks)
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
index=es_index, body=es_req, ignore=404, scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
# scroll id may be missing when queering a totally empty DB
next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"]
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.es.clear_scroll(scroll_id=next_scroll_id)
next_scroll_id = self.empty_scroll
return events, total_events, next_scroll_id
def get_task_events(
self,
company_id,
@@ -457,6 +497,8 @@ class EventBLL(object):
size=500,
scroll_id=None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
@@ -470,20 +512,16 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
if metric or variant:
must = query["bool"]["must"]
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
must = []
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
if last_iter_count is None:
must = query["bool"]["must"]
must.append({"terms": {"task": task_ids}})
else:
should = query["bool"]["should"]
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count
@@ -502,27 +540,23 @@ class EventBLL(object):
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(task_ids)
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
index=es_index, body=es_req, ignore=404, scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
@@ -558,7 +592,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@@ -590,14 +624,14 @@ class EventBLL(object):
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
"order": {"_term": "asc"},
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
"order": {"_term": "asc"},
"order": {"_key": "asc"},
},
"aggs": {
"last_value": {
@@ -627,7 +661,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
metrics = []
max_timestamp = 0
@@ -674,7 +708,7 @@ class EventBLL(object):
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
vectors = []
iterations = []
@@ -695,7 +729,7 @@ class EventBLL(object):
"terms": {
"field": "iter",
"size": iters,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
}
}
},
@@ -705,7 +739,7 @@ class EventBLL(object):
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -727,8 +761,6 @@ class EventBLL(object):
es_index = EventMetrics.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(
index=es_index, body=es_req, routing=task_id, refresh=True
)
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
return es_res.get("deleted", 0)

View File

@@ -1,12 +1,11 @@
import itertools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable
from typing import Sequence, Tuple
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -16,7 +15,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task
from timing_context import TimingContext
from utilities import safe_get
from tools import safe_get
log = config.logger(__file__)
@@ -30,14 +29,18 @@ class EventType(Enum):
class EventMetrics:
MAX_TASKS_COUNT = 50
MAX_METRICS_COUNT = 200
MAX_VARIANTS_COUNT = 500
MAX_METRICS_COUNT = 100
MAX_VARIANTS_COUNT = 100
MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000
def __init__(self, es: Elasticsearch):
self.es = es
@property
def _max_concurrency(self):
return config.get("services.events.max_metrics_concurrency", 4)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
@@ -51,15 +54,48 @@ class EventMetrics:
The amount of points in each histogram should not exceed
the requested samples
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
return self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=[task_id],
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average,
return self._get_scalar_average_per_iter_core(
task_id, es_index, samples, ScalarKey.resolve(key)
)
def _get_scalar_average_per_iter_core(
self,
task_id: str,
es_index: str,
samples: int,
key: ScalarKey,
run_parallel: bool = True,
) -> dict:
intervals = self._get_task_metric_intervals(
es_index=es_index, task_id=task_id, samples=samples, field=key.field
)
if not intervals:
return {}
interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial(
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
)
if run_parallel:
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups)
)
else:
metrics = itertools.chain.from_iterable(
get_scalar_average(group) for group in interval_groups
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def compare_scalar_metrics_average_per_iter(
self,
company_id,
@@ -72,159 +108,115 @@ class EventMetrics:
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
override_projection=("id", "name", "company"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
ret = self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=task_ids,
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average_per_task,
)
companies = {t.company for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
for metric_data in ret.values():
for variant_data in metric_data.values():
for task_id, task_data in variant_data.items():
task_data["name"] = task_name_by_id[task_id]
return ret
TaskMetric = Tuple[str, str, str]
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _split_metrics_by_max_aggs_count(
self, task_metrics: Sequence[TaskMetric]
) -> Iterable[Sequence[TaskMetric]]:
"""
Return task metrics in groups where amount of task metrics in each group
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
variants while always preserving all their tasks in the same group
"""
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
yield task_metrics
return
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
groups = []
for group in tm_grouped.values():
groups.append(group)
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
yield list(itertools.chain(*groups))
groups = []
if groups:
yield list(itertools.chain(*groups))
return
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
task_ids: Sequence[str],
samples: int,
key: ScalarKey,
get_func: Callable[
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
],
) -> dict:
"""
Group metrics per interval length and execute get_func for each group in parallel
:param company_id: id of the company
:params task_ids: ids of the tasks to collect data for
:param samples: maximum number of samples per metric
:param get_func: callable that given metric names for the same interval
performs histogram aggregation for the metrics and return the aggregated data
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
intervals = self._get_metric_intervals(
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
es_index=es_index,
samples=samples,
key=ScalarKey.resolve(key),
run_parallel=False,
)
if not intervals:
return {}
intervals = list(
itertools.chain.from_iterable(
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
for i, tms in intervals
)
)
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
res[metric_key][variant_key][task_id] = variant_data
return ret
return res
def _get_metric_intervals(
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@classmethod
def _group_task_metric_intervals(
cls, intervals: Sequence[MetricInterval]
) -> Sequence[MetricIntervalGroup]:
"""
Group task metric intervals so that the following conditions are meat:
- All the metrics in the same group have the same interval (with 10% rounding)
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
"""
metric_interval_groups = []
interval_group = []
group_interval_upper_bound = 0
group_max_interval = 0
group_samples = 0
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
if (
interval > group_interval_upper_bound
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
):
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
interval_group = []
group_max_interval = interval
group_interval_upper_bound = interval + int(interval * 0.1)
group_samples = 0
interval_group.append((metric, variant))
group_samples += size
group_max_interval = max(group_max_interval, interval)
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
return metric_interval_groups
def _get_task_metric_intervals(
self, es_index, task_id: str, samples: int, field: str = "iter"
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
amount of points does not exceed sample.
Return metric variants grouped by interval value with 10% rounding
For samples==0 return empty list
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
default_intervals = [(1, [])]
if not samples:
return default_intervals
es_req = {
"size": 0,
"query": {"terms": {"task": task_ids}},
"query": {"term": {"task": task_id}},
"aggs": {
"tasks": {
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"metrics": {
"variants": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
},
@@ -233,88 +225,75 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
es_res = self.es.search(index=es_index, body=es_req)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return default_intervals
return []
intervals = [
(
task["key"],
metric["key"],
variant["key"],
self._calculate_metric_interval(variant, samples),
)
for task in aggs_result["tasks"]["buckets"]
for metric in task["metrics"]["buckets"]
return [
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
for metric in aggs_result["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
]
metric_intervals = []
upper_border = 0
interval_metrics = None
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
if not interval_metrics or interval > upper_border:
interval_metrics = []
metric_intervals.append((interval, interval_metrics))
upper_border = interval + int(interval * 0.1)
interval_metrics.append((task, metric, variant))
return metric_intervals
@staticmethod
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
def _build_metric_interval(
metric: str, variant: str, data: dict, samples: int
) -> Tuple[str, str, int, int]:
"""
Calculate index interval per metric_variant variant so that the
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(metric_variant, "count/value")
if not count or count < samples:
return 1
count = safe_get(data, "count/value", default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(metric_variant, "min_index/value", default=0)
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
return max(1, int(max_index - min_index + 1) // samples)
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
return (
metric,
variant,
max(1, int(max_index - min_index + 1) // samples),
samples,
)
MetricData = Tuple[str, dict]
def _get_scalar_average(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
metrics_interval: MetricIntervalGroup,
task_id: str,
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
Note: the function works with a single task only
"""
assert len(task_ids) == 1
interval, task_metrics = metrics_interval
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
aggs_result = self._query_aggregation_for_task_metrics(
es_index, aggs=aggs, task_id=task_id, metrics=metrics
)
if not aggs_result:
@@ -335,61 +314,6 @@ class EventMetrics:
]
return metrics
def _get_scalar_average_per_task(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
"""
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {
"terms": {
"field": "task",
"size": self.MAX_TASKS_COUNT,
},
"aggs": aggregation,
}
},
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
)
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
task["key"]: key.get_iterations_data(task)
for task in variant["tasks"]["buckets"]
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
@staticmethod
def _add_aggregation_average(aggregation):
average_agg = {"avg_val": {"avg": {"field": "value"}}}
@@ -398,69 +322,55 @@ class EventMetrics:
for key, value in aggregation.items()
}
def _query_aggregation_for_metrics_and_tasks(
def _query_aggregation_for_task_metrics(
self,
es_index: str,
aggs: dict,
task_ids: Sequence[str],
task_metrics: Sequence[TaskMetric],
task_id: str,
metrics: Sequence[Tuple[str, str]],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
if task_metrics:
condition = {
"should": [
self._build_metric_terms(task, metric, variant)
for task, metric, variant in task_metrics
]
}
else:
condition = {"must": [{"terms": {"task": task_ids}}]}
must = [{"term": {"task": task_id}}]
if metrics:
should = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
for metric, variant in metrics
]
must.append({"bool": {"should": should}})
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"bool": condition},
"query": {"bool": {"must": must}},
"aggs": aggs,
"version": True,
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
es_res = self.es.search(index=es_index, body=es_req)
return es_res.get("aggregations")
@staticmethod
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
"""
Build query term for a metric + variant
"""
return {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence[Tuple]:
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return [(tid, []) for tid in task_ids]
return {}
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_concurrency) as pool:
with ThreadPoolExecutor(self._max_concurrency) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
@@ -488,7 +398,7 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
return [
metric["key"]

View File

@@ -0,0 +1,113 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from bll.event.event_metrics import EventMetrics
from database.errors import translate_errors_context
from timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = "log"
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
es_index=es_index,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
es_index,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
already_present_ids = set(ev["_id"] for ev in events)
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[
*events,
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
],
hits_total,
)

View File

@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from apimodels import StringEnum
from utilities.stringenum import StringEnum
from bll.util import extract_properties_to_lists
from config import config
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"fixed_interval": f"{interval}ms",
"min_doc_count": 1,
}
}
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"fixed_interval": f"{interval}ms",
"min_doc_count": 1,
"format": "strict_date_time",
}

View File

@@ -0,0 +1,18 @@
from typing import Optional, Sequence
from mongoengine import Q
from database.model.model import Model
from database.utils import get_company_or_none_constraint
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")

View File

@@ -0,0 +1,193 @@
from collections import defaultdict
from enum import Enum
from itertools import chain
from typing import Sequence, Union, Type, Dict
from mongoengine import Q
from redis import Redis
from config import config
from database.model.base import GetMixin
from database.model.model import Model
from database.model.task.task import Task
from redis_manager import redman
from utilities import json
log = config.logger(__file__)
_settings_prefix = "services.organization"
class _TagsCache:
_tags_field = "tags"
_system_tags_field = "system_tags"
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
self.db_cls = db_cls
self.redis = redis
@property
def _tags_cache_expiration_seconds(self):
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
def _get_tags_from_db(
self,
company: str,
field: str,
project: str = None,
filter_: Dict[str, Sequence[str]] = None,
) -> set:
query = Q(company=company)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project=project)
return self.db_cls.objects(query).distinct(field)
def _get_tags_cache_key(
self,
company: str,
field: str,
project: str = None,
filter_: Dict[str, Sequence[str]] = None,
):
"""
Project None means 'from all company projects'
The key is built in the way that scanning company keys for 'all company projects'
will not return the keys related to the particular company projects and vice versa.
So that we can have a fine grain control on what redis keys to invalidate
"""
filter_str = None
if filter_:
filter_str = "_".join(
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
)
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
return "_".join(filter(None, key_parts))
def get_tags(
self,
company: str,
include_system: bool = False,
filter_: Dict[str, Sequence[str]] = None,
project: str = None,
) -> dict:
"""
Get tags and optionally system tags for the company
Return the dictionary of tags per tags field name
The function retrieves both cached values from Redis in one call
and re calculates any of them if missing in Redis
"""
fields = [self._tags_field]
if include_system:
fields.append(self._system_tags_field)
redis_keys = [
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
for f in fields
]
cached = self.redis.mget(redis_keys)
ret = {}
for field, tag_data, key in zip(fields, cached, redis_keys):
if tag_data is not None:
tags = json.loads(tag_data)
else:
tags = list(self._get_tags_from_db(company, field, project, filter_))
self.redis.setex(
key,
time=self._tags_cache_expiration_seconds,
value=json.dumps(tags),
)
ret[field] = set(tags)
return ret
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
"""
fields = [
field
for field, update in (
(self._tags_field, tags),
(self._system_tags_field, system_tags),
)
if update is not None
]
if not fields:
return
self._delete_redis_keys(company, projects=[project], fields=fields)
def reset_tags(self, company: str, projects: Sequence[str]):
self._delete_redis_keys(
company,
projects=projects,
fields=(self._tags_field, self._system_tags_field),
)
def _delete_redis_keys(
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
):
redis_keys = list(
chain.from_iterable(
self.redis.keys(
self._get_tags_cache_key(company, field=f, project=p) + "*"
)
for f in fields
for p in set(projects) | {None}
)
)
if redis_keys:
self.redis.delete(*redis_keys)
class Tags(Enum):
Task = "task"
Model = "model"
class OrgBLL:
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def get_tags(
self,
company: str,
entity: Tags,
include_system: bool = False,
filter_: Dict[str, Sequence[str]] = None,
projects: Sequence[str] = None,
) -> dict:
tags_cache = self._get_tags_cache_for_entity(entity)
if not projects:
return tags_cache.get_tags(
company, include_system=include_system, filter_=filter_
)
ret = defaultdict(set)
for project in projects:
project_tags = tags_cache.get_tags(
company, include_system=include_system, filter_=filter_, project=project
)
for field, tags in project_tags.items():
ret[field] |= tags
return ret
def update_tags(
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company, project, tags, system_tags)
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.reset_tags(company, projects=projects)
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags

View File

@@ -0,0 +1 @@
from .project_bll import ProjectBLL

View File

@@ -0,0 +1,33 @@
from typing import Sequence, Optional
from mongoengine import Q
from config import config
from database.model.model import Model
from database.model.task.task import Task
from timing_context import TimingContext
log = config.logger(__file__)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res

View File

@@ -18,7 +18,6 @@ log = config.logger(__file__)
class QueueMetrics:
class EsKeys:
DOC_TYPE = "metrics"
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
@@ -66,7 +65,6 @@ class QueueMetrics:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_type=self.EsKeys.DOC_TYPE,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
@@ -93,7 +91,6 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
doc_type=self.EsKeys.DOC_TYPE,
body=es_req,
)
@@ -109,7 +106,7 @@ class QueueMetrics:
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"interval": f"{interval}s",
"fixed_interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {

View File

@@ -1,15 +1,21 @@
from typing import Optional, TypeVar, Generic, Type
from contextlib import contextmanager
from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
import database
from timing_context import TimingContext
T = TypeVar("T")
def _do_nothing(_: T):
return
class RedisCacheManager(Generic[T]):
"""
Class for store/retreive of state objects from redis
Class for store/retrieve of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
@@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
@contextmanager
def get_or_create_state(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
):
"""
Try to retrieve state with the given id from the Redis cache if yes then validates it
If no then create a new one with randomly generated id
Yield the state and write it back to redis once the user code block exits
:param state_id: id of the state to retrieve
:param init_state: user callback to init the newly created state
If not passed then no init except for the id generation is done
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
try:
yield state
finally:
self.set_state(state)

View File

@@ -19,7 +19,7 @@ from config.info import get_deployment_type
from database.model import Company, User
from database.model.queue import Queue
from database.model.task.task import Task
from utilities import safe_get
from tools import safe_get
from utilities.json import dumps
from utilities.threads_manager import ThreadsManager
from version import __version__ as current_version
@@ -237,7 +237,6 @@ class StatisticsReporter:
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
return worker_bll.es_client.search(
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
@@ -280,7 +279,7 @@ class StatisticsReporter:
]
return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline)
for group in Task.aggregate(pipeline)
}

View File

@@ -4,5 +4,4 @@ from .utils import (
update_project_time,
validate_status_change,
split_by,
ParameterKeyEscaper,
)

View File

@@ -0,0 +1,229 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
from boltons import iterutils
from apierrors import errors
from apimodels.tasks import (
HyperParamKey,
HyperParamItem,
ReplaceHyperparams,
Configuration,
)
from bll.task import TaskBLL
from config import config
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
from utilities.parameter_key_escaper import ParameterKeyEscaper
log = config.logger(__file__)
task_bll = TaskBLL()
class HyperParams:
_properties_section = "properties"
@classmethod
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
for task in tasks
}
@classmethod
def _get_params_list(
cls, items: Dict[str, Dict[str, ParamsItem]]
) -> Sequence[dict]:
ret = list(chain.from_iterable(v.values() for v in items.values()))
return [
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
]
@classmethod
def _normalize_params(cls, params: Sequence) -> bool:
"""
Lower case properties section and return True if it is the only section
"""
for p in params:
if p.section.lower() == cls._properties_section:
p.section = cls._properties_section
return all(p.section == cls._properties_section for p in params)
@classmethod
def delete_params(
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
sections = iterutils.bucketize(items, key=attrgetter("section"))
return {
ParameterKeyEscaper.escape(section): {
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
for param in params
}
for section, params in sections.items()
}
@classmethod
def get_configurations(
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
) -> Dict[str, dict]:
only = ["id"]
if names:
only.extend(
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
)
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {
"configuration": [
c.to_proper_dict()
for c in sorted(task.configuration.values(), key=attrgetter("name"))
]
}
for task in tasks
}
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration=Sequence[str]
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
@staticmethod
def _get_task_for_update(
company: str, id: str, allow_all_statuses: bool = False
) -> Task:
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=id)
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task

View File

@@ -0,0 +1,89 @@
from datetime import timedelta, datetime
from time import sleep
from apierrors import errors
from bll.task import ChangeStatusRequest
from config import config
from database.model.task.task import TaskStatus, Task
from utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
class NonResponsiveTasksWatchdog:
threads = ThreadsManager()
class _Settings:
"""
Retrieves watchdog settings from the config file
The properties are not cached so that the updates in
the config file are reflected
"""
_prefix = "services.tasks.non_responsive_tasks_watchdog"
@property
def enabled(self):
return config.get(f"{self._prefix}.enabled", True)
@property
def watch_interval_sec(self):
return config.get(f"{self._prefix}.watch_interval_sec", 900)
@property
def threshold_sec(self):
return config.get(f"{self._prefix}.threshold_sec", 7200)
settings = _Settings()
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:
stopped = cls.cleanup_tasks(
threshold_sec=cls.settings.threshold_sec
)
log.info(f"{stopped} non-responsive tasks stopped")
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
sleep(watch_interval)
@classmethod
def cleanup_tasks(cls, threshold_sec):
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(seconds=threshold_sec)
ref_time = datetime.utcnow() - threshold
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
)
tasks = list(
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
"id", "name", "status", "project", "last_update"
)
)
log.info(f"{len(tasks)} non-responsive tasks found")
if not tasks:
return 0
err_count = 0
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
try:
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
except errors.bad_request.FailedChangingTaskStatus:
err_count += 1
return len(tasks) - err_count

View File

@@ -0,0 +1,201 @@
import itertools
from typing import Sequence, Tuple
import dpath
from apierrors import errors
from database.model.task.task import Task
from tools import safe_get
from utilities.parameter_key_escaper import ParameterKeyEscaper
hyperparams_default_section = "Args"
hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
if default_section is None:
return None, full_name
section, _, name = full_name.partition("/")
if section != tf_define_section:
return default_section, full_name
if not name:
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
return section, name
def _get_full_param_name(param: dict) -> str:
section = param.get("section")
if section != tf_define_section:
return param["name"]
return "/".join((section, param["name"]))
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
removed = 0
if not data:
return removed
if with_sections:
for section, section_data in list(data.items()):
removed += _remove_legacy_params(section_data)
if not section_data:
"""If section is empty after removing legacy params then delete it"""
del data[section]
else:
for key, param in list(data.items()):
if param.get("type") == hyperparams_legacy_type:
removed += 1
del data[key]
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
if not data:
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
return [
param for param in data.values() if param.get("type") == hyperparams_legacy_type
]
def params_prepare_for_save(fields: dict, previous_task: Task = None):
"""
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
and previous_task
and previous_task[new_params_field]
):
previous_data = previous_task.to_proper_dict().get(new_params_field)
removed = _remove_legacy_params(
previous_data, with_sections=default_section is not None
)
if not legacy_params and not removed:
# if we only need to delete legacy fields from the db
# but they are not there then there is no point to proceed
continue
fields_update = {new_params_field: previous_data}
params_unprepare_from_saved(fields_update)
fields.update(fields_update)
for full_name, value in legacy_params.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_params_field, section, name)))
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
def params_unprepare_from_saved(fields, copy_to_legacy=False):
"""
Unescape all section and param names for hyper params and configuration
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
)
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]
for prefix in (
"hyperparams.",
"-hyperparams.",
"configuration.",
"-configuration.",
):
paths = [
_process_path(path) if path.startswith(prefix) else path for path in paths
]
return paths

View File

@@ -1,10 +1,11 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import dpath
import pymongo.results
import six
from mongoengine import Q
@@ -14,6 +15,7 @@ import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from bll.organization import OrgBLL, Tags
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@@ -27,25 +29,38 @@ from database.model.task.task import (
TaskSystemTags,
ArtifactModes,
Artifact,
external_task_types,
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from timing_context import TimingContext
from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
from utilities.parameter_key_escaper import ParameterKeyEscaper
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
org_bll = OrgBLL()
class TaskBLL(object):
threads = ThreadsManager("TaskBLL")
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
@@ -70,25 +85,24 @@ class TaskBLL(object):
@staticmethod
def get_by_id(
company_id,
task_id,
required_status=None,
required_dataset=None,
only_fields=None,
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
only_fields = [only_fields]
else:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
qs = Task.objects(id=task_id, company=company_id)
if only_fields:
qs = (
qs.only(only_fields)
if isinstance(only_fields, string_types)
else qs.only(*only_fields)
)
qs = qs.only(
"status", "input"
) # make sure all fields we rely on here are also returned
task = qs.first()
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@@ -96,17 +110,12 @@ class TaskBLL(object):
if required_status and not task.status == required_status:
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
if required_dataset and required_dataset not in (
entry.dataset for entry in task.input.view.entries
):
raise errors.bad_request.InvalidId(
"not in input view", dataset=required_dataset
)
return task
@staticmethod
def assert_exists(company_id, task_ids, only=None, allow_public=False):
def assert_exists(
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
ids = set(task_ids)
@@ -117,14 +126,13 @@ class TaskBLL(object):
return_dicts=False,
)
if only:
res = q.only(*only)
count = len(res)
else:
count = q.count()
res = q.first()
if count != len(ids):
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
return res
if return_tasks:
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
@@ -166,17 +174,32 @@ class TaskBLL(object):
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
params_dict["execution"] = {}
for legacy_param in ("parameters", "configuration"):
legacy_value = execution_overrides.pop(legacy_param, None)
if legacy_value is not None:
params_dict["execution"] = legacy_value
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
params_prepare_for_save(params_dict, previous_task=task)
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
@@ -203,27 +226,59 @@ class TaskBLL(object):
if task.output
else None,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
cls.validate(new_task)
new_task.save()
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
return new_task
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project and not Project.get_for_writing(
company=task.company, id=task.project
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
@@ -263,7 +318,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = Task.aggregate(*pipeline)
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
@@ -312,10 +367,12 @@ class TaskBLL(object):
return "__".join((op, "last_metrics") + path)
for path, value in last_scalar_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
else:
extra_updates[op_path("set", *path)] = value
if last_events is not None:
@@ -327,7 +384,7 @@ class TaskBLL(object):
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
)
for metric_key, metric_data in last_events.items()
}
@@ -575,81 +632,35 @@ class TaskBLL(object):
return [a.key for a in added], [a.key for a in updated]
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start_non_responsive_tasks_watchdog(cls):
log = config.logger("non_responsive_tasks_watchdog")
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(
seconds=config.get(
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
)
)
watch_interval = config.get(
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900
)
sleep(watch_interval)
while not ThreadsManager.terminating:
try:
ref_time = datetime.utcnow() - threshold
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
)
tasks = list(
Task.objects(
status__in=relevant_status, last_update__lt=ref_time
).only("id", "name", "status", "project", "last_update")
)
if tasks:
log.info(f"Stopping {len(tasks)} non-responsive tasks")
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
log.info(f"Done")
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
sleep(watch_interval)
@staticmethod
def get_aggregated_project_execution_parameters(
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[str]]:
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": company_id,
"execution.parameters": {"$exists": True, "$gt": {}},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
{"$unwind": "$parameters"},
{"$group": {"_id": "$parameters.k"}},
{"$sort": {"_id": 1}},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
@@ -666,7 +677,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(Task.aggregate(*pipeline), None)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
@@ -675,7 +686,12 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r["_id"])
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))

View File

@@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from boltons.dictutils import OneToOne
from apierrors import errors
from database.errors import translate_errors_context
@@ -172,26 +171,3 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@@ -35,14 +35,21 @@ class SetFieldsResolver:
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = set_fields
self.fields = {
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:

View File

@@ -21,6 +21,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.auth import User
from database.model.company import Company
from database.model.project import Project
from database.model.queue import Queue
from database.model.task.task import Task
from redis_manager import redman
@@ -33,8 +34,8 @@ log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es if es is not None else es_factory.connect("workers")
self.redis = redis if redis is not None else redman.connection("workers")
self.es_client = es or es_factory.connect("workers")
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@property
@@ -49,6 +50,7 @@ class WorkerBLL:
ip: str = "",
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@@ -58,6 +60,7 @@ class WorkerBLL:
:param ip: the real ip of the worker
:param queues: queues reported as being monitored by the worker
:param timeout: registration expiration timeout in seconds
:param tags: a list of tags for this worker
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
:return: worker entry instance
"""
@@ -91,6 +94,7 @@ class WorkerBLL:
register_time=now,
register_timeout=timeout,
last_activity_time=now,
tags=tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
@@ -113,12 +117,15 @@ class WorkerBLL:
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
:param company_id: worker's company ID
:param user_id: user_id ID under which this worker is running
:param ip: worker IP
:param report: the report itself
:param tags: tags for this worker
:raise bad_request.InvalidTaskId: the reported task was not found
:return: worker entry instance
"""
@@ -129,6 +136,9 @@ class WorkerBLL:
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
if report.machine_stats:
self._log_stats_to_es(
company_id=company_id,
@@ -146,6 +156,7 @@ class WorkerBLL:
if not report.task:
entry.task = None
entry.project = None
else:
with translate_errors_context():
query = dict(id=report.task, company=company_id)
@@ -160,6 +171,12 @@ class WorkerBLL:
raise bad_request.InvalidTaskId(**query)
entry.task = IdNameEntry(id=task.id, name=task.name)
entry.project = None
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.last_report_time = now
except APIError:
raise
@@ -223,7 +240,7 @@ class WorkerBLL:
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(*projection)
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
task_ids = task_ids.union(
filter(
@@ -369,7 +386,6 @@ class WorkerBLL:
def make_doc(category, metric, variant, value) -> dict:
return dict(
_index=es_index,
_type="stat",
_source=dict(
timestamp=timestamp,
worker=worker,

View File

@@ -25,7 +25,6 @@ class WorkerStats:
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
@@ -53,7 +52,7 @@ class WorkerStats:
res = self._search_company_stats(company_id, es_req)
if not res["hits"]["total"]:
if not res["hits"]["total"]["value"]:
raise bad_request.WorkerStatsNotFound(
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
)
@@ -87,7 +86,7 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{request.interval}s",
"fixed_interval": f"{request.interval}s",
"min_doc_count": 1,
},
"aggs": {
@@ -216,7 +215,7 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{interval}s",
"fixed_interval": f"{interval}s",
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}

View File

@@ -1,5 +1,6 @@
import logging
import os
import platform
from functools import reduce
from os import getenv
from os.path import expandvars
@@ -15,7 +16,7 @@ from pyparsing import (
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATH_SEP = ":"
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
@@ -57,7 +58,7 @@ class BasicConfig:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
@@ -77,7 +78,7 @@ class BasicConfig:
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
]
if invalid:
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
return [path for path in paths if path.is_dir()]
def _load(self, verbose=True):

View File

@@ -26,6 +26,17 @@
check_max_version: false
}
pre_populate {
enabled: false
zip_files: ["/path/to/export.zip"]
fail_on_error: false
# artifacts_path: "/mnt/fileserver"
}
# time in seconds to take an exclusive lock to init es and mongodb
# not including the pre_populate
db_init_timout: 120
mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
@@ -34,11 +45,16 @@
aggregate {
allow_disk_use: true
}
}
pre_populate {
enabled: false
zip_file: "/path/to/export.zip"
fail_on_error: false
elastic {
probing {
# settings for inital probing of elastic connection
max_retries: 4
timeout: 30
}
upgrade_monitoring {
v16_migration_verification: true
}
}

View File

@@ -1,21 +1,21 @@
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9211}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9211}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"

View File

@@ -13,17 +13,21 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}
}
}

View File

@@ -0,0 +1,16 @@
fixed_users {
guest {
enabled: false
default_company: "025315a9321f49f8be07f5ac48fbcf92"
name: "Guest"
username: "guest"
password: "guest"
# Allow access only to the following endpoints when using user/pass credentials
allow_endpoints: [
"auth.login"
]
}
}

View File

@@ -6,4 +6,8 @@ ignore_iteration {
# max number of concurrent queries to ES when calculating events metrics
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
max_metrics_concurrency: 4
events_retrieval {
state_expiration_sec: 3600
}

View File

@@ -0,0 +1,3 @@
tags_cache {
expiration_seconds: 3600
}

View File

@@ -0,0 +1,8 @@
# Order of featured projects, by name or ID
featured_order: [
# {id: "<project-id>"}
# OR
# {name: "<project-name>"}
# OR
# {name_regex: "<python-regex>"}
]

View File

@@ -1,4 +1,6 @@
non_responsive_tasks_watchdog {
enabled: true
# In-progress tasks older than this value in seconds will be stopped by the watchdog
threshold_sec: 7200

View File

@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
def get_default_company():
return config.get("apiserver.default_company")
missed_es_upgrade = False
es_connection_error = False

View File

@@ -79,6 +79,10 @@ def get_entries():
return _entries
def get_hosts():
return [entry.host for entry in get_entries()]
def get_aliases():
return [entry.alias for entry in get_entries()]

View File

@@ -14,6 +14,9 @@ from mongoengine import (
DictField,
DynamicField,
)
from mongoengine.fields import key_not_string, key_starts_with_dollar
NoneType = type(None)
class LengthRangeListField(ListField):
@@ -125,17 +128,39 @@ def contains_empty_key(d):
return True
class SafeMapField(MapField):
class DictValidationMixin:
"""
DictField validation in MongoEngine requires default alias and permissions to access DB version:
https://github.com/MongoEngine/mongoengine/issues/2239
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
"""
def _safe_validate(self: DictField, value):
if not isinstance(value, dict):
self.error("Only dictionaries may be used in a DictField")
if key_not_string(value):
msg = "Invalid dictionary key - documents must have only string keys"
self.error(msg)
if key_starts_with_dollar(value):
self.error(
'Invalid dictionary key name - keys may not startswith "$" characters'
)
super(DictField, self).validate(value)
class SafeMapField(MapField, DictValidationMixin):
def validate(self, value):
super(SafeMapField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField):
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
super(SafeDictField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
SortedListField that does not raise an error in case items are not comparable
(in which case they will be sorted by their string representation)
"""
def to_mongo(self, *args, **kwargs):
try:
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None:
def key(v): return str(itemgetter(self._ordering)(v))
def key(v):
return str(itemgetter(self._ordering)(v))
else:
key = str
return sorted(value, key=key, reverse=self._order_reverse)

View File

@@ -32,6 +32,8 @@ class Role(object):
""" Company user """
annotator = "annotator"
""" Annotator with limited access"""
guest = "guest"
""" Guest user. Read Only."""
@classmethod
def get_system_roles(cls) -> set:
@@ -43,6 +45,7 @@ class Role(object):
class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
last_used = DateTimeField()

View File

@@ -1,14 +1,15 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from boltons.iterutils import first
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
from apierrors import errors
from apierrors.base import BaseError
from config import config
from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
@@ -34,7 +35,12 @@ class AuthDocument(Document):
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
def to_proper_dict(
self: Union["ProperDictMixin", Document],
strip_private=True,
only=None,
extra_dict=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
@@ -71,6 +77,8 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
class QueryParameterOptions(object):
def __init__(
self,
@@ -91,11 +99,48 @@ class GetMixin(PropsMixin):
self.list_fields = list_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_default = "in"
_ops = {"not": "nin"}
_next = _default
def __init__(self, legacy=False):
self._legacy = legacy
def key(self, v):
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"]
elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
return None
next_ = self._next
self._next = self._default
return next_
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
get_all_query_options = QueryParameterOptions()
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
cls: Union["GetMixin", Document],
company,
id,
*,
_only=None,
include_public=False,
**kwargs,
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
@@ -162,17 +207,7 @@ class GetMixin(PropsMixin):
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
query &= cls.get_list_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@@ -216,6 +251,47 @@ class GetMixin(PropsMixin):
return query & RegexQ(**dict_query)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
is included.
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
)
if not allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
@@ -272,6 +348,20 @@ class GetMixin(PropsMixin):
return []
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def split_projection(
cls, projection: Sequence[str]
) -> Tuple[Collection[str], Collection[str]]:
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
@classmethod
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
@@ -409,7 +499,27 @@ class GetMixin(PropsMixin):
)
@classmethod
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
:param query: Optional query object (mongoengine.Q).
:param projection: A list of projection fields.
:return: A list of documents matching the query.
"""
q = get_company_or_none_constraint()
_query = (q & query) if query else q
return cls._get_many_no_company(query=_query, override_projection=projection)
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
query: Q,
parameters=None,
override_projection=None,
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -429,7 +539,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if search_text:
@@ -437,13 +549,14 @@ class GetMixin(PropsMixin):
if order_by:
# add ordering
qs = qs.order_by(*order_by)
if only:
if include:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
qs = qs.only(*include)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
@@ -460,6 +573,8 @@ class GetMixin(PropsMixin):
"""
Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order.
If the first order field is a user defined parameter (either from execution.parameters,
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
@@ -477,7 +592,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
query_sets = [cls.objects(query)]
if order_by:
@@ -500,20 +617,29 @@ class GetMixin(PropsMixin):
query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if only:
if include:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
else:
exclude = set(cls.get_exclude_fields())
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
query_sets = [qs.only(*include) for qs in query_sets]
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
@@ -524,7 +650,8 @@ class GetMixin(PropsMixin):
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
)
if len(ret) >= page_size:
break
@@ -593,7 +720,13 @@ class UpdateMixin(object):
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
def safe_update(
cls: Union["UpdateMixin", Document],
company_id,
id,
partial_update_dict,
injected_update=None,
):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
@@ -610,7 +743,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
cls: Union["DbModelMixin", Document],
pipeline: Sequence[dict],
allow_disk_use=None,
**kwargs,
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
@@ -625,7 +761,32 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
return cls.objects.aggregate(pipeline, **kwargs)
@classmethod
def set_public(
cls: Type[Document],
company_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, unset__company=1)
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
return {"updated": cls.objects(id__in=ids).update(**update)}
def validate_id(cls, company, **kwargs):
@@ -647,5 +808,5 @@ def validate_id(cls, company, **kwargs):
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
)

View File

@@ -1,8 +1,9 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from mongoengine import Document, StringField, DateTimeField, BooleanField
from database import Database, strict
from database.fields import StrippedStringField, SafeDictField
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
@@ -12,46 +13,63 @@ from database.model.user import User
class Model(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"parent",
"project",
"task",
("company", "framework"),
("company", "name"),
("company", "user"),
{
'name': '%s.model.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$parent',
'$task',
'$project',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'parent': 5,
'task': 3,
'project': 3,
}
}
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
"default_language": "english",
"weights": {
"name": 10,
"id": 10,
"comment": 10,
"parent": 5,
"task": 3,
"project": 3,
},
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"user",
"project",
"task",
"parent",
),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field='Model', required=False)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default='', user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default="", user_set_allowed=True)
framework = StringField()
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)

View File

@@ -1,11 +1,14 @@
from mongoengine import MapField, IntField
from database.fields import NoneType, UnionField, SafeMapField
class ModelLabels(MapField):
class ModelLabels(SafeMapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
super(ModelLabels, self).__init__(
field=UnionField(types=(int, NoneType)), *args, **kwargs
)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
non_empty_values = list(filter(None, value.values()))
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
self.error("Same label id appears more than once in model labels")

View File

@@ -1,7 +1,7 @@
from mongoengine import StringField, DateTimeField, ListField
from mongoengine import StringField, DateTimeField, IntField
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import AttributedDocument
from database.model.base import GetMixin
@@ -17,12 +17,13 @@ class Project(AttributedDocument):
"db_alias": Database.backend,
"strict": strict,
"indexes": [
("company", "name"),
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
"default_language": "english",
"weights": {"name": 10, "id": 10, "description": 10},
}
},
],
}
@@ -35,7 +36,11 @@ class Project(AttributedDocument):
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True))
system_tags = ListField(StringField(required=True))
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()
featured = IntField(default=9999)
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)

View File

@@ -4,11 +4,10 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -7,6 +7,10 @@ from database import Database, strict
from database.model import DbModelMixin
class SettingKeys:
server__uuid = "server.uuid"
class Settings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
@@ -47,7 +51,7 @@ class Settings(DbModelMixin, Document):
""" Adds a new key/value settings. Fails if key already exists. """
key = key.strip(sep)
try:
res = Settings(key=key, value=value).save(force_insert=True)
res = cls(key=key, value=value).save(force_insert=True)
return bool(res)
except NotUniqueError:
return False

View File

@@ -18,7 +18,7 @@ from database.fields import (
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.base import ProperDictMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
@@ -49,13 +49,13 @@ class TaskSystemTags(object):
development = "development"
class Script(EmbeddedDocument):
class Script(EmbeddedDocument, ProperDictMixin):
binary = StringField(default="python")
repository = StringField(required=True)
repository = StringField(default="")
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
entry_point = StringField(default="")
working_dir = StringField()
requirements = SafeDictField()
diff = StringField()
@@ -84,7 +84,23 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class ParamsItem(EmbeddedDocument, ProperDictMixin):
section = StringField(required=True)
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
@@ -100,9 +116,29 @@ class Execution(EmbeddedDocument, ProperDictMixin):
class TaskType(object):
training = "training"
testing = "testing"
inference = "inference"
data_processing = "data_processing"
application = "application"
monitor = "monitor"
controller = "controller"
optimizer = "optimizer"
service = "service"
qc = "qc"
custom = "custom"
external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -110,6 +146,13 @@ class Task(AttributedDocument):
"created",
"started",
"completed",
"parent",
"project",
("company", "name"),
("company", "user"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
@@ -134,6 +177,12 @@ class Task(AttributedDocument):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
id = StringField(primary_key=True)
name = StrippedStringField(
@@ -152,14 +201,19 @@ class Task(AttributedDocument):
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
output: Output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script, default=Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
from database import Database, strict
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)

View File

@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
)
dst[path_part] = [
copy_path(path_parts[depth + 1:], s, d)
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
@@ -96,6 +96,7 @@ class _ProxyManager:
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
exclusion_prefix = "-"
@property
def doc_projection(self):
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
[]
) # Projection information for reference fields (used in join queries)
for field in projection:
field_ = field.lstrip(self.exclusion_prefix)
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
if not field_.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
if field_ == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
subfield = field_[len(ref_field) :]
if not subfield.startswith(SEP):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
ref_projection_info.append(
(
ref_field,
ref_field_cls,
("" if field_[0] == field[0] else self.exclusion_prefix)
+ subfield[1:],
)
)
break
else:
# Not a reference field, just add to the top-level projection
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
orig_field = field
if field.endswith(".*"):
field = field[:-2]
if not field:
if not field.lstrip(self.exclusion_prefix):
raise errors.bad_request.InvalidFields(
field=orig_field, object=doc_cls.__name__
)
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(
doc_cls, doc_projection.difference(ref_projection).union({"id"})
doc_cls, doc_projection.difference(ref_projection)
)
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
f
for f in doc_projection
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
not in doc_cls.get_fields()
]
if invalid_fields:
raise errors.bad_request.InvalidFields(
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
doc_projection.add(field)
doc_projection = list(doc_projection)
# If there are include fields (not only exclude) then add an id field
if (
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
and "id" not in doc_projection
):
doc_projection.append("id")
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
]
if items:
def do_projection(item):
ref_field_name, data, ids = item

View File

@@ -1,8 +1,14 @@
import copy
import re
from typing import Union
from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
from mongoengine.queryset.visitor import (
QueryCompilerVisitor,
SimplificationVisitor,
QCombination,
QNode,
)
class RegexWrapper(object):
@@ -17,17 +23,16 @@ class RegexWrapper(object):
class RegexMixin(object):
def to_query(self, document):
def to_query(self: Union["RegexMixin", QNode], document):
query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document))
return query
def _combine(self, other, operation):
def _combine(self: Union["RegexMixin", QNode], other, operation):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
if getattr(other, "empty", True):
return self
if self.empty:

View File

@@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
res[field] = None
continue
if desc:
if callable(desc):
if issubclass(desc, Document):
if not desc.objects(id=value).only("id"):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
elif callable(desc):
try:
desc(value)
except TypeError:
raise ParseCallError(f"expecting {desc.__name__}", field=field)
except Exception as ex:
raise ParseCallError(str(ex), field=field)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc
):
raise ParseCallError(
"expecting %s" % desc.__name__, field=field
)
if issubclass(desc, Document) and not desc.objects(id=value).only(
"id"
):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
res[field] = value
return res

View File

@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import requests
from pathlib import Path
from typing import Optional, Sequence
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from elasticsearch import Elasticsearch
HERE = Path(__file__).resolve().parent
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount('http://', adapter)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def apply_mappings_to_host(host: str):
def _send_mapping(f):
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
es_server = host
url = f"{es_server}/_template/{f.stem}"
session.delete(url)
r = session.post(
url,
headers={"Content-Type": "application/json"},
data=json.dumps(data),
)
return {"mapping": f.stem, "result": r.text}
template_name = f.stem
res = es.indices.put_template(template_name, body=data)
return {"mapping": template_name, "result": res}
p = HERE / "mappings"
return [
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
]
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
return [_send_template(f) for f in files]
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("hosts", nargs="+")
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
parser.add_argument(
"--hosts",
nargs="+",
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
)
return parser.parse_args()
def main():
for host in parse_args().hosts:
print(">>>>> Applying mapping to " + host)
res = apply_mappings_to_host(host)
print(res)
args = parse_args()
print(">>>>> Applying mapping to " + str(args.hosts))
res = apply_mappings_to_cluster(args.hosts, args.key)
print(res)
if __name__ == "__main__":

View File

@@ -1,8 +1,10 @@
from furl import furl
from time import sleep
from elasticsearch import Elasticsearch, exceptions
import es_factory
from config import config
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
from elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
@@ -15,13 +17,48 @@ class MissingElasticConfiguration(Exception):
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
class ElasticConnectionError(Exception):
"""
Exception when could not connect to elastic during init
"""
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
pass
def check_elastic_empty() -> bool:
"""
Check for elasticsearch connection
Use probing settings and not the default es cluster ones
so that we can handle correctly the connection rejects due to ES not fully started yet
:return:
"""
cluster_conf = es_factory.get_cluster_config("events")
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
timeout = config.get("apiserver.elastic.probing.timeout", 30)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
except exceptions.ConnectionError:
if retry >= max_retries - 1:
raise ElasticConnectionError()
log.warn(
f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
)
sleep(timeout)
def init_es_data():
for name in es_factory.get_all_cluster_names():
cluster_conf = es_factory.get_cluster_config(name)
hosts_config = cluster_conf.get("hosts")
if not hosts_config:
raise MissingElasticConfiguration(f"for cluster '{name}'")
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
log.info(res)

View File

@@ -1,27 +0,0 @@
{
"template": "events-*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_default_": {
"_source": {
"enabled": true
},
"_routing": {
"required": true
},
"properties": {
"@timestamp": { "type": "date" },
"task": { "type": "keyword" },
"type": { "type": "keyword" },
"worker": { "type": "keyword" },
"timestamp": { "type": "date" },
"iter": { "type": "long" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" }
}
}
}
}

View File

@@ -0,0 +1,40 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"@timestamp": {
"type": "date"
},
"task": {
"type": "keyword"
},
"type": {
"type": "keyword"
},
"worker": {
"type": "keyword"
},
"timestamp": {
"type": "date"
},
"iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
}
}
}
}

View File

@@ -0,0 +1,15 @@
{
"index_patterns": "events-log-*",
"order": 1,
"mappings": {
"properties": {
"msg": {
"type": "text",
"index": false
},
"level": {
"type": "keyword"
}
}
}
}

View File

@@ -0,0 +1,12 @@
{
"index_patterns": "events-plot-*",
"order": 1,
"mappings": {
"properties": {
"plot_str": {
"type": "text",
"index": false
}
}
}
}

View File

@@ -0,0 +1,14 @@
{
"index_patterns": "events-training_debug_image-*",
"order": 1,
"mappings": {
"properties": {
"key": {
"type": "keyword"
},
"url": {
"type": "keyword"
}
}
}
}

View File

@@ -1,12 +0,0 @@
{
"template": "events-log-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"msg": { "type":"text", "index": false },
"level": { "type":"keyword" }
}
}
}
}

View File

@@ -1,11 +0,0 @@
{
"template": "events-plot-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"plot_str": { "type":"text", "index": false }
}
}
}
}

View File

@@ -1,12 +0,0 @@
{
"template": "events-training_debug_image-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"key": { "type": "keyword" },
"url": { "type": "keyword" }
}
}
}
}

View File

@@ -1,27 +0,0 @@
{
"template": "queue_metrics_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"metrics": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
}
}
}
}
}

View File

@@ -1,23 +0,0 @@
{
"template": "worker_stats_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"stat": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": { "type": "date" },
"worker": { "type": "keyword" },
"category": { "type": "keyword" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" },
"unit": { "type": "keyword" },
"task": { "type": "keyword" }
}
}
}
}

View File

@@ -0,0 +1,25 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
}
}
}
}

View File

@@ -0,0 +1,37 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"worker": {
"type": "keyword"
},
"category": {
"type": "keyword"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
}
}
}
}

View File

@@ -65,6 +65,10 @@ def connect(cluster_name):
return _instances[cluster_name]
def get_all_cluster_names():
return list(config.get("hosts.elastic"))
def get_cluster_config(cluster_name):
"""
Returns cluster config for the specified cluster path

View File

@@ -1,9 +1,11 @@
from pathlib import Path
from typing import Sequence, Union
from config import config
from config.info import get_default_company
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
@@ -11,59 +13,76 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__)
def _pre_populate(company_id: str, zip_file: str):
if not zip_file or not Path(zip_file).is_file():
msg = f"Invalid pre-populate zip file: {zip_file}"
if config.get("apiserver.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
log.info(f"Pre-populating using {zip_file}")
PrePopulate.import_from_zip(
zip_file,
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
)
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
if isinstance(zip_files, str):
zip_files = [zip_files]
for p in map(Path, zip_files):
if p.is_file():
yield p
if p.is_dir():
yield from p.glob("*.zip")
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
PrePopulate.update_featured_projects_order()
def init_mongo_data():
try:
empty_dbs = _apply_migrations(log)
_apply_migrations(log)
_ensure_uuid()
company_id = _ensure_company(log)
company_id = _ensure_company(get_default_company(), "trains", log)
_ensure_default_queue(company_id)
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
if not zip_file or not Path(zip_file).is_file():
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
fixed_mode = FixedUser.enabled()
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
for user, credentials in config.get("secure.credentials", {}).items():
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
PrePopulate.import_from_zip(zip_file, user_id=user_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id, log=log)
if FixedUser.enabled():
if fixed_mode:
log.info("Fixed users mode is enabled")
FixedUser.validate()
if FixedUser.guest_enabled():
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, company_id, log=log)
ensure_fixed_user(user, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:

View File

@@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool:
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
def get_last_server_version() -> Version:
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
return previous_versions[0] if previous_versions else Version("0.0.0")
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
@@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
try:
new_scripts = {
@@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
).save()
log.info("Finished mongodb migrations")
return empty_dbs

View File

@@ -1,31 +1,384 @@
import hashlib
import importlib
import os
import re
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from io import BytesIO
from itertools import chain
from operator import attrgetter
from os.path import splitext
from typing import List, Optional, Any, Type, Set, Dict
from pathlib import Path
from typing import (
Optional,
Any,
Type,
Set,
Dict,
Sequence,
Tuple,
BinaryIO,
Union,
Mapping,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from tqdm import tqdm
from boltons.iterutils import chunked_iter
from furl import furl
from mongoengine import Q
from bll.event import EventBLL
from bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
hyperparams_legacy_type,
)
from config import config
from config.info import get_default_company
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options
from tools import safe_get
from utilities import json
from .user import _ensure_backend_user
class PrePopulate:
@classmethod
def export_to_zip(
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
):
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
cls._export(zfile, experiments, projects)
event_bll = EventBLL()
events_file_suffix = "_events"
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
self.file = file
self.empty = True
def __enter__(self):
self._write("[")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._write("\n]")
def _write(self, data: str):
self.file.write(data.encode("utf-8"))
def write(self, line: str):
if not self.empty:
self._write(",")
self._write("\n" + line)
self.empty = False
@staticmethod
def _get_last_update_time(entity) -> datetime:
return getattr(entity, "last_update", None) or getattr(entity, "created")
@classmethod
def import_from_zip(cls, filename: str, user_id: str = None):
def _check_for_update(
cls, map_file: Path, entities: dict, metadata_hash: str
) -> Tuple[bool, Sequence[str]]:
if not map_file.is_file():
return True, []
files = []
try:
map_data = json.loads(map_file.read_text())
files = map_data.get("files", [])
for file in files:
if not Path(file).is_file():
return True, files
new_times = {
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
for item in chain.from_iterable(entities.values())
}
old_times = map_data.get("entities", {})
if set(new_times.keys()) != set(old_times.keys()):
return True, files
for id_, new_timestamp in new_times.items():
if new_timestamp != old_times[id_]:
return True, files
if metadata_hash != map_data.get("metadata_hash", ""):
return True, files
except Exception as ex:
print("Error reading map file. " + str(ex))
return True, files
return False, files
@classmethod
def _write_update_file(
cls,
map_file: Path,
entities: dict,
created_files: Sequence[str],
metadata_hash: str,
):
map_file.write_text(
json.dumps(
dict(
files=created_files,
entities={
entity.id: cls._get_last_update_time(entity)
for entity in chain.from_iterable(entities.values())
},
metadata_hash=metadata_hash,
)
)
)
@staticmethod
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
def is_fileserver_link(a: str) -> bool:
a = a.lower()
if a.startswith("https://files."):
return True
if a.startswith("http"):
parsed = urlparse(a)
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
"8081"
):
return True
return False
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
print(
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
)
return fileserver_links
@classmethod
def export_to_zip(
cls,
filename: str,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
artifacts_path: str = None,
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
) -> Sequence[str]:
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
file = Path(filename)
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated:
print(f"There are no updates from the last export")
return old_files
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
file.replace(file_with_hash)
created_files = [str(file_with_hash)]
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
cls._write_update_file(
map_file,
entities=entities,
created_files=created_files,
metadata_hash=metadata_hash,
)
return created_files
@classmethod
def import_from_zip(
cls,
filename: str,
artifacts_path: str,
company_id: Optional[str] = None,
user_id: str = "",
user_name: str = "",
):
metadata = None
with ZipFile(filename) as zfile:
cls._import(zfile, user_id)
try:
with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read())
meta_public = metadata.get("public")
if company_id is None and meta_public is not None:
company_id = "" if meta_public else get_default_company()
if not user_id:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
user_id, user_name = meta_user_id, meta_user_name
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
for file_info in reader.filelist:
if file_info.orig_filename == task_file:
with reader.open(file_info) as f:
content = cls._upgrade_tasks(f)
else:
content = reader.read(file_info)
writer.writestr(file_info, content)
hash_.update(content)
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
temp_file.unlink()
return []
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
temp_file.replace(new_file)
upadated = [str(new_file)]
artifacts_file = file.with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
the fields are serialized back in the same order as they were in the original file
"""
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@classmethod
def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", [])
def get_index(p: Project):
for index, entry in enumerate(featured_order):
if (
entry.get("id", None) == p.id
or entry.get("name", None) == p.name
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
):
return index
return 999
for project in Project.get_many_public(projection=["id", "name"]):
featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[List[str]]
) -> List[Any]:
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> Sequence[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items}
@@ -43,20 +396,24 @@ class PrePopulate:
@classmethod
def _resolve_entities(
cls, experiments: List[str] = None, projects: List[str] = None
cls,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
from database.model.project import Project
from database.model.task.task import Task
entities = defaultdict(set)
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...")
objs = Task.objects(
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
system_tags__nin=[EntityVisibility.archived.value],
)
if task_statuses:
query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments:
@@ -69,85 +426,289 @@ class PrePopulate:
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
for task in entities[Task]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if model_ids:
print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids)))
return entities
@classmethod
def _cleanup_task(cls, task):
from database.model.task.task import TaskStatus
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
if not tags:
return tags
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
task.completed = None
task.started = None
if task.execution:
task.execution.model = None
task.execution.model_desc = None
task.execution.model_labels = None
if task.output:
task.output.model = None
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
task.status = TaskStatus.created
@classmethod
def _cleanup_task(cls, task: Task):
task.comment = "Auto generated by Allegro.ai"
task.created = datetime.utcnow()
task.last_iteration = 0
task.last_update = task.created
task.status_changed = task.created
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
task.output.destination = None
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
from database.model.task.task import Task
if entity_cls == Task:
cls._cleanup_task(entity)
elif entity_cls == Model:
cls._cleanup_model(entity)
elif entity == Project:
cls._cleanup_project(entity)
@classmethod
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
try:
for item in items:
item.update(upsert=False, tags=sorted(item.tags + [tag]))
except AttributeError:
pass
@classmethod
def _export_task_events(
cls, task: Task, base_filename: str, writer: ZipFile, hash_
) -> Sequence[str]:
artifacts = []
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
print(f"Writing task events into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
scroll_id = None
while True:
res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id
)
if not res.events:
break
scroll_id = res.next_scroll_id
for event in res.events:
if event.get("type") == "training_debug_image":
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
w.write(json.dumps(event))
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
return artifacts
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
if not (url and url.lower().startswith("s3://")):
return url
try:
fixed = furl(url)
fixed.scheme = "https"
fixed.host += ".s3.amazonaws.com"
return fixed.url
except Exception as ex:
print(f"Failed processing link {url}. " + str(ex))
return url
@classmethod
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
):
if entity_cls == Task:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
]
if entity_cls == Model:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
return []
@classmethod
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
if not task.execution.artifacts:
return []
for a in task.execution.artifacts:
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
return [
a.uri
for a in task.execution.artifacts
if a.mode == ArtifactModes.output and a.uri
]
@classmethod
def _export_artifacts(
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
):
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
for path in unique_paths:
path = path.lstrip("/")
full_path = os.path.join(artifacts_path, path)
if os.path.isfile(full_path):
writer.write(full_path, path)
else:
print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod
def _export(
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
):
entities = cls._resolve_entities(experiments, projects)
for cls_, items in entities.items():
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash
"""
artifacts = []
now = datetime.utcnow()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
if not items:
continue
filename = f"{cls_.__module__}.{cls_.__name__}.json"
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with writer.open(filename, "w") as f:
f.write("[\n".encode("utf-8"))
last = len(items) - 1
for i, item in enumerate(items):
cls._cleanup_entity(cls_, item)
f.write(item.to_json().encode("utf-8"))
if i != last:
f.write(",".encode("utf-8"))
f.write("\n".encode("utf-8"))
f.write("]\n".encode("utf-8"))
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
if tag_entities:
cls._add_tag(items, now.strftime(cls.export_tag))
return artifacts
@staticmethod
def _import(reader: ZipFile, user_id: str = None):
for file_info in reader.filelist:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
def json_lines(file: BinaryIO):
for line in file:
clean = (
line.decode("utf-8")
.rstrip("\r\n")
.strip()
.lstrip("[")
.rstrip(",]")
.strip()
)
if not clean:
continue
yield clean
with reader.open(file_info) as f:
for item in tqdm(
f.readlines(),
desc=f"Writing {cls_.__name__.lower()}s into database",
unit="doc",
):
item = (
item.decode("utf-8")
.strip()
.lstrip("[")
.rstrip("]")
.rstrip(",")
.strip()
)
if not item:
continue
doc = cls_.from_json(item)
if user_id is not None and hasattr(doc, "user"):
doc.user = user_id
doc.save(force_insert=True)
@classmethod
def _import(
cls,
reader: ZipFile,
company_id: str = "",
user_id: str = None,
metadata: Mapping[str, Any] = None,
):
"""
Import entities and events from the zip file
Start from entities since event import will require the tasks already in DB
"""
event_file_ending = cls.events_file_suffix + ".json"
entity_files = (
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
event_files = (
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
)
for files, reader_func in (
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
(event_files, cls._import_events),
):
for file_info in files:
with reader.open(file_info) as f:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
reader_func(f, full_name, company_id, user_id)
@classmethod
def _import_entity(
cls,
f: BinaryIO,
full_name: str,
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
):
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database")
override_project_count = 0
for item in cls.json_lines(f):
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, Project):
override_project_name = metadata.get("project_name", None)
if override_project_name:
if override_project_count:
override_project_name = (
f"{override_project_name} {override_project_count + 1}"
)
override_project_count += 1
doc.name = override_project_name
doc.logo_url = metadata.get("logo_url", None)
doc.logo_blob = metadata.get("logo_blob", None)
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
)
doc.save()
if isinstance(doc, Task):
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
@classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
)

View File

@@ -9,28 +9,34 @@ from database.model.user import User
from service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
ensure_credentials = {"key", "secret"}.issubset(user_data)
if ensure_credentials:
user = AuthUser.objects(
credentials__match=Credentials(
key=user_data["key"], secret=user_data["secret"]
)
).first()
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
key, secret = user_data.get("key"), user_data.get("secret")
if not (key and secret):
credentials = None
else:
creds = Credentials(key=key, secret=secret)
user = AuthUser.objects(credentials__match=creds).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
credentials = [] if revoke else [creds]
user_id = user_data.get("id", f"__{user_data['name']}__")
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_data.get("id", f"__{user_data['name']}__"),
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
if ensure_credentials
else None,
credentials=credentials,
)
user.save()
@@ -52,23 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
if User.objects(id=user.user_id).first():
def ensure_fixed_user(user: FixedUser, log: Logger):
if User.objects(company=user.company, id=user.user_id).first():
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
given_name, _, family_name = user.name.partition(" ")
User(
id=user.user_id,
company=company_id,
name=user.name,
given_name=given_name,
family_name=family_name,
).save()
return _ensure_backend_user(user.user_id, user.company, user.name)

View File

@@ -3,21 +3,18 @@ from uuid import uuid4
from bll.queue import QueueBLL
from config import config
from config.info import get_default_company
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
from database.model.settings import Settings, SettingKeys
log = config.logger(__file__)
def _ensure_company(log: Logger):
company_id = get_default_company()
def _ensure_company(company_id, company_name, log: Logger):
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
@@ -37,4 +34,4 @@ def _ensure_default_queue(company):
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))

View File

@@ -0,0 +1,58 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
def migrate_auth(db: Database):
"""
Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
_drop_all_indices_from_collections(db, ["user"])
def migrate_backend(db: Database):
"""
1. Sort tags and system tags
2. Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
fields = ("tags", "system_tags")
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
for collection_name in ("task", "model", "project", "queue"):
collection = db[collection_name]
for doc in collection.find(filter=query, projection=fields):
update = {
field: sorted(doc[field])
for field in fields
if doc.get(field)
}
if update:
collection.update_one({"_id": doc["_id"]}, {"$set": update})
_drop_all_indices_from_collections(
db,
[
"company",
"model",
"project",
"queue",
"settings",
"task",
"task__trash",
"user",
"versions",
],
)

View File

@@ -0,0 +1,36 @@
from pymongo.database import Database, Collection
from bll.task.param_utils import (
hyperparams_legacy_type,
hyperparams_default_section,
split_param_name,
)
from tools import safe_get
def migrate_backend(db: Database):
hyperparam_fields = ("execution.parameters", "hyperparams")
configuration_fields = ("execution.model_desc", "configuration")
collection: Collection = db["task"]
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
set_commands = {}
for (old_field, new_field), default_section in zip(
(hyperparam_fields, configuration_fields),
(hyperparams_default_section, None),
):
legacy = safe_get(doc, old_field, separator=".")
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_field, section, name)))
# if safe_get(doc, new_path) is not None:
# continue
new_value = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_value["section"] = section
set_commands[".".join(new_path)] = new_value
if set_commands:
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})

View File

@@ -1,7 +1,8 @@
attrs>=19.1.0
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
elasticsearch>=5.0.0,<6.0.0
elasticsearch>=7.0.0,<8.0.0
fastjsonschema>=2.8
Flask-Compress>=1.4.0
Flask-Cors>=3.0.5
@@ -14,17 +15,17 @@ Jinja2==2.10
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.7.2
mongoengine==0.16.2
mongoengine==0.19.1
nested_dict>=1.61
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt>=1.3.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
pymongo==3.10.1
python-rapidjson>=0.6.3
redis>=2.10.5
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.0,<3
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4

View File

@@ -328,6 +328,11 @@ fixed_users_mode {
description: "Fixed users mode enabled"
type: boolean
}
server_errors {
description: "Server initialization errors"
type: object
additionalProperties: True
}
}
}
}

View File

@@ -0,0 +1,16 @@
_description: "debugging utilities"
ping {
authorize: false
"2.9" {
description: "Ping server"
request {
type: object
additionalProperties: true
}
response {
type: object
properties: {
}
}
}
}

View File

@@ -258,6 +258,7 @@
properties {
added { type: integer }
errors { type: integer }
errors_info { type: object }
}
}
}
@@ -362,7 +363,7 @@
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True"
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
}
refresh {
type: boolean
@@ -529,6 +530,56 @@
}
}
}
"2.9" {
description: "Get 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
batch_size {
type: integer
description: "The amount of log events to return"
}
navigate_earlier {
type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order, unless order='asc'). Otherwise from the earliest to the latest ones (in timestamp ascending order, unless order='desc'). The default is True"
}
from_timestamp {
type: number
description: "Epoch time in UTC ms to use as the navigation start. Optional. If not provided, reference timestamp is determined by the 'navigate_earlier' parameter (if true, reference timestamp is the last timestamp and if false, reference timestamp is the first timestamp)"
}
order {
type: string
description: "If set, changes the order in which log events are returned based on the value of 'navigate_earlier'"
enum: [asc, desc]
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of log events returned"
}
total {
type: number
description: "Total number of log events available for this query"
}
}
}
}
}
get_task_events {
"2.1" {
@@ -802,7 +853,7 @@
description: "Task ID"
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
type: integer
}
key {
@@ -840,7 +891,7 @@
]
properties {
tasks {
description: "List of task Task IDs"
description: "List of task Task IDs. Maximum amount of tasks is 10"
type: array
items {
type: string
@@ -848,7 +899,7 @@
}
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
description: "The amount of histogram points to return. Optional, the default value is 6000"
type: integer
}
key {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
_description: "This service provides organization level operations"
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
tags {
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}

View File

@@ -196,6 +196,52 @@ _definitions {
}
}
}
tags_request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
projects {
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
type: array
items { type: string }
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
tags {
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
tags_response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
create {
@@ -359,6 +405,11 @@ get_all_ex {
enum: [ active, archived ]
default: active
}
non_public {
description: "Return only non-public projects"
type: boolean
default: false
}
}
}
}
@@ -481,8 +532,8 @@ get_unique_metric_variants {
}
}
get_hyper_parameters {
"2.2" {
description: """Get a list of all hyper parameter names used in tasks within the given project."""
"2.9" {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
properties {
@@ -506,9 +557,9 @@ get_hyper_parameters {
type: object
properties {
parameters {
description: "A list of hyper parameter names"
description: "A list of parameter sections and names"
type: array
items { type: string }
items {type: object}
}
remaining {
description: "Remaining results"
@@ -522,3 +573,69 @@ get_hyper_parameters {
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
make_public {
"2.9" {
description: """Convert company projects to public"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public projects to private"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}

View File

@@ -69,6 +69,17 @@ info {
}
}
}
"2.8": ${info."2.1"} {
response {
type: object
properties {
uid {
description: "Server UID"
type: string
}
}
}
}
}
endpoints {
"2.1" {

View File

@@ -254,6 +254,15 @@ _definitions {
enum: [
training
testing
inference
data_processing
application
monitor
controller
optimizer
service
qc
custom
]
}
last_metrics_event {
@@ -288,7 +297,80 @@ _definitions {
"$ref": "#/definitions/last_metrics_event"
}
}
params_item {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. The combination of section and name should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
configuration_item {
type: object
properties {
name {
description: "Name of the parameter. Should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
param_key {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
type: string
}
}
}
section_params {
description: "Task section params"
type: object
additionalProperties {
"$ref": "#/definitions/params_item"
}
}
replace_hyperparams_enum {
type: string
enum: [
none,
section,
all
]
}
task {
type: object
properties {
@@ -409,9 +491,24 @@ _definitions {
"$ref": "#/definitions/last_metrics_variants"
}
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
}
get_by_id {
"2.1" {
description: "Gets task information"
@@ -475,7 +572,11 @@ get_all {
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
description: """List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned documents.
Use '-' prefix to specify descending order. Optional, recommended when using page.
If the first order field is a hyper parameter or metric then string values are ordered
according to numeric ordering rules where applicable"""
type: array
items { type: string }
}
@@ -550,6 +651,31 @@ get_all {
}
}
}
get_types {
"2.8" {
description: "Get the list of task types used in the specified projects"
request {
type: object
properties {
projects {
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
type: array
items: {type: string}
}
}
}
response {
type: object
properties {
types {
description: "Unique list of the task types used in the requested projects"
type: array
items: {type: string}
}
}
}
}
}
clone {
"2.5" {
description: "Clone an existing task"
@@ -587,10 +713,28 @@ clone {
description: "The project of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_hyperparams {
description: "The hyper params for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
new_task_configuration {
description: "The configuration for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
execution_overrides {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
}
validate_references {
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
type: boolean
}
}
}
response {
@@ -656,6 +800,20 @@ create {
description: "Script info"
"$ref": "#/definitions/script"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
response {
@@ -717,6 +875,20 @@ validate {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@@ -867,6 +1039,20 @@ edit {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@@ -901,6 +1087,11 @@ reset {
properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
properties.clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
} ${_references.status_change_request}
response {
type: object
@@ -1394,4 +1585,263 @@ add_or_update_artifacts {
}
}
}
}
make_public {
"2.9" {
description: """Convert company tasks to public"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public tasks to private"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}
get_hyper_params {
"2.9": {
description: "Get the list of task hyper parameters"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
params {
type: object
description: "Hyper parameters (keyed by task ID)"
}
}
}
}
}
edit_hyper_params {
"2.9" {
description: "Add or update task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/params_item"}
}
replace_hyperparams {
description: """Can be set to one of the following:
'all' - all the hyper parameters will be replaced with the provided ones
'section' - the sections that present in the new parameters will be replaced with the provided parameters
'none' (the default value) - only the specific parameters will be updated or added"""
"$ref": "#/definitions/replace_hyperparams_enum"
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_hyper_params {
"2.9": {
description: "Delete task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
type: array
items { "$ref": "#/definitions/param_key" }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
get_configurations {
"2.9": {
description: "Get the list of task configurations"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
names {
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Configurations (keyed by task ID)"
}
}
}
}
}
get_configuration_names {
"2.9": {
description: "Get the list of task configuration items names"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Names of task configuration items (keyed by task ID)"
}
}
}
}
}
edit_configuration {
"2.9" {
description: "Add or update task configuration"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/configuration_item"}
}
replace_configuration {
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
type: boolean
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_configuration {
"2.9": {
description: "Delete task configuration items"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "List of configuration itemss to delete"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}

View File

@@ -145,6 +145,19 @@ get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
}
"2.8": ${get_all."2.1"} {
request {
type: object
properties {
active_in_projects {
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
type: array
items { type: string }
}
}
}
}
}
get_all {

View File

@@ -135,6 +135,10 @@
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
@@ -144,6 +148,11 @@
type: array
items { "$ref": "#/definitions/queue_entry" }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
@@ -151,11 +160,11 @@
type: object
properties {
id {
description: "Worker ID"
description: "ID"
type: string
}
name {
description: "Worker name"
description: "Name"
type: string
}
}
@@ -301,6 +310,11 @@
type: array
items { type: string }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
response {
@@ -363,6 +377,11 @@
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
tags {
description: "New user tags for the worker"
type: array
items: { type: string }
}
}
}
response {

View File

@@ -1,20 +1,28 @@
import atexit
from argparse import ArgumentParser
from hashlib import md5
from flask import Flask, request, Response
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from sync import distributed_lock
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
@@ -33,8 +41,42 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize()
init_es_data()
init_mongo_data()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if upgrade_monitoring:
if not empty_db and (info.es_connection_error or empty_es):
if get_last_server_version() < Version("0.16.0"):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
proceed_with_init = not (info.es_connection_error or info.missed_es_upgrade)
else:
proceed_with_init = True
if proceed_with_init:
init_es_data()
init_mongo_data()
if (
proceed_with_init
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")

View File

@@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
from .service_repo import ServiceRepo
__all__ = ["endpoint"]
__all__ = ["APICall", "endpoint"]
LegacyEndpointFunc = Callable[[APICall], None]

Some files were not shown because too many files have changed in this diff Show More